repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / filehandlers
Eric Bower · 11 Dec 24

post_handler.go

  1package filehandlers
  2
  3import (
  4	"encoding/binary"
  5	"fmt"
  6	"io"
  7	"net/http"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/ssh"
 14	"github.com/picosh/pico/db"
 15	"github.com/picosh/pico/shared"
 16	"github.com/picosh/pico/shared/storage"
 17	sendutils "github.com/picosh/send/utils"
 18	"github.com/picosh/utils"
 19)
 20
 21type PostMetaData struct {
 22	*db.Post
 23	Cur       *db.Post
 24	Tags      []string
 25	User      *db.User
 26	FileEntry *sendutils.FileEntry
 27	Aliases   []string
 28}
 29
 30type ScpFileHooks interface {
 31	FileValidate(s ssh.Session, data *PostMetaData) (bool, error)
 32	FileMeta(s ssh.Session, data *PostMetaData) error
 33}
 34
 35type ScpUploadHandler struct {
 36	DBPool db.DB
 37	Cfg    *shared.ConfigSite
 38	Hooks  ScpFileHooks
 39}
 40
 41func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks, st storage.StorageServe) *ScpUploadHandler {
 42	return &ScpUploadHandler{
 43		DBPool: dbpool,
 44		Cfg:    cfg,
 45		Hooks:  hooks,
 46	}
 47}
 48
 49func (h *ScpUploadHandler) Read(s ssh.Session, entry *sendutils.FileEntry) (os.FileInfo, sendutils.ReaderAtCloser, error) {
 50	user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
 51	if err != nil {
 52		return nil, nil, err
 53	}
 54	cleanFilename := filepath.Base(entry.Filepath)
 55
 56	if cleanFilename == "" || cleanFilename == "." {
 57		return nil, nil, os.ErrNotExist
 58	}
 59
 60	post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 61	if err != nil {
 62		return nil, nil, err
 63	}
 64
 65	fileInfo := &sendutils.VirtualFile{
 66		FName:    post.Filename,
 67		FIsDir:   false,
 68		FSize:    int64(post.FileSize),
 69		FModTime: *post.UpdatedAt,
 70	}
 71
 72	reader := sendutils.NopReaderAtCloser(strings.NewReader(post.Text))
 73
 74	return fileInfo, reader, nil
 75}
 76
 77func (h *ScpUploadHandler) Write(s ssh.Session, entry *sendutils.FileEntry) (string, error) {
 78	logger := h.Cfg.Logger
 79	user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
 80	if err != nil {
 81		logger.Error("error getting user from ctx", "err", err.Error())
 82		return "", err
 83	}
 84
 85	userID := user.ID
 86	filename := filepath.Base(entry.Filepath)
 87	logger = shared.LoggerWithUser(logger, user)
 88	logger = logger.With(
 89		"filename", filename,
 90	)
 91
 92	if entry.Mode.IsDir() {
 93		return "", fmt.Errorf("file entry is directory, but only files are supported: %s", filename)
 94	}
 95
 96	var origText []byte
 97	if b, err := io.ReadAll(entry.Reader); err == nil {
 98		origText = b
 99	}
100
101	mimeType := http.DetectContentType(origText)
102	ext := filepath.Ext(filename)
103	// DetectContentType does not detect markdown
104	if ext == ".md" {
105		mimeType = "text/markdown; charset=UTF-8"
106	}
107
108	now := time.Now()
109	slug := utils.SanitizeFileExt(filename)
110	fileSize := binary.Size(origText)
111	shasum := utils.Shasum(origText)
112
113	nextPost := db.Post{
114		Filename:  filename,
115		Slug:      slug,
116		PublishAt: &now,
117		Text:      string(origText),
118		MimeType:  mimeType,
119		FileSize:  fileSize,
120		Shasum:    shasum,
121	}
122
123	metadata := PostMetaData{
124		Post:      &nextPost,
125		User:      user,
126		FileEntry: entry,
127	}
128
129	valid, err := h.Hooks.FileValidate(s, &metadata)
130	if !valid {
131		logger.Error("file failed validation", "err", err.Error())
132		return "", err
133	}
134
135	post, err := h.DBPool.FindPostWithFilename(metadata.Filename, metadata.User.ID, h.Cfg.Space)
136	if err != nil {
137		logger.Error("unable to load post, continuing", "err", err.Error())
138	}
139
140	if post != nil {
141		metadata.Cur = post
142		metadata.Data = post.Data
143		metadata.Post.PublishAt = post.PublishAt
144	}
145
146	err = h.Hooks.FileMeta(s, &metadata)
147	if err != nil {
148		logger.Error("file could not load meta", "err", err.Error())
149		return "", err
150	}
151
152	modTime := time.Now()
153
154	if entry.Mtime > 0 {
155		modTime = time.Unix(entry.Mtime, 0)
156	}
157
158	// if the file is empty we remove it from our database
159	if post == nil {
160		logger.Info("file not found, adding record")
161		insertPost := db.Post{
162			UserID: userID,
163			Space:  h.Cfg.Space,
164
165			Data:        metadata.Data,
166			Description: metadata.Description,
167			Filename:    metadata.Filename,
168			FileSize:    metadata.FileSize,
169			Hidden:      metadata.Hidden,
170			MimeType:    metadata.MimeType,
171			PublishAt:   metadata.PublishAt,
172			Shasum:      metadata.Shasum,
173			Slug:        metadata.Slug,
174			Text:        metadata.Text,
175			Title:       metadata.Title,
176			ExpiresAt:   metadata.ExpiresAt,
177			UpdatedAt:   &modTime,
178		}
179		post, err = h.DBPool.InsertPost(&insertPost)
180		if err != nil {
181			logger.Error("post could not be created", "err", err.Error())
182			return "", fmt.Errorf("error for %s: %v", filename, err)
183		}
184
185		if len(metadata.Aliases) > 0 {
186			logger.Info(
187				"found post aliases, replacing with old aliases",
188				"aliases",
189				strings.Join(metadata.Aliases, ","),
190			)
191			err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
192			if err != nil {
193				logger.Error("post could not replace aliases", "err", err.Error())
194				return "", fmt.Errorf("error for %s: %v", filename, err)
195			}
196		}
197
198		if len(metadata.Tags) > 0 {
199			logger.Info(
200				"found post tags, replacing with old tags",
201				"tags", strings.Join(metadata.Tags, ","),
202			)
203			err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
204			if err != nil {
205				logger.Error("post could not replace tags", "err", err.Error())
206				return "", fmt.Errorf("error for %s: %v", filename, err)
207			}
208		}
209	} else {
210		if metadata.Text == post.Text && modTime.Equal(*post.UpdatedAt) {
211			logger.Info("file found, but text is identical, skipping")
212			curl := shared.NewCreateURL(h.Cfg)
213			return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
214		}
215
216		logger.Info("file found, updating record")
217
218		updatePost := db.Post{
219			ID: post.ID,
220
221			Data:        metadata.Data,
222			FileSize:    metadata.FileSize,
223			Description: metadata.Description,
224			PublishAt:   metadata.PublishAt,
225			Slug:        metadata.Slug,
226			Shasum:      metadata.Shasum,
227			Text:        metadata.Text,
228			Title:       metadata.Title,
229			Hidden:      metadata.Hidden,
230			ExpiresAt:   metadata.ExpiresAt,
231			UpdatedAt:   &modTime,
232		}
233		_, err = h.DBPool.UpdatePost(&updatePost)
234		if err != nil {
235			logger.Error("post could not be updated", "err", err.Error())
236			return "", fmt.Errorf("error for %s: %v", filename, err)
237		}
238
239		logger.Info(
240			"found post tags, replacing with old tags",
241			"tags", strings.Join(metadata.Tags, ","),
242		)
243		err = h.DBPool.ReplaceTagsForPost(metadata.Tags, post.ID)
244		if err != nil {
245			logger.Error("post could not replace tags", "err", err.Error())
246			return "", fmt.Errorf("error for %s: %v", filename, err)
247		}
248
249		logger.Info(
250			"found post aliases, replacing with old aliases",
251			"aliases", strings.Join(metadata.Aliases, ","),
252		)
253		err = h.DBPool.ReplaceAliasesForPost(metadata.Aliases, post.ID)
254		if err != nil {
255			logger.Error("post could not replace aliases", "err", err.Error())
256			return "", fmt.Errorf("error for %s: %v", filename, err)
257		}
258	}
259
260	curl := shared.NewCreateURL(h.Cfg)
261	return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
262}
263
264func (h *ScpUploadHandler) Delete(s ssh.Session, entry *sendutils.FileEntry) error {
265	logger := h.Cfg.Logger
266	user, err := h.DBPool.FindUser(s.Permissions().Extensions["user_id"])
267	if err != nil {
268		logger.Error("could not get user from ctx", "err", err.Error())
269		return err
270	}
271
272	userID := user.ID
273	filename := filepath.Base(entry.Filepath)
274	logger = shared.LoggerWithUser(logger, user)
275	logger = logger.With(
276		"filename", filename,
277	)
278
279	post, err := h.DBPool.FindPostWithFilename(filename, userID, h.Cfg.Space)
280	if err != nil {
281		return err
282	}
283
284	if post == nil {
285		return os.ErrNotExist
286	}
287
288	err = h.DBPool.RemovePosts([]string{post.ID})
289	logger.Info("removing record")
290	if err != nil {
291		logger.Error("post could not remove", "err", err.Error())
292		return fmt.Errorf("error for %s: %v", filename, err)
293	}
294	return nil
295}