Antonio Mika
·
08 Oct 24
post_handler.go
1package filehandlers
2
3import (
4 "encoding/binary"
5 "fmt"
6 "io"
7 "net/http"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/ssh"
14 "github.com/picosh/pico/db"
15 "github.com/picosh/pico/shared"
16 "github.com/picosh/pico/shared/storage"
17 sendutils "github.com/picosh/send/utils"
18 "github.com/picosh/utils"
19)
20
21type PostMetaData struct {
22 *db.Post
23 Cur *db.Post
24 Tags []string
25 User *db.User
26 FileEntry *sendutils.FileEntry
27 Aliases []string
28}
29
30type ScpFileHooks interface {
31 FileValidate(s ssh.Session, data *PostMetaData) (bool, error)
32 FileMeta(s ssh.Session, data *PostMetaData) error
33}
34
35type ScpUploadHandler struct {
36 DBPool db.DB
37 Cfg *shared.ConfigSite
38 Hooks ScpFileHooks
39}
40
41func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks, st storage.StorageServe) *ScpUploadHandler {
42 return &ScpUploadHandler{
43 DBPool: dbpool,
44 Cfg: cfg,
45 Hooks: hooks,
46 }
47}
48
49func (h *ScpUploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) {
50 user, err := shared.GetUser(s.Context())
51 if err != nil {
52 return nil, nil, err
53 }
54 cleanFilename := filepath.Base(entry.Filepath)
55
56 if cleanFilename == "" || cleanFilename == "." {
57 return nil, nil, os.ErrNotExist
58 }
59
60 post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
61 if err != nil {
62 return nil, nil, err
63 }
64
65 fileInfo := &sendutils.VirtualFile{
66 FName: post.Filename,
67 FIsDir: false,
68 FSize: int64(post.FileSize),
69 FModTime: *post.UpdatedAt,
70 }
71
72 reader := sendutils.NopReaderAtCloser(strings.NewReader(post.Text))
73
74 return fileInfo, reader, nil
75}
76
77func (h *ScpUploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
78 logger := h.Cfg.Logger
79 user, err := shared.GetUser(s.Context())
80 if err != nil {
81 logger.Error("error getting user from ctx", "err", err.Error())
82 return "", err
83 }
84
85 userID := user.ID
86 filename := filepath.Base(entry.Filepath)
87 logger = shared.LoggerWithUser(logger, user)
88 logger = logger.With(
89 "filename", filename,
90 )
91
92 if entry.Mode.IsDir() {
93 return "", fmt.Errorf("file entry is directory, but only files are supported: %s", filename)
94 }
95
96 var origText []byte
97 if b, err := io.ReadAll(entry.Reader); err == nil {
98 origText = b
99 }
100
101 mimeType := http.DetectContentType(origText)
102 ext := filepath.Ext(filename)
103 // DetectContentType does not detect markdown
104 if ext == ".md" {
105 mimeType = "text/markdown; charset=UTF-8"
106 }
107
108 now := time.Now()
109 slug := utils.SanitizeFileExt(filename)
110 fileSize := binary.Size(origText)
111 shasum := utils.Shasum(origText)
112
113 nextPost := db.Post{
114 Filename: filename,
115 Slug: slug,
116 PublishAt: &now,
117 Text: string(origText),
118 MimeType: mimeType,
119 FileSize: fileSize,
120 Shasum: shasum,
121 }
122
123 metadata := PostMetaData{
124 Post: &nextPost,
125 User: user,
126 FileEntry: entry,
127 }
128
129 valid, err := h.Hooks.FileValidate(s, &metadata)
130 if !valid {
131 logger.Error("file failed validation", "err", err.Error())
132 return "", err
133 }
134
135 post, err := h.DBPool.FindPostWithFilename(metadata.Filename, metadata.User.ID, h.Cfg.Space)
136 if err != nil {
137 logger.Error("unable to load post, continuing", "err", err.Error())
138 }
139
140 if post != nil {
141 metadata.Cur = post
142 metadata.Data = post.Data
143 metadata.Post.PublishAt = post.PublishAt
144 }
145
146 err = h.Hooks.FileMeta(s, &metadata)
147 if err != nil {
148 logger.Error("file could not load meta", "err", err.Error())
149 return "", err
150 }
151
152 modTime := time.Now()
153
154 if entry.Mtime > 0 {
155 modTime = time.Unix(entry.Mtime, 0)
156 }
157
158 // if the file is empty we remove it from our database
159 if post == nil {
160 logger.Info("file not found, adding record")
161 insertPost := db.Post{
162 UserID: userID,
163 Space: h.Cfg.Space,
164
165 Data: metadata.Data,
166 Description: metadata.Description,
167 Filename: metadata.Filename,
168 FileSize: metadata.FileSize,
169 Hidden: metadata.Hidden,
170 MimeType: metadata.MimeType,
171 PublishAt: metadata.PublishAt,
172 Shasum: metadata.Shasum,
173 Slug: metadata.Slug,
174 Text: metadata.Text,
175 Title: metadata.Title,
176 ExpiresAt: metadata.ExpiresAt,
177 UpdatedAt: &modTime,
178 }
179 post, err = h.DBPool.InsertPost(&insertPost)
180 if err != nil {
181 logger.Error("post could not be created", "err", err.Error())
182 return "", fmt.Errorf("error for %s: %v", filename, err)
183 }
184
185 if len(metadata.Aliases) > 0 {
186 logger.Info(
187 "found post aliases, replacing with old aliases",
188 "aliases",
189 strings.Join(metadata.Aliases, ","),
190 )
191 err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
192 if err != nil {
193 logger.Error("post could not replace aliases", "err", err.Error())
194 return "", fmt.Errorf("error for %s: %v", filename, err)
195 }
196 }
197
198 if len(metadata.Tags) > 0 {
199 logger.Info(
200 "found post tags, replacing with old tags",
201 "tags", strings.Join(metadata.Tags, ","),
202 )
203 err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
204 if err != nil {
205 logger.Error("post could not replace tags", "err", err.Error())
206 return "", fmt.Errorf("error for %s: %v", filename, err)
207 }
208 }
209 } else {
210 if metadata.Text == post.Text && modTime.Equal(*post.UpdatedAt) {
211 logger.Info("file found, but text is identical, skipping")
212 curl := shared.NewCreateURL(h.Cfg)
213 return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
214 }
215
216 logger.Info("file found, updating record")
217
218 updatePost := db.Post{
219 ID: post.ID,
220
221 Data: metadata.Data,
222 FileSize: metadata.FileSize,
223 Description: metadata.Description,
224 PublishAt: metadata.PublishAt,
225 Slug: metadata.Slug,
226 Shasum: metadata.Shasum,
227 Text: metadata.Text,
228 Title: metadata.Title,
229 Hidden: metadata.Hidden,
230 ExpiresAt: metadata.ExpiresAt,
231 UpdatedAt: &modTime,
232 }
233 _, err = h.DBPool.UpdatePost(&updatePost)
234 if err != nil {
235 logger.Error("post could not be updated", "err", err.Error())
236 return "", fmt.Errorf("error for %s: %v", filename, err)
237 }
238
239 logger.Info(
240 "found post tags, replacing with old tags",
241 "tags", strings.Join(metadata.Tags, ","),
242 )
243 err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
244 if err != nil {
245 logger.Error("post could not replace tags", "err", err.Error())
246 return "", fmt.Errorf("error for %s: %v", filename, err)
247 }
248
249 logger.Info(
250 "found post aliases, replacing with old aliases",
251 "aliases", strings.Join(metadata.Aliases, ","),
252 )
253 err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
254 if err != nil {
255 logger.Error("post could not replace aliases", "err", err.Error())
256 return "", fmt.Errorf("error for %s: %v", filename, err)
257 }
258 }
259
260 curl := shared.NewCreateURL(h.Cfg)
261 return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
262}
263
264func (h *ScpUploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
265 logger := h.Cfg.Logger
266 user, err := shared.GetUser(s.Context())
267 if err != nil {
268 logger.Error("could not get user from ctx", "err", err.Error())
269 return err
270 }
271
272 userID := user.ID
273 filename := filepath.Base(entry.Filepath)
274 logger = shared.LoggerWithUser(logger, user)
275 logger = logger.With(
276 "filename", filename,
277 )
278
279 post, err := h.DBPool.FindPostWithFilename(filename, userID, h.Cfg.Space)
280 if err != nil {
281 return err
282 }
283
284 if post == nil {
285 return os.ErrNotExist
286 }
287
288 err = h.DBPool.RemovePosts([]string{post.ID})
289 logger.Info("removing record")
290 if err != nil {
291 logger.Error("post could not remove", "err", err.Error())
292 return fmt.Errorf("error for %s: %v", filename, err)
293 }
294 return nil
295}