Antonio Mika
·
08 Oct 24
file_handler.go
1package pico
2
3import (
4 "bytes"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/ssh"
15 "github.com/charmbracelet/wish"
16 "github.com/picosh/pico/db"
17 "github.com/picosh/pico/shared"
18 sendutils "github.com/picosh/send/utils"
19 "github.com/picosh/utils"
20)
21
22type UploadHandler struct {
23 DBPool db.DB
24 Cfg *shared.ConfigSite
25}
26
27func NewUploadHandler(dbpool db.DB, cfg *shared.ConfigSite) *UploadHandler {
28 return &UploadHandler{
29 DBPool: dbpool,
30 Cfg: cfg,
31 }
32}
33
34func (h *UploadHandler) getAuthorizedKeyFile(user *db.User) (*sendutils.VirtualFile, string, error) {
35 keys, err := h.DBPool.FindKeysForUser(user)
36 text := ""
37 var modTime time.Time
38 for _, pk := range keys {
39 text += fmt.Sprintf("%s %s\n", pk.Key, pk.Name)
40 modTime = *pk.CreatedAt
41 }
42 if err != nil {
43 return nil, "", err
44 }
45 fileInfo := &sendutils.VirtualFile{
46 FName: "authorized_keys",
47 FIsDir: false,
48 FSize: int64(len(text)),
49 FModTime: modTime,
50 }
51 return fileInfo, text, nil
52}
53
54func (h *UploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
55 return errors.New("unsupported")
56}
57
58func (h *UploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) {
59 user, err := shared.GetUser(s.Context())
60 if err != nil {
61 return nil, nil, err
62 }
63 cleanFilename := filepath.Base(entry.Filepath)
64
65 if cleanFilename == "" || cleanFilename == "." {
66 return nil, nil, os.ErrNotExist
67 }
68
69 if cleanFilename == "authorized_keys" {
70 fileInfo, text, err := h.getAuthorizedKeyFile(user)
71 if err != nil {
72 return nil, nil, err
73 }
74 reader := sendutils.NopReaderAtCloser(strings.NewReader(text))
75 return fileInfo, reader, nil
76 }
77
78 return nil, nil, os.ErrNotExist
79}
80
81func (h *UploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
82 var fileList []os.FileInfo
83 user, err := shared.GetUser(s.Context())
84 if err != nil {
85 return fileList, err
86 }
87 cleanFilename := filepath.Base(fpath)
88
89 if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
90 name := cleanFilename
91 if name == "" {
92 name = "/"
93 }
94
95 fileList = append(fileList, &sendutils.VirtualFile{
96 FName: name,
97 FIsDir: true,
98 })
99
100 flist, _, err := h.getAuthorizedKeyFile(user)
101 if err != nil {
102 return fileList, err
103 }
104 fileList = append(fileList, flist)
105 } else {
106 if cleanFilename == "authorized_keys" {
107 flist, _, err := h.getAuthorizedKeyFile(user)
108 if err != nil {
109 return fileList, err
110 }
111 fileList = append(fileList, flist)
112 }
113 }
114
115 return fileList, nil
116}
117
118func (h *UploadHandler) GetLogger() *slog.Logger {
119 return h.Cfg.Logger
120}
121
122func (h *UploadHandler) Validate(s ssh.Session) error {
123 var err error
124 key, err := sendutils.KeyText(s)
125 if err != nil {
126 return fmt.Errorf("key not found")
127 }
128
129 user, err := h.DBPool.FindUserForKey(s.User(), key)
130 if err != nil {
131 return err
132 }
133
134 if user.Name == "" {
135 return fmt.Errorf("must have username set")
136 }
137
138 shared.SetUser(s.Context(), user)
139 return nil
140}
141
142type KeyWithId struct {
143 Pk ssh.PublicKey
144 ID string
145 Comment string
146}
147
148type KeyDiffResult struct {
149 Add []KeyWithId
150 Rm []string
151 Update []KeyWithId
152}
153
154func authorizedKeysDiff(keyInUse ssh.PublicKey, curKeys []KeyWithId, nextKeys []KeyWithId) KeyDiffResult {
155 update := []KeyWithId{}
156 add := []KeyWithId{}
157 for _, nk := range nextKeys {
158 found := false
159 for _, ck := range curKeys {
160 if ssh.KeysEqual(nk.Pk, ck.Pk) {
161 found = true
162
163 // update the comment field
164 if nk.Comment != ck.Comment {
165 ck.Comment = nk.Comment
166 update = append(update, ck)
167 }
168 break
169 }
170 }
171 if !found {
172 add = append(add, nk)
173 }
174 }
175
176 rm := []string{}
177 for _, ck := range curKeys {
178 // we never want to remove the key that's in the current ssh session
179 // in an effort to avoid mistakenly removing their current key
180 if ssh.KeysEqual(ck.Pk, keyInUse) {
181 continue
182 }
183
184 found := false
185 for _, nk := range nextKeys {
186 if ssh.KeysEqual(ck.Pk, nk.Pk) {
187 found = true
188 break
189 }
190 }
191 if !found {
192 rm = append(rm, ck.ID)
193 }
194 }
195
196 return KeyDiffResult{
197 Add: add,
198 Rm: rm,
199 Update: update,
200 }
201}
202
203func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger, user *db.User, s ssh.Session) error {
204 logger.Info("processing new authorized_keys")
205 dbpool := h.DBPool
206
207 curKeysStr, err := dbpool.FindKeysForUser(user)
208 if err != nil {
209 return err
210 }
211
212 splitKeys := bytes.Split(text, []byte{'\n'})
213 nextKeys := []KeyWithId{}
214 for _, pk := range splitKeys {
215 key, comment, _, _, err := ssh.ParseAuthorizedKey(bytes.TrimSpace(pk))
216 if err != nil {
217 continue
218 }
219 nextKeys = append(nextKeys, KeyWithId{Pk: key, Comment: comment})
220 }
221
222 curKeys := []KeyWithId{}
223 for _, pk := range curKeysStr {
224 key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
225 if err != nil {
226 continue
227 }
228 curKeys = append(curKeys, KeyWithId{Pk: key, ID: pk.ID, Comment: pk.Name})
229 }
230
231 diff := authorizedKeysDiff(s.PublicKey(), curKeys, nextKeys)
232
233 for _, pk := range diff.Add {
234 key := utils.KeyForKeyText(pk.Pk)
235
236 wish.Errorf(s, "adding pubkey (%s)\n", key)
237 logger.Info("adding pubkey", "pubkey", key)
238
239 err = dbpool.InsertPublicKey(user.ID, key, pk.Comment, nil)
240 if err != nil {
241 wish.Errorf(s, "error: could not insert pubkey: %s (%s)\n", err.Error(), key)
242 logger.Error("could not insert pubkey", "err", err.Error())
243 }
244 }
245
246 for _, pk := range diff.Update {
247 key := utils.KeyForKeyText(pk.Pk)
248
249 wish.Errorf(s, "updating pubkey with comment: %s (%s)\n", pk.Comment, key)
250 logger.Info(
251 "updating pubkey with comment",
252 "pubkey", key,
253 "comment", pk.Comment,
254 )
255
256 _, err = dbpool.UpdatePublicKey(pk.ID, pk.Comment)
257 if err != nil {
258 wish.Errorf(s, "error: could not update pubkey: %s (%s)\n", err.Error(), key)
259 logger.Error("could not update pubkey", "err", err.Error(), "key", key)
260 }
261 }
262
263 if len(diff.Rm) > 0 {
264 wish.Errorf(s, "removing pubkeys: %s\n", diff.Rm)
265 logger.Info("removing pubkeys", "pubkeys", diff.Rm)
266
267 err = dbpool.RemoveKeys(diff.Rm)
268 if err != nil {
269 wish.Errorf(s, "error: could not rm pubkeys: %s\n", err.Error())
270 logger.Error("could not remove pubkey", "err", err.Error())
271 }
272 }
273
274 return nil
275}
276
277func (h *UploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
278 logger := h.Cfg.Logger
279 user, err := shared.GetUser(s.Context())
280 if err != nil {
281 logger.Error(err.Error())
282 return "", err
283 }
284
285 filename := filepath.Base(entry.Filepath)
286 logger = logger.With(
287 "user", user.Name,
288 "filename", filename,
289 )
290
291 var text []byte
292 if b, err := io.ReadAll(entry.Reader); err == nil {
293 text = b
294 }
295
296 if filename == "authorized_keys" {
297 err := h.ProcessAuthorizedKeys(text, logger, user, s)
298 if err != nil {
299 return "", err
300 }
301 } else {
302 return "", fmt.Errorf("validation error: invalid file, received %s", entry.Filepath)
303 }
304
305 return "", nil
306}