repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pico
Antonio Mika · 08 Oct 24

file_handler.go

  1package pico
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/ssh"
 15	"github.com/charmbracelet/wish"
 16	"github.com/picosh/pico/db"
 17	"github.com/picosh/pico/shared"
 18	sendutils "github.com/picosh/send/utils"
 19	"github.com/picosh/utils"
 20)
 21
 22type UploadHandler struct {
 23	DBPool db.DB
 24	Cfg    *shared.ConfigSite
 25}
 26
 27func NewUploadHandler(dbpool db.DB, cfg *shared.ConfigSite) *UploadHandler {
 28	return &UploadHandler{
 29		DBPool: dbpool,
 30		Cfg:    cfg,
 31	}
 32}
 33
 34func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualFile, string, error) {
 35	keys, err := h.DBPool.FindKeysForUser(user)
 36	text := ""
 37	var modTime time.Time
 38	for _, pk := range keys {
 39		text += fmt.Sprintf("%s %s\n", pk.Key, pk.Name)
 40		modTime = *pk.CreatedAt
 41	}
 42	if err != nil {
 43		return nil, "", err
 44	}
 45	fileInfo := &sendutils.VirtualFile{
 46		FName:    "authorized_keys",
 47		FIsDir:   false,
 48		FSize:    int64(len(text)),
 49		FModTime: modTime,
 50	}
 51	return fileInfo, text, nil
 52}
 53
 54func (h *UploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
 55	return errors.New("unsupported")
 56}
 57
 58func (h *UploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) {
 59	user, err := shared.GetUser(s.Context())
 60	if err != nil {
 61		return nil, nil, err
 62	}
 63	cleanFilename := filepath.Base(entry.Filepath)
 64
 65	if cleanFilename == "" || cleanFilename == "." {
 66		return nil, nil, os.ErrNotExist
 67	}
 68
 69	if cleanFilename == "authorized_keys" {
 70		fileInfo, text, err := h.getAuthorizedKeyFile(user)
 71		if err != nil {
 72			return nil, nil, err
 73		}
 74		reader := sendutils.NopReaderAtCloser(strings.NewReader(text))
 75		return fileInfo, reader, nil
 76	}
 77
 78	return nil, nil, os.ErrNotExist
 79}
 80
 81func (h *UploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
 82	var fileList []os.FileInfo
 83	user, err := shared.GetUser(s.Context())
 84	if err != nil {
 85		return fileList, err
 86	}
 87	cleanFilename := filepath.Base(fpath)
 88
 89	if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
 90		name := cleanFilename
 91		if name == "" {
 92			name = "/"
 93		}
 94
 95		fileList = append(fileList, &sendutils.VirtualFile{
 96			FName:  name,
 97			FIsDir: true,
 98		})
 99
100		flist, _, err := h.getAuthorizedKeyFile(user)
101		if err != nil {
102			return fileList, err
103		}
104		fileList = append(fileList, flist)
105	} else {
106		if cleanFilename == "authorized_keys" {
107			flist, _, err := h.getAuthorizedKeyFile(user)
108			if err != nil {
109				return fileList, err
110			}
111			fileList = append(fileList, flist)
112		}
113	}
114
115	return fileList, nil
116}
117
118func (h *UploadHandler) GetLogger() *slog.Logger {
119	return h.Cfg.Logger
120}
121
122func (h *UploadHandler) Validate(s ssh.Session) error {
123	var err error
124	key, err := sendutils.KeyText(s)
125	if err != nil {
126		return fmt.Errorf("key not found")
127	}
128
129	user, err := h.DBPool.FindUserForKey(s.User(), key)
130	if err != nil {
131		return err
132	}
133
134	if user.Name == "" {
135		return fmt.Errorf("must have username set")
136	}
137
138	shared.SetUser(s.Context(), user)
139	return nil
140}
141
142type KeyWithId struct {
143	Pk      ssh.PublicKey
144	ID      string
145	Comment string
146}
147
148type KeyDiffResult struct {
149	Add    []KeyWithId
150	Rm     []string
151	Update []KeyWithId
152}
153
154func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []KeyWithId) KeyDiffResult {
155	update := []KeyWithId{}
156	add := []KeyWithId{}
157	for _, nk := range nextKeys {
158		found := false
159		for _, ck := range curKeys {
160			if ssh.KeysEqual(nk.Pk, ck.Pk) {
161				found = true
162
163				// update the comment field
164				if nk.Comment != ck.Comment {
165					ck.Comment = nk.Comment
166					update = append(update, ck)
167				}
168				break
169			}
170		}
171		if !found {
172			add = append(add, nk)
173		}
174	}
175
176	rm := []string{}
177	for _, ck := range curKeys {
178		// we never want to remove the key that's in the current ssh session
179		// in an effort to avoid mistakenly removing their current key
180		if ssh.KeysEqual(ck.Pk, keyInUse) {
181			continue
182		}
183
184		found := false
185		for _, nk := range nextKeys {
186			if ssh.KeysEqual(ck.Pk, nk.Pk) {
187				found = true
188				break
189			}
190		}
191		if !found {
192			rm = append(rm, ck.ID)
193		}
194	}
195
196	return KeyDiffResult{
197		Add:    add,
198		Rm:     rm,
199		Update: update,
200	}
201}
202
203func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s ssh.Session) error {
204	logger.Info("processing new authorized_keys")
205	dbpool := h.DBPool
206
207	curKeysStr, err := dbpool.FindKeysForUser(user)
208	if err != nil {
209		return err
210	}
211
212	splitKeys := bytes.Split(text, []byte{'\n'})
213	nextKeys := []KeyWithId{}
214	for _, pk := range splitKeys {
215		key, comment, _, _, err := ssh.ParseAuthorizedKey(bytes.TrimSpace(pk))
216		if err != nil {
217			continue
218		}
219		nextKeys = append(nextKeys, KeyWithId{Pk: key, Comment: comment})
220	}
221
222	curKeys := []KeyWithId{}
223	for _, pk := range curKeysStr {
224		key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
225		if err != nil {
226			continue
227		}
228		curKeys = append(curKeys, KeyWithId{Pk: key, ID: pk.ID, Comment: pk.Name})
229	}
230
231	diff := authorizedKeysDiff(s.PublicKey(), curKeys, nextKeys)
232
233	for _, pk := range diff.Add {
234		key := utils.KeyForKeyText(pk.Pk)
235
236		wish.Errorf(s, "adding pubkey (%s)\n", key)
237		logger.Info("adding pubkey", "pubkey", key)
238
239		err = dbpool.InsertPublicKey(user.ID, key, pk.Comment, nil)
240		if err != nil {
241			wish.Errorf(s, "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
242			logger.Error("could not insert pubkey", "err", err.Error())
243		}
244	}
245
246	for _, pk := range diff.Update {
247		key := utils.KeyForKeyText(pk.Pk)
248
249		wish.Errorf(s, "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
250		logger.Info(
251			"updating pubkey with comment",
252			"pubkey", key,
253			"comment", pk.Comment,
254		)
255
256		_, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
257		if err != nil {
258			wish.Errorf(s, "error: could not update pubkey: %s (%s)\n", err.Error(), key)
259			logger.Error("could not update pubkey", "err", err.Error(), "key", key)
260		}
261	}
262
263	if len(diff.Rm) > 0 {
264		wish.Errorf(s, "removing pubkeys: %s\n", diff.Rm)
265		logger.Info("removing pubkeys", "pubkeys", diff.Rm)
266
267		err = dbpool.RemoveKeys(diff.Rm)
268		if err != nil {
269			wish.Errorf(s, "error: could not rm pubkeys: %s\n", err.Error())
270			logger.Error("could not remove pubkey", "err", err.Error())
271		}
272	}
273
274	return nil
275}
276
277func (h *UploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
278	logger := h.Cfg.Logger
279	user, err := shared.GetUser(s.Context())
280	if err != nil {
281		logger.Error(err.Error())
282		return "", err
283	}
284
285	filename := filepath.Base(entry.Filepath)
286	logger = logger.With(
287		"user", user.Name,
288		"filename", filename,
289	)
290
291	var text []byte
292	if b, err := io.ReadAll(entry.Reader); err == nil {
293		text = b
294	}
295
296	if filename == "authorized_keys" {
297		err := h.ProcessAuthorizedKeys(text, logger, user, s)
298		if err != nil {
299			return "", err
300		}
301	} else {
302		return "", fmt.Errorf("validation error: invalid file, received %s", entry.Filepath)
303	}
304
305	return "", nil
306}