repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / filehandlers
Antonio Mika · 08 Oct 24

router_handler.go

  1package filehandlers
  2
  3import (
  4	"database/sql"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10
 11	"github.com/charmbracelet/ssh"
 12	"github.com/picosh/pico/db"
 13	"github.com/picosh/pico/shared"
 14	"github.com/picosh/send/utils"
 15)
 16
 17type ReadWriteHandler interface {
 18	Write(ssh.Session, *utils.FileEntry) (string, error)
 19	Read(ssh.Session, *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error)
 20	Delete(ssh.Session, *utils.FileEntry) error
 21}
 22
 23type FileHandlerRouter struct {
 24	FileMap map[string]ReadWriteHandler
 25	Cfg     *shared.ConfigSite
 26	DBPool  db.DB
 27	Spaces  []string
 28}
 29
 30var _ utils.CopyFromClientHandler = &FileHandlerRouter{}      // Verify implementation
 31var _ utils.CopyFromClientHandler = (*FileHandlerRouter)(nil) // Verify implementation
 32
 33func NewFileHandlerRouter(cfg *shared.ConfigSite, dbpool db.DB, mapper map[string]ReadWriteHandler) *FileHandlerRouter {
 34	return &FileHandlerRouter{
 35		Cfg:     cfg,
 36		DBPool:  dbpool,
 37		FileMap: mapper,
 38		Spaces:  []string{cfg.Space},
 39	}
 40}
 41
 42func (r *FileHandlerRouter) findHandler(entry *utils.FileEntry) (ReadWriteHandler, error) {
 43	fext := filepath.Ext(entry.Filepath)
 44	handler, ok := r.FileMap[fext]
 45	if !ok {
 46		hand, hasFallback := r.FileMap["fallback"]
 47		if !hasFallback {
 48			return nil, fmt.Errorf("no corresponding handler for file extension: %s", fext)
 49		}
 50		handler = hand
 51	}
 52	return handler, nil
 53}
 54
 55func (r *FileHandlerRouter) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
 56	if entry.Mode.IsDir() {
 57		return "", os.ErrInvalid
 58	}
 59
 60	handler, err := r.findHandler(entry)
 61	if err != nil {
 62		return "", err
 63	}
 64	return handler.Write(s, entry)
 65}
 66
 67func (r *FileHandlerRouter) Delete(s ssh.Session, entry *utils.FileEntry) error {
 68	handler, err := r.findHandler(entry)
 69	if err != nil {
 70		return err
 71	}
 72	return handler.Delete(s, entry)
 73}
 74
 75func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
 76	handler, err := r.findHandler(entry)
 77	if err != nil {
 78		return nil, nil, err
 79	}
 80	return handler.Read(s, entry)
 81}
 82
 83func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
 84	var fileList []os.FileInfo
 85	user, err := shared.GetUser(s.Context())
 86	if err != nil {
 87		return fileList, err
 88	}
 89	cleanFilename := filepath.Base(fpath)
 90
 91	var post *db.Post
 92	var posts []*db.Post
 93
 94	if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
 95		name := cleanFilename
 96		if name == "" {
 97			name = "/"
 98		}
 99
100		fileList = append(fileList, &utils.VirtualFile{
101			FName:  name,
102			FIsDir: true,
103		})
104
105		for _, space := range spaces {
106			curPosts, e := dbpool.FindAllPostsForUser(user.ID, space)
107			if e != nil {
108				err = e
109				break
110			}
111			posts = append(posts, curPosts...)
112		}
113	} else {
114		for _, space := range spaces {
115
116			p, e := dbpool.FindPostWithFilename(cleanFilename, user.ID, space)
117			if e != nil {
118				err = e
119				continue
120			}
121			post = p
122		}
123
124		posts = append(posts, post)
125	}
126
127	if err != nil && !errors.Is(err, sql.ErrNoRows) {
128		return nil, err
129	}
130
131	for _, post := range posts {
132		if post == nil {
133			continue
134		}
135
136		fileList = append(fileList, &utils.VirtualFile{
137			FName:    post.Filename,
138			FIsDir:   false,
139			FSize:    int64(post.FileSize),
140			FModTime: *post.UpdatedAt,
141		})
142	}
143
144	return fileList, nil
145}
146
147func (r *FileHandlerRouter) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
148	return BaseList(s, fpath, isDir, recursive, r.Spaces, r.DBPool)
149}
150
151func (r *FileHandlerRouter) GetLogger() *slog.Logger {
152	return r.Cfg.Logger
153}
154
155func (r *FileHandlerRouter) Validate(s ssh.Session) error {
156	user, err := shared.GetUser(s.Context())
157	if err != nil {
158		return err
159	}
160
161	r.Cfg.Logger.Info(
162		"attempting to upload files",
163		"user", user.Name,
164		"space", r.Cfg.Space,
165	)
166	return nil
167}