Eric Bower
·
09 Nov 24
uploader.go
1package pgs
2
3import (
4 "bytes"
5 "fmt"
6 "io"
7 "io/fs"
8 "log/slog"
9 "os"
10 "path"
11 "path/filepath"
12 "slices"
13 "strings"
14 "time"
15
16 "github.com/charmbracelet/ssh"
17 "github.com/charmbracelet/wish"
18 "github.com/picosh/pico/db"
19 "github.com/picosh/pico/shared"
20 "github.com/picosh/pobj"
21 sst "github.com/picosh/pobj/storage"
22 sendutils "github.com/picosh/send/utils"
23 "github.com/picosh/utils"
24 ignore "github.com/sabhiram/go-gitignore"
25)
26
27type ctxBucketKey struct{}
28type ctxStorageSizeKey struct{}
29type ctxProjectKey struct{}
30type ctxDenylistKey struct{}
31
32type DenyList struct {
33 Denylist string
34}
35
36func getDenylist(s ssh.Session) *DenyList {
37 v := s.Context().Value(ctxDenylistKey{})
38 if v == nil {
39 return nil
40 }
41 denylist := s.Context().Value(ctxDenylistKey{}).(*DenyList)
42 return denylist
43}
44
45func setDenylist(s ssh.Session, denylist string) {
46 s.Context().SetValue(ctxDenylistKey{}, &DenyList{Denylist: denylist})
47}
48
49func getProject(s ssh.Session) *db.Project {
50 v := s.Context().Value(ctxProjectKey{})
51 if v == nil {
52 return nil
53 }
54 project := s.Context().Value(ctxProjectKey{}).(*db.Project)
55 return project
56}
57
58func setProject(s ssh.Session, project *db.Project) {
59 s.Context().SetValue(ctxProjectKey{}, project)
60}
61
62func getBucket(s ssh.Session) (sst.Bucket, error) {
63 bucket := s.Context().Value(ctxBucketKey{}).(sst.Bucket)
64 if bucket.Name == "" {
65 return bucket, fmt.Errorf("bucket not set on `ssh.Context()` for connection")
66 }
67 return bucket, nil
68}
69
70func getStorageSize(s ssh.Session) uint64 {
71 return s.Context().Value(ctxStorageSizeKey{}).(uint64)
72}
73
74func incrementStorageSize(s ssh.Session, fileSize int64) uint64 {
75 curSize := getStorageSize(s)
76 var nextStorageSize uint64
77 if fileSize < 0 {
78 nextStorageSize = curSize - uint64(fileSize)
79 } else {
80 nextStorageSize = curSize + uint64(fileSize)
81 }
82 s.Context().SetValue(ctxStorageSizeKey{}, nextStorageSize)
83 return nextStorageSize
84}
85
86func shouldIgnoreFile(fp, ignoreStr string) bool {
87 object := ignore.CompileIgnoreLines(strings.Split(ignoreStr, "\n")...)
88 return object.MatchesPath(fp)
89}
90
91type FileData struct {
92 *sendutils.FileEntry
93 User *db.User
94 Bucket sst.Bucket
95 Project *db.Project
96 DenyList string
97}
98
99type UploadAssetHandler struct {
100 DBPool db.DB
101 Cfg *shared.ConfigSite
102 Storage sst.ObjectStorage
103}
104
105func NewUploadAssetHandler(dbpool db.DB, cfg *shared.ConfigSite, storage sst.ObjectStorage) *UploadAssetHandler {
106 return &UploadAssetHandler{
107 DBPool: dbpool,
108 Cfg: cfg,
109 Storage: storage,
110 }
111}
112
113func (h *UploadAssetHandler) GetLogger() *slog.Logger {
114 return h.Cfg.Logger
115}
116
117func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) {
118 user, err := shared.GetUser(s.Context())
119 if err != nil {
120 return nil, nil, err
121 }
122
123 fileInfo := &sendutils.VirtualFile{
124 FName: filepath.Base(entry.Filepath),
125 FIsDir: false,
126 FSize: entry.Size,
127 FModTime: time.Unix(entry.Mtime, 0),
128 }
129
130 bucket, err := h.Storage.GetBucket(shared.GetAssetBucketName(user.ID))
131 if err != nil {
132 return nil, nil, err
133 }
134
135 fname := shared.GetAssetFileName(entry)
136 contents, info, err := h.Storage.GetObject(bucket, fname)
137 if err != nil {
138 return nil, nil, err
139 }
140
141 fileInfo.FSize = info.Size
142 fileInfo.FModTime = info.LastModified
143
144 reader := pobj.NewAllReaderAt(contents)
145
146 return fileInfo, reader, nil
147}
148
149func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
150 var fileList []os.FileInfo
151
152 user, err := shared.GetUser(s.Context())
153 if err != nil {
154 return fileList, err
155 }
156
157 cleanFilename := fpath
158
159 bucketName := shared.GetAssetBucketName(user.ID)
160 bucket, err := h.Storage.GetBucket(bucketName)
161 if err != nil {
162 return fileList, err
163 }
164
165 if cleanFilename == "" || cleanFilename == "." {
166 name := cleanFilename
167 if name == "" {
168 name = "/"
169 }
170
171 info := &sendutils.VirtualFile{
172 FName: name,
173 FIsDir: true,
174 }
175
176 fileList = append(fileList, info)
177 } else {
178 if cleanFilename != "/" && isDir {
179 cleanFilename += "/"
180 }
181
182 foundList, err := h.Storage.ListObjects(bucket, cleanFilename, recursive)
183 if err != nil {
184 return fileList, err
185 }
186
187 fileList = append(fileList, foundList...)
188 }
189
190 return fileList, nil
191}
192
193func (h *UploadAssetHandler) Validate(s ssh.Session) error {
194 user, err := shared.GetUser(s.Context())
195 if err != nil {
196 return err
197 }
198
199 assetBucket := shared.GetAssetBucketName(user.ID)
200 bucket, err := h.Storage.UpsertBucket(assetBucket)
201 if err != nil {
202 return err
203 }
204 s.Context().SetValue(ctxBucketKey{}, bucket)
205
206 totalStorageSize, err := h.Storage.GetBucketQuota(bucket)
207 if err != nil {
208 return err
209 }
210 s.Context().SetValue(ctxStorageSizeKey{}, totalStorageSize)
211 h.Cfg.Logger.Info(
212 "bucket size",
213 "user", user.Name,
214 "bytes", totalStorageSize,
215 )
216
217 h.Cfg.Logger.Info(
218 "attempting to upload files",
219 "user", user.Name,
220 "space", h.Cfg.Space,
221 )
222
223 return nil
224}
225
226func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project, logger *slog.Logger) (string, error) {
227 fp, _, err := h.Storage.GetObject(bucket, filepath.Join(project.ProjectDir, "_pgs_ignore"))
228 if err != nil {
229 return "", fmt.Errorf("_pgs_ignore not found")
230 }
231
232 defer fp.Close()
233 buf := new(strings.Builder)
234 _, err = io.Copy(buf, fp)
235 if err != nil {
236 logger.Error("io copy", "err", err.Error())
237 return "", err
238 }
239
240 str := buf.String()
241 return str, nil
242}
243
244func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
245 user, err := shared.GetUser(s.Context())
246 if user == nil || err != nil {
247 h.Cfg.Logger.Error("user not found in ctx", "err", err.Error())
248 return "", err
249 }
250
251 if entry.Mode.IsDir() && strings.Count(entry.Filepath, "/") == 1 {
252 entry.Filepath = strings.TrimPrefix(entry.Filepath, "/")
253 }
254
255 logger := h.GetLogger()
256 logger = shared.LoggerWithUser(logger, user)
257 logger = logger.With(
258 "file", entry.Filepath,
259 "size", entry.Size,
260 )
261
262 bucket, err := getBucket(s)
263 if err != nil {
264 logger.Error("could not find bucket in ctx", "err", err.Error())
265 return "", err
266 }
267
268 project := getProject(s)
269 projectName := shared.GetProjectName(entry)
270 logger = logger.With("project", projectName)
271
272 // find, create, or update project if we haven't already done it
273 if project == nil {
274 project, err = h.DBPool.FindProjectByName(user.ID, projectName)
275 if err == nil {
276 err = h.DBPool.UpdateProject(user.ID, projectName)
277 if err != nil {
278 logger.Error("could not update project", "err", err.Error())
279 return "", err
280 }
281 } else {
282 _, err = h.DBPool.InsertProject(user.ID, projectName, projectName)
283 if err != nil {
284 logger.Error("could not create project", "err", err.Error())
285 return "", err
286 }
287 project, err = h.DBPool.FindProjectByName(user.ID, projectName)
288 if err != nil {
289 logger.Error("could not find project", "err", err.Error())
290 return "", err
291 }
292 }
293 setProject(s, project)
294 }
295
296 if project.Blocked != "" {
297 msg := "project has been blocked and cannot upload files: %s"
298 return "", fmt.Errorf(msg, project.Blocked)
299 }
300
301 if entry.Mode.IsDir() {
302 _, _, err := h.Storage.PutObject(
303 bucket,
304 path.Join(shared.GetAssetFileName(entry), "._pico_keep_dir"),
305 bytes.NewReader([]byte{}),
306 entry,
307 )
308 return "", err
309 }
310
311 featureFlag, err := shared.GetFeatureFlag(s.Context())
312 if err != nil {
313 return "", err
314 }
315
316 // calculate the filsize difference between the same file already
317 // stored and the updated file being uploaded
318 assetFilename := shared.GetAssetFileName(entry)
319 _, info, _ := h.Storage.GetObject(bucket, assetFilename)
320 var curFileSize int64
321 if info != nil {
322 curFileSize = info.Size
323 }
324
325 denylist := getDenylist(s)
326 if denylist == nil {
327 dlist, err := h.findDenylist(bucket, project, logger)
328 if err != nil {
329 logger.Info("failed to get denylist, setting default (.*)", "err", err.Error())
330 dlist = ".*"
331 }
332 setDenylist(s, dlist)
333 denylist = &DenyList{Denylist: dlist}
334 }
335
336 data := &FileData{
337 FileEntry: entry,
338 User: user,
339 Bucket: bucket,
340 DenyList: denylist.Denylist,
341 Project: project,
342 }
343
344 valid, err := h.validateAsset(data)
345 if !valid {
346 return "", err
347 }
348
349 // SFTP does not report file size so the more performant way to
350 // check filesize constraints is to try and upload the file to s3
351 // with a specialized reader that raises an error if the filesize limit
352 // has been reached
353 storageMax := featureFlag.Data.StorageMax
354 fileMax := featureFlag.Data.FileMax
355 curStorageSize := getStorageSize(s)
356 remaining := int64(storageMax) - int64(curStorageSize)
357 sizeRemaining := min(remaining+curFileSize, fileMax)
358 if sizeRemaining <= 0 {
359 wish.Fatalln(s, "storage quota reached")
360 return "", fmt.Errorf("storage quota reached")
361 }
362 logger = logger.With(
363 "storageMax", storageMax,
364 "currentStorageMax", curStorageSize,
365 "fileMax", fileMax,
366 "sizeRemaining", sizeRemaining,
367 )
368
369 specialFileMax := featureFlag.Data.SpecialFileMax
370 if isSpecialFile(entry) {
371 sizeRemaining = min(sizeRemaining, specialFileMax)
372 }
373
374 fsize, err := h.writeAsset(
375 utils.NewMaxBytesReader(data.Reader, int64(sizeRemaining)),
376 data,
377 )
378 if err != nil {
379 logger.Error("could not write asset", "err", err.Error())
380 cerr := fmt.Errorf(
381 "%s: storage size %.2fmb, storage max %.2fmb, file max %.2fmb, special file max %.4fmb",
382 err,
383 utils.BytesToMB(int(curStorageSize)),
384 utils.BytesToMB(int(storageMax)),
385 utils.BytesToMB(int(fileMax)),
386 utils.BytesToMB(int(specialFileMax)),
387 )
388 return "", cerr
389 }
390
391 deltaFileSize := curFileSize - fsize
392 nextStorageSize := incrementStorageSize(s, deltaFileSize)
393
394 url := h.Cfg.AssetURL(
395 user.Name,
396 projectName,
397 strings.Replace(data.Filepath, "/"+projectName+"/", "", 1),
398 )
399
400 maxSize := int(featureFlag.Data.StorageMax)
401 str := fmt.Sprintf(
402 "%s (space: %.2f/%.2fGB, %.2f%%)",
403 url,
404 utils.BytesToGB(int(nextStorageSize)),
405 utils.BytesToGB(maxSize),
406 (float32(nextStorageSize)/float32(maxSize))*100,
407 )
408
409 return str, nil
410}
411
412func isSpecialFile(entry *sendutils.FileEntry) bool {
413 fname := filepath.Base(entry.Filepath)
414 return fname == "_headers" || fname == "_redirects"
415}
416
417func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
418 user, err := shared.GetUser(s.Context())
419 if err != nil {
420 h.Cfg.Logger.Error("user not found in ctx", "err", err.Error())
421 return err
422 }
423
424 if entry.Mode.IsDir() && strings.Count(entry.Filepath, "/") == 1 {
425 entry.Filepath = strings.TrimPrefix(entry.Filepath, "/")
426 }
427
428 assetFilepath := shared.GetAssetFileName(entry)
429
430 logger := h.GetLogger()
431 logger = shared.LoggerWithUser(logger, user)
432 logger = logger.With(
433 "file", assetFilepath,
434 )
435
436 bucket, err := getBucket(s)
437 if err != nil {
438 logger.Error("could not find bucket in ctx", "err", err.Error())
439 return err
440 }
441
442 projectName := shared.GetProjectName(entry)
443 logger = logger.With("project", projectName)
444
445 if assetFilepath == filepath.Join("/", projectName, "._pico_keep_dir") {
446 return os.ErrPermission
447 }
448
449 logger.Info("deleting file")
450
451 pathDir := filepath.Dir(assetFilepath)
452 fileName := filepath.Base(assetFilepath)
453
454 sibs, err := h.Storage.ListObjects(bucket, pathDir+"/", false)
455 if err != nil {
456 return err
457 }
458
459 sibs = slices.DeleteFunc(sibs, func(sib fs.FileInfo) bool {
460 return sib.Name() == fileName
461 })
462
463 if len(sibs) == 0 {
464 _, _, err := h.Storage.PutObject(
465 bucket,
466 filepath.Join(pathDir, "._pico_keep_dir"),
467 bytes.NewReader([]byte{}),
468 entry,
469 )
470 if err != nil {
471 return err
472 }
473 }
474
475 return h.Storage.DeleteObject(bucket, assetFilepath)
476}
477
478func (h *UploadAssetHandler) validateAsset(data *FileData) (bool, error) {
479 fname := filepath.Base(data.Filepath)
480
481 projectName := shared.GetProjectName(data.FileEntry)
482 if projectName == "" || projectName == "/" || projectName == "." {
483 return false, fmt.Errorf("ERROR: invalid project name, you must copy files to a non-root folder (e.g. pgs.sh:/project-name)")
484 }
485
486 // special files we use for custom routing
487 if fname == "_pgs_ignore" || fname == "_redirects" || fname == "_headers" {
488 return true, nil
489 }
490
491 fpath := strings.Replace(data.Filepath, "/"+projectName, "", 1)
492 if shouldIgnoreFile(fpath, data.DenyList) {
493 err := fmt.Errorf(
494 "ERROR: (%s) file rejected, https://pico.sh/pgs#file-denylist",
495 data.Filepath,
496 )
497 return false, err
498 }
499
500 return true, nil
501}
502
503func (h *UploadAssetHandler) writeAsset(reader io.Reader, data *FileData) (int64, error) {
504 assetFilepath := shared.GetAssetFileName(data.FileEntry)
505
506 logger := shared.LoggerWithUser(h.Cfg.Logger, data.User)
507 logger.Info(
508 "uploading file to bucket",
509 "bucket", data.Bucket.Name,
510 "filename", assetFilepath,
511 )
512
513 _, fsize, err := h.Storage.PutObject(
514 data.Bucket,
515 assetFilepath,
516 reader,
517 data.FileEntry,
518 )
519 return fsize, err
520}