Eric Bower
·
17 Jun 24
router_handler.go
1package filehandlers
2
3import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10
11 "github.com/charmbracelet/ssh"
12 "github.com/picosh/pico/db"
13 "github.com/picosh/pico/filehandlers/util"
14 "github.com/picosh/pico/shared"
15 "github.com/picosh/send/send/utils"
16)
17
18type ReadWriteHandler interface {
19 Write(ssh.Session, *utils.FileEntry) (string, error)
20 Read(ssh.Session, *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error)
21 Delete(ssh.Session, *utils.FileEntry) error
22}
23
24type FileHandlerRouter struct {
25 FileMap map[string]ReadWriteHandler
26 Cfg *shared.ConfigSite
27 DBPool db.DB
28 Spaces []string
29}
30
31var _ utils.CopyFromClientHandler = &FileHandlerRouter{} // Verify implementation
32var _ utils.CopyFromClientHandler = (*FileHandlerRouter)(nil) // Verify implementation
33
34func NewFileHandlerRouter(cfg *shared.ConfigSite, dbpool db.DB, mapper map[string]ReadWriteHandler) *FileHandlerRouter {
35 return &FileHandlerRouter{
36 Cfg: cfg,
37 DBPool: dbpool,
38 FileMap: mapper,
39 Spaces: []string{cfg.Space},
40 }
41}
42
43func (r *FileHandlerRouter) findHandler(entry *utils.FileEntry) (ReadWriteHandler, error) {
44 fext := filepath.Ext(entry.Filepath)
45 handler, ok := r.FileMap[fext]
46 if !ok {
47 hand, hasFallback := r.FileMap["fallback"]
48 if !hasFallback {
49 return nil, fmt.Errorf("no corresponding handler for file extension: %s", fext)
50 }
51 handler = hand
52 }
53 return handler, nil
54}
55
56func (r *FileHandlerRouter) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
57 if entry.Mode.IsDir() {
58 return "", os.ErrInvalid
59 }
60
61 handler, err := r.findHandler(entry)
62 if err != nil {
63 return "", err
64 }
65 return handler.Write(s, entry)
66}
67
68func (r *FileHandlerRouter) Delete(s ssh.Session, entry *utils.FileEntry) error {
69 handler, err := r.findHandler(entry)
70 if err != nil {
71 return err
72 }
73 return handler.Delete(s, entry)
74}
75
76func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
77 handler, err := r.findHandler(entry)
78 if err != nil {
79 return nil, nil, err
80 }
81 return handler.Read(s, entry)
82}
83
84func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
85 var fileList []os.FileInfo
86 user, err := util.GetUser(s.Context())
87 if err != nil {
88 return fileList, err
89 }
90 cleanFilename := filepath.Base(fpath)
91
92 var post *db.Post
93 var posts []*db.Post
94
95 if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
96 name := cleanFilename
97 if name == "" {
98 name = "/"
99 }
100
101 fileList = append(fileList, &utils.VirtualFile{
102 FName: name,
103 FIsDir: true,
104 })
105
106 for _, space := range spaces {
107 curPosts, e := dbpool.FindAllPostsForUser(user.ID, space)
108 if e != nil {
109 err = e
110 break
111 }
112 posts = append(posts, curPosts...)
113 }
114 } else {
115 for _, space := range spaces {
116
117 p, e := dbpool.FindPostWithFilename(cleanFilename, user.ID, space)
118 if e != nil {
119 err = e
120 continue
121 }
122 post = p
123 }
124
125 posts = append(posts, post)
126 }
127
128 if err != nil && !errors.Is(err, sql.ErrNoRows) {
129 return nil, err
130 }
131
132 for _, post := range posts {
133 if post == nil {
134 continue
135 }
136
137 fileList = append(fileList, &utils.VirtualFile{
138 FName: post.Filename,
139 FIsDir: false,
140 FSize: int64(post.FileSize),
141 FModTime: *post.UpdatedAt,
142 })
143 }
144
145 return fileList, nil
146}
147
148func (r *FileHandlerRouter) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
149 return BaseList(s, fpath, isDir, recursive, r.Spaces, r.DBPool)
150}
151
152func (r *FileHandlerRouter) GetLogger() *slog.Logger {
153 return r.Cfg.Logger
154}
155
156func (r *FileHandlerRouter) Validate(s ssh.Session) error {
157 user, err := util.GetUser(s.Context())
158 if err != nil {
159 return err
160 }
161
162 r.Cfg.Logger.Info(
163 "attempting to upload files",
164 "user", user.Name,
165 "space", r.Cfg.Space,
166 )
167 return nil
168}