repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / filehandlers
Eric Bower · 17 Jun 24

router_handler.go

  1package filehandlers
  2
  3import (
  4	"database/sql"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10
 11	"github.com/charmbracelet/ssh"
 12	"github.com/picosh/pico/db"
 13	"github.com/picosh/pico/filehandlers/util"
 14	"github.com/picosh/pico/shared"
 15	"github.com/picosh/send/send/utils"
 16)
 17
 18type ReadWriteHandler interface {
 19	Write(ssh.Session, *utils.FileEntry) (string, error)
 20	Read(ssh.Session, *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error)
 21	Delete(ssh.Session, *utils.FileEntry) error
 22}
 23
 24type FileHandlerRouter struct {
 25	FileMap map[string]ReadWriteHandler
 26	Cfg     *shared.ConfigSite
 27	DBPool  db.DB
 28	Spaces  []string
 29}
 30
 31var _ utils.CopyFromClientHandler = &FileHandlerRouter{}      // Verify implementation
 32var _ utils.CopyFromClientHandler = (*FileHandlerRouter)(nil) // Verify implementation
 33
 34func NewFileHandlerRouter(cfg *shared.ConfigSite, dbpool db.DB, mapper map[string]ReadWriteHandler) *FileHandlerRouter {
 35	return &FileHandlerRouter{
 36		Cfg:     cfg,
 37		DBPool:  dbpool,
 38		FileMap: mapper,
 39		Spaces:  []string{cfg.Space},
 40	}
 41}
 42
 43func (r *FileHandlerRouter) findHandler(entry *utils.FileEntry) (ReadWriteHandler, error) {
 44	fext := filepath.Ext(entry.Filepath)
 45	handler, ok := r.FileMap[fext]
 46	if !ok {
 47		hand, hasFallback := r.FileMap["fallback"]
 48		if !hasFallback {
 49			return nil, fmt.Errorf("no corresponding handler for file extension: %s", fext)
 50		}
 51		handler = hand
 52	}
 53	return handler, nil
 54}
 55
 56func (r *FileHandlerRouter) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
 57	if entry.Mode.IsDir() {
 58		return "", os.ErrInvalid
 59	}
 60
 61	handler, err := r.findHandler(entry)
 62	if err != nil {
 63		return "", err
 64	}
 65	return handler.Write(s, entry)
 66}
 67
 68func (r *FileHandlerRouter) Delete(s ssh.Session, entry *utils.FileEntry) error {
 69	handler, err := r.findHandler(entry)
 70	if err != nil {
 71		return err
 72	}
 73	return handler.Delete(s, entry)
 74}
 75
 76func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
 77	handler, err := r.findHandler(entry)
 78	if err != nil {
 79		return nil, nil, err
 80	}
 81	return handler.Read(s, entry)
 82}
 83
 84func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
 85	var fileList []os.FileInfo
 86	user, err := util.GetUser(s.Context())
 87	if err != nil {
 88		return fileList, err
 89	}
 90	cleanFilename := filepath.Base(fpath)
 91
 92	var post *db.Post
 93	var posts []*db.Post
 94
 95	if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
 96		name := cleanFilename
 97		if name == "" {
 98			name = "/"
 99		}
100
101		fileList = append(fileList, &utils.VirtualFile{
102			FName:  name,
103			FIsDir: true,
104		})
105
106		for _, space := range spaces {
107			curPosts, e := dbpool.FindAllPostsForUser(user.ID, space)
108			if e != nil {
109				err = e
110				break
111			}
112			posts = append(posts, curPosts...)
113		}
114	} else {
115		for _, space := range spaces {
116
117			p, e := dbpool.FindPostWithFilename(cleanFilename, user.ID, space)
118			if e != nil {
119				err = e
120				continue
121			}
122			post = p
123		}
124
125		posts = append(posts, post)
126	}
127
128	if err != nil && !errors.Is(err, sql.ErrNoRows) {
129		return nil, err
130	}
131
132	for _, post := range posts {
133		if post == nil {
134			continue
135		}
136
137		fileList = append(fileList, &utils.VirtualFile{
138			FName:    post.Filename,
139			FIsDir:   false,
140			FSize:    int64(post.FileSize),
141			FModTime: *post.UpdatedAt,
142		})
143	}
144
145	return fileList, nil
146}
147
148func (r *FileHandlerRouter) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
149	return BaseList(s, fpath, isDir, recursive, r.Spaces, r.DBPool)
150}
151
152func (r *FileHandlerRouter) GetLogger() *slog.Logger {
153	return r.Cfg.Logger
154}
155
156func (r *FileHandlerRouter) Validate(s ssh.Session) error {
157	user, err := util.GetUser(s.Context())
158	if err != nil {
159		return err
160	}
161
162	r.Cfg.Logger.Info(
163		"attempting to upload files",
164		"user", user.Name,
165		"space", r.Cfg.Space,
166	)
167	return nil
168}