Eric Bower
·
10 Dec 24
uploader.go
1package pgs
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "io/fs"
9 "log/slog"
10 "os"
11 "path"
12 "path/filepath"
13 "slices"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/ssh"
19 "github.com/charmbracelet/wish"
20 "github.com/picosh/pico/db"
21 "github.com/picosh/pico/shared"
22 "github.com/picosh/pobj"
23 sst "github.com/picosh/pobj/storage"
24 sendutils "github.com/picosh/send/utils"
25 "github.com/picosh/utils"
26 ignore "github.com/sabhiram/go-gitignore"
27)
28
29type ctxBucketKey struct{}
30type ctxStorageSizeKey struct{}
31type ctxProjectKey struct{}
32type ctxDenylistKey struct{}
33
34type DenyList struct {
35 Denylist string
36}
37
38func getDenylist(s ssh.Session) *DenyList {
39 v := s.Context().Value(ctxDenylistKey{})
40 if v == nil {
41 return nil
42 }
43 denylist := s.Context().Value(ctxDenylistKey{}).(*DenyList)
44 return denylist
45}
46
47func setDenylist(s ssh.Session, denylist string) {
48 s.Context().SetValue(ctxDenylistKey{}, &DenyList{Denylist: denylist})
49}
50
51func getProject(s ssh.Session) *db.Project {
52 v := s.Context().Value(ctxProjectKey{})
53 if v == nil {
54 return nil
55 }
56 project := s.Context().Value(ctxProjectKey{}).(*db.Project)
57 return project
58}
59
60func setProject(s ssh.Session, project *db.Project) {
61 s.Context().SetValue(ctxProjectKey{}, project)
62}
63
64func getBucket(s ssh.Session) (sst.Bucket, error) {
65 bucket := s.Context().Value(ctxBucketKey{}).(sst.Bucket)
66 if bucket.Name == "" {
67 return bucket, fmt.Errorf("bucket not set on `ssh.Context()` for connection")
68 }
69 return bucket, nil
70}
71
72func getStorageSize(s ssh.Session) uint64 {
73 return s.Context().Value(ctxStorageSizeKey{}).(uint64)
74}
75
76func incrementStorageSize(s ssh.Session, fileSize int64) uint64 {
77 curSize := getStorageSize(s)
78 var nextStorageSize uint64
79 if fileSize < 0 {
80 nextStorageSize = curSize - uint64(fileSize)
81 } else {
82 nextStorageSize = curSize + uint64(fileSize)
83 }
84 s.Context().SetValue(ctxStorageSizeKey{}, nextStorageSize)
85 return nextStorageSize
86}
87
88func shouldIgnoreFile(fp, ignoreStr string) bool {
89 object := ignore.CompileIgnoreLines(strings.Split(ignoreStr, "\n")...)
90 return object.MatchesPath(fp)
91}
92
93type FileData struct {
94 *sendutils.FileEntry
95 User *db.User
96 Bucket sst.Bucket
97 Project *db.Project
98 DenyList string
99}
100
101type UploadAssetHandler struct {
102 DBPool db.DB
103 Cfg *shared.ConfigSite
104 Storage sst.ObjectStorage
105 CacheClearingQueue chan string
106}
107
108func NewUploadAssetHandler(dbpool db.DB, cfg *shared.ConfigSite, storage sst.ObjectStorage, ctx context.Context) *UploadAssetHandler {
109 // Enable buffering so we don't slow down uploads.
110 ch := make(chan string, 100)
111 go runCacheQueue(cfg, ctx, ch)
112 return &UploadAssetHandler{
113 DBPool: dbpool,
114 Cfg: cfg,
115 Storage: storage,
116 CacheClearingQueue: ch,
117 }
118}
119
120func (h *UploadAssetHandler) GetLogger() *slog.Logger {
121 return h.Cfg.Logger
122}
123
124func (h *UploadAssetHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) {
125 user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
126 if err != nil {
127 return nil, nil, err
128 }
129
130 fileInfo := &sendutils.VirtualFile{
131 FName: filepath.Base(entry.Filepath),
132 FIsDir: false,
133 FSize: entry.Size,
134 FModTime: time.Unix(entry.Mtime, 0),
135 }
136
137 bucket, err := h.Storage.GetBucket(shared.GetAssetBucketName(user.ID))
138 if err != nil {
139 return nil, nil, err
140 }
141
142 fname := shared.GetAssetFileName(entry)
143 contents, info, err := h.Storage.GetObject(bucket, fname)
144 if err != nil {
145 return nil, nil, err
146 }
147
148 fileInfo.FSize = info.Size
149 fileInfo.FModTime = info.LastModified
150
151 reader := pobj.NewAllReaderAt(contents)
152
153 return fileInfo, reader, nil
154}
155
156func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
157 var fileList []os.FileInfo
158
159 user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
160 if err != nil {
161 return fileList, err
162 }
163
164 cleanFilename := fpath
165
166 bucketName := shared.GetAssetBucketName(user.ID)
167 bucket, err := h.Storage.GetBucket(bucketName)
168 if err != nil {
169 return fileList, err
170 }
171
172 if cleanFilename == "" || cleanFilename == "." {
173 name := cleanFilename
174 if name == "" {
175 name = "/"
176 }
177
178 info := &sendutils.VirtualFile{
179 FName: name,
180 FIsDir: true,
181 }
182
183 fileList = append(fileList, info)
184 } else {
185 if cleanFilename != "/" && isDir {
186 cleanFilename += "/"
187 }
188
189 foundList, err := h.Storage.ListObjects(bucket, cleanFilename, recursive)
190 if err != nil {
191 return fileList, err
192 }
193
194 fileList = append(fileList, foundList...)
195 }
196
197 return fileList, nil
198}
199
200func (h *UploadAssetHandler) Validate(s ssh.Session) error {
201 user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
202 if err != nil {
203 return err
204 }
205
206 assetBucket := shared.GetAssetBucketName(user.ID)
207 bucket, err := h.Storage.UpsertBucket(assetBucket)
208 if err != nil {
209 return err
210 }
211 s.Context().SetValue(ctxBucketKey{}, bucket)
212
213 totalStorageSize, err := h.Storage.GetBucketQuota(bucket)
214 if err != nil {
215 return err
216 }
217 s.Context().SetValue(ctxStorageSizeKey{}, totalStorageSize)
218 h.Cfg.Logger.Info(
219 "bucket size",
220 "user", user.Name,
221 "bytes", totalStorageSize,
222 )
223
224 h.Cfg.Logger.Info(
225 "attempting to upload files",
226 "user", user.Name,
227 "space", h.Cfg.Space,
228 )
229
230 return nil
231}
232
233func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project, logger *slog.Logger) (string, error) {
234 fp, _, err := h.Storage.GetObject(bucket, filepath.Join(project.ProjectDir, "_pgs_ignore"))
235 if err != nil {
236 return "", fmt.Errorf("_pgs_ignore not found")
237 }
238
239 defer fp.Close()
240 buf := new(strings.Builder)
241 _, err = io.Copy(buf, fp)
242 if err != nil {
243 logger.Error("io copy", "err", err.Error())
244 return "", err
245 }
246
247 str := buf.String()
248 return str, nil
249}
250
251func (h *UploadAssetHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
252 user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
253 if user == nil || err != nil {
254 h.Cfg.Logger.Error("user not found in ctx", "err", err.Error())
255 return "", err
256 }
257
258 if entry.Mode.IsDir() && strings.Count(entry.Filepath, "/") == 1 {
259 entry.Filepath = strings.TrimPrefix(entry.Filepath, "/")
260 }
261
262 logger := h.GetLogger()
263 logger = shared.LoggerWithUser(logger, user)
264 logger = logger.With(
265 "file", entry.Filepath,
266 "size", entry.Size,
267 )
268
269 bucket, err := getBucket(s)
270 if err != nil {
271 logger.Error("could not find bucket in ctx", "err", err.Error())
272 return "", err
273 }
274
275 project := getProject(s)
276 projectName := shared.GetProjectName(entry)
277 logger = logger.With("project", projectName)
278
279 // find, create, or update project if we haven't already done it
280 if project == nil {
281 project, err = h.DBPool.FindProjectByName(user.ID, projectName)
282 if err == nil {
283 err = h.DBPool.UpdateProject(user.ID, projectName)
284 if err != nil {
285 logger.Error("could not update project", "err", err.Error())
286 return "", err
287 }
288 } else {
289 _, err = h.DBPool.InsertProject(user.ID, projectName, projectName)
290 if err != nil {
291 logger.Error("could not create project", "err", err.Error())
292 return "", err
293 }
294 project, err = h.DBPool.FindProjectByName(user.ID, projectName)
295 if err != nil {
296 logger.Error("could not find project", "err", err.Error())
297 return "", err
298 }
299 }
300 setProject(s, project)
301 }
302
303 if project.Blocked != "" {
304 msg := "project has been blocked and cannot upload files: %s"
305 return "", fmt.Errorf(msg, project.Blocked)
306 }
307
308 if entry.Mode.IsDir() {
309 _, _, err := h.Storage.PutObject(
310 bucket,
311 path.Join(shared.GetAssetFileName(entry), "._pico_keep_dir"),
312 bytes.NewReader([]byte{}),
313 entry,
314 )
315 return "", err
316 }
317
318 featureFlag := shared.FindPlusFF(h.DBPool, h.Cfg, user.ID)
319 // calculate the filsize difference between the same file already
320 // stored and the updated file being uploaded
321 assetFilename := shared.GetAssetFileName(entry)
322 _, info, _ := h.Storage.GetObject(bucket, assetFilename)
323 var curFileSize int64
324 if info != nil {
325 curFileSize = info.Size
326 }
327
328 denylist := getDenylist(s)
329 if denylist == nil {
330 dlist, err := h.findDenylist(bucket, project, logger)
331 if err != nil {
332 logger.Info("failed to get denylist, setting default (.*)", "err", err.Error())
333 dlist = ".*"
334 }
335 setDenylist(s, dlist)
336 denylist = &DenyList{Denylist: dlist}
337 }
338
339 data := &FileData{
340 FileEntry: entry,
341 User: user,
342 Bucket: bucket,
343 DenyList: denylist.Denylist,
344 Project: project,
345 }
346
347 valid, err := h.validateAsset(data)
348 if !valid {
349 return "", err
350 }
351
352 // SFTP does not report file size so the more performant way to
353 // check filesize constraints is to try and upload the file to s3
354 // with a specialized reader that raises an error if the filesize limit
355 // has been reached
356 storageMax := featureFlag.Data.StorageMax
357 fileMax := featureFlag.Data.FileMax
358 curStorageSize := getStorageSize(s)
359 remaining := int64(storageMax) - int64(curStorageSize)
360 sizeRemaining := min(remaining+curFileSize, fileMax)
361 if sizeRemaining <= 0 {
362 wish.Fatalln(s, "storage quota reached")
363 return "", fmt.Errorf("storage quota reached")
364 }
365 logger = logger.With(
366 "storageMax", storageMax,
367 "currentStorageMax", curStorageSize,
368 "fileMax", fileMax,
369 "sizeRemaining", sizeRemaining,
370 )
371
372 specialFileMax := featureFlag.Data.SpecialFileMax
373 if isSpecialFile(entry) {
374 sizeRemaining = min(sizeRemaining, specialFileMax)
375 }
376
377 fsize, err := h.writeAsset(
378 utils.NewMaxBytesReader(data.Reader, int64(sizeRemaining)),
379 data,
380 )
381 if err != nil {
382 logger.Error("could not write asset", "err", err.Error())
383 cerr := fmt.Errorf(
384 "%s: storage size %.2fmb, storage max %.2fmb, file max %.2fmb, special file max %.4fmb",
385 err,
386 utils.BytesToMB(int(curStorageSize)),
387 utils.BytesToMB(int(storageMax)),
388 utils.BytesToMB(int(fileMax)),
389 utils.BytesToMB(int(specialFileMax)),
390 )
391 return "", cerr
392 }
393
394 deltaFileSize := curFileSize - fsize
395 nextStorageSize := incrementStorageSize(s, deltaFileSize)
396
397 url := h.Cfg.AssetURL(
398 user.Name,
399 projectName,
400 strings.Replace(data.Filepath, "/"+projectName+"/", "", 1),
401 )
402
403 maxSize := int(featureFlag.Data.StorageMax)
404 str := fmt.Sprintf(
405 "%s (space: %.2f/%.2fGB, %.2f%%)",
406 url,
407 utils.BytesToGB(int(nextStorageSize)),
408 utils.BytesToGB(maxSize),
409 (float32(nextStorageSize)/float32(maxSize))*100,
410 )
411
412 surrogate := getSurrogateKey(user.Name, projectName)
413 h.CacheClearingQueue <- surrogate
414
415 return str, nil
416}
417
418func isSpecialFile(entry *sendutils.FileEntry) bool {
419 fname := filepath.Base(entry.Filepath)
420 return fname == "_headers" || fname == "_redirects"
421}
422
423func (h *UploadAssetHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
424 user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
425 if err != nil {
426 h.Cfg.Logger.Error("user not found in ctx", "err", err.Error())
427 return err
428 }
429
430 if entry.Mode.IsDir() && strings.Count(entry.Filepath, "/") == 1 {
431 entry.Filepath = strings.TrimPrefix(entry.Filepath, "/")
432 }
433
434 assetFilepath := shared.GetAssetFileName(entry)
435
436 logger := h.GetLogger()
437 logger = shared.LoggerWithUser(logger, user)
438 logger = logger.With(
439 "file", assetFilepath,
440 )
441
442 bucket, err := getBucket(s)
443 if err != nil {
444 logger.Error("could not find bucket in ctx", "err", err.Error())
445 return err
446 }
447
448 projectName := shared.GetProjectName(entry)
449 logger = logger.With("project", projectName)
450
451 if assetFilepath == filepath.Join("/", projectName, "._pico_keep_dir") {
452 return os.ErrPermission
453 }
454
455 logger.Info("deleting file")
456
457 pathDir := filepath.Dir(assetFilepath)
458 fileName := filepath.Base(assetFilepath)
459
460 sibs, err := h.Storage.ListObjects(bucket, pathDir+"/", false)
461 if err != nil {
462 return err
463 }
464
465 sibs = slices.DeleteFunc(sibs, func(sib fs.FileInfo) bool {
466 return sib.Name() == fileName
467 })
468
469 if len(sibs) == 0 {
470 _, _, err := h.Storage.PutObject(
471 bucket,
472 filepath.Join(pathDir, "._pico_keep_dir"),
473 bytes.NewReader([]byte{}),
474 entry,
475 )
476 if err != nil {
477 return err
478 }
479 }
480 err = h.Storage.DeleteObject(bucket, assetFilepath)
481
482 surrogate := getSurrogateKey(user.Name, projectName)
483 h.CacheClearingQueue <- surrogate
484
485 return err
486}
487
488func (h *UploadAssetHandler) validateAsset(data *FileData) (bool, error) {
489 fname := filepath.Base(data.Filepath)
490
491 projectName := shared.GetProjectName(data.FileEntry)
492 if projectName == "" || projectName == "/" || projectName == "." {
493 return false, fmt.Errorf("ERROR: invalid project name, you must copy files to a non-root folder (e.g. pgs.sh:/project-name)")
494 }
495
496 // special files we use for custom routing
497 if fname == "_pgs_ignore" || fname == "_redirects" || fname == "_headers" {
498 return true, nil
499 }
500
501 fpath := strings.Replace(data.Filepath, "/"+projectName, "", 1)
502 if shouldIgnoreFile(fpath, data.DenyList) {
503 err := fmt.Errorf(
504 "ERROR: (%s) file rejected, https://pico.sh/pgs#file-denylist",
505 data.Filepath,
506 )
507 return false, err
508 }
509
510 return true, nil
511}
512
513func (h *UploadAssetHandler) writeAsset(reader io.Reader, data *FileData) (int64, error) {
514 assetFilepath := shared.GetAssetFileName(data.FileEntry)
515
516 logger := shared.LoggerWithUser(h.Cfg.Logger, data.User)
517 logger.Info(
518 "uploading file to bucket",
519 "bucket", data.Bucket.Name,
520 "filename", assetFilepath,
521 )
522
523 _, fsize, err := h.Storage.PutObject(
524 data.Bucket,
525 assetFilepath,
526 reader,
527 data.FileEntry,
528 )
529 return fsize, err
530}
531
532// runCacheQueue processes requests to purge the cache for a single site.
533// One message arrives per file that is written/deleted during uploads.
534// Repeated messages for the same site are grouped so that we only flush once
535// per site per 5 seconds.
536func runCacheQueue(cfg *shared.ConfigSite, ctx context.Context, ch chan string) {
537 send := createPubCacheDrain(ctx, cfg.Logger)
538 var pendingFlushes sync.Map
539 tick := time.Tick(5 * time.Second)
540 for {
541 select {
542 case host := <-ch:
543 pendingFlushes.Store(host, host)
544 case <-tick:
545 go func() {
546 pendingFlushes.Range(func(key, value any) bool {
547 pendingFlushes.Delete(key)
548 err := purgeCache(cfg, send, key.(string))
549 if err != nil {
550 cfg.Logger.Error("failed to clear cache", "err", err.Error())
551 }
552 return true
553 })
554 }()
555 }
556 }
557}