Antonio Mika
·
08 Oct 24
router_handler.go
1package filehandlers
2
3import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10
11 "github.com/charmbracelet/ssh"
12 "github.com/picosh/pico/db"
13 "github.com/picosh/pico/shared"
14 "github.com/picosh/send/utils"
15)
16
17type ReadWriteHandler interface {
18 Write(ssh.Session, *utils.FileEntry) (string, error)
19 Read(ssh.Session, *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error)
20 Delete(ssh.Session, *utils.FileEntry) error
21}
22
23type FileHandlerRouter struct {
24 FileMap map[string]ReadWriteHandler
25 Cfg *shared.ConfigSite
26 DBPool db.DB
27 Spaces []string
28}
29
30var _ utils.CopyFromClientHandler = &FileHandlerRouter{} // Verify implementation
31var _ utils.CopyFromClientHandler = (*FileHandlerRouter)(nil) // Verify implementation
32
33func NewFileHandlerRouter(cfg *shared.ConfigSite, dbpool db.DB, mapper map[string]ReadWriteHandler) *FileHandlerRouter {
34 return &FileHandlerRouter{
35 Cfg: cfg,
36 DBPool: dbpool,
37 FileMap: mapper,
38 Spaces: []string{cfg.Space},
39 }
40}
41
42func (r *FileHandlerRouter) findHandler(entry *utils.FileEntry) (ReadWriteHandler, error) {
43 fext := filepath.Ext(entry.Filepath)
44 handler, ok := r.FileMap[fext]
45 if !ok {
46 hand, hasFallback := r.FileMap["fallback"]
47 if !hasFallback {
48 return nil, fmt.Errorf("no corresponding handler for file extension: %s", fext)
49 }
50 handler = hand
51 }
52 return handler, nil
53}
54
55func (r *FileHandlerRouter) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
56 if entry.Mode.IsDir() {
57 return "", os.ErrInvalid
58 }
59
60 handler, err := r.findHandler(entry)
61 if err != nil {
62 return "", err
63 }
64 return handler.Write(s, entry)
65}
66
67func (r *FileHandlerRouter) Delete(s ssh.Session, entry *utils.FileEntry) error {
68 handler, err := r.findHandler(entry)
69 if err != nil {
70 return err
71 }
72 return handler.Delete(s, entry)
73}
74
75func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
76 handler, err := r.findHandler(entry)
77 if err != nil {
78 return nil, nil, err
79 }
80 return handler.Read(s, entry)
81}
82
83func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
84 var fileList []os.FileInfo
85 user, err := shared.GetUser(s.Context())
86 if err != nil {
87 return fileList, err
88 }
89 cleanFilename := filepath.Base(fpath)
90
91 var post *db.Post
92 var posts []*db.Post
93
94 if cleanFilename == "" || cleanFilename == "." || cleanFilename == "/" {
95 name := cleanFilename
96 if name == "" {
97 name = "/"
98 }
99
100 fileList = append(fileList, &utils.VirtualFile{
101 FName: name,
102 FIsDir: true,
103 })
104
105 for _, space := range spaces {
106 curPosts, e := dbpool.FindAllPostsForUser(user.ID, space)
107 if e != nil {
108 err = e
109 break
110 }
111 posts = append(posts, curPosts...)
112 }
113 } else {
114 for _, space := range spaces {
115
116 p, e := dbpool.FindPostWithFilename(cleanFilename, user.ID, space)
117 if e != nil {
118 err = e
119 continue
120 }
121 post = p
122 }
123
124 posts = append(posts, post)
125 }
126
127 if err != nil && !errors.Is(err, sql.ErrNoRows) {
128 return nil, err
129 }
130
131 for _, post := range posts {
132 if post == nil {
133 continue
134 }
135
136 fileList = append(fileList, &utils.VirtualFile{
137 FName: post.Filename,
138 FIsDir: false,
139 FSize: int64(post.FileSize),
140 FModTime: *post.UpdatedAt,
141 })
142 }
143
144 return fileList, nil
145}
146
147func (r *FileHandlerRouter) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
148 return BaseList(s, fpath, isDir, recursive, r.Spaces, r.DBPool)
149}
150
151func (r *FileHandlerRouter) GetLogger() *slog.Logger {
152 return r.Cfg.Logger
153}
154
155func (r *FileHandlerRouter) Validate(s ssh.Session) error {
156 user, err := shared.GetUser(s.Context())
157 if err != nil {
158 return err
159 }
160
161 r.Cfg.Logger.Info(
162 "attempting to upload files",
163 "user", user.Name,
164 "space", r.Cfg.Space,
165 )
166 return nil
167}