repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

commit
412982b
parent
bbb4d7c
author
Eric Bower
date
2022-10-10 03:30:08 +0000 UTC
bug(upload): removed user field from shared file uploader struct

When a user attempts to upload files to our server, the
logic we use to perform file uploads is baked into a few similar
structs (`ScpUploadHandler` and `UploadImgHandler`).  In those
structs we authenticate their `ssh.Session` by using their name
to find a user record in our db.

We mistakenly had a user field on those shared structs which would
sometimes result in a race condition where user A would upload files to
user B's blog.

We have removed the user field and instead leverage
`ssh.Session.Context` to set and access the user object for our file
uploaders.
3 files changed,  +75, -24
M filehandlers/imgs/handler.go
+34, -11
  1@@ -26,6 +26,16 @@ var GB = MB * 1024
  2 var maxSize = 1 * GB
  3 var maxImgSize = 10 * MB
  4 
  5+type ctxUserKey struct{}
  6+
  7+func getUser(s ssh.Session) (*db.User, error) {
  8+	user := s.Context().Value(ctxUserKey{}).(*db.User)
  9+	if user == nil {
 10+		return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
 11+	}
 12+	return user, nil
 13+}
 14+
 15 type PostMetaData struct {
 16 	*db.Post
 17 	OrigText  []byte
 18@@ -36,7 +46,6 @@ type PostMetaData struct {
 19 }
 20 
 21 type UploadImgHandler struct {
 22-	User    *db.User
 23 	DBPool  db.DB
 24 	Cfg     *shared.ConfigSite
 25 	Storage storage.ObjectStorage
 26@@ -68,13 +77,18 @@ func (h *UploadImgHandler) removePost(data *PostMetaData) error {
 27 }
 28 
 29 func (h *UploadImgHandler) Read(s ssh.Session, filename string) (os.FileInfo, io.ReaderAt, error) {
 30+	user, err := getUser(s)
 31+	if err != nil {
 32+		return nil, nil, err
 33+	}
 34+
 35 	cleanFilename := strings.ReplaceAll(filename, "/", "")
 36 
 37 	if cleanFilename == "" || cleanFilename == "." {
 38 		return nil, nil, os.ErrNotExist
 39 	}
 40 
 41-	post, err := h.DBPool.FindPostWithFilename(cleanFilename, h.User.ID, h.Cfg.Space)
 42+	post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 43 	if err != nil {
 44 		return nil, nil, err
 45 	}
 46@@ -86,7 +100,7 @@ func (h *UploadImgHandler) Read(s ssh.Session, filename string) (os.FileInfo, io
 47 		FModTime: *post.UpdatedAt,
 48 	}
 49 
 50-	bucket, err := h.Storage.GetBucket(h.User.ID)
 51+	bucket, err := h.Storage.GetBucket(user.ID)
 52 	if err != nil {
 53 		return nil, nil, err
 54 	}
 55@@ -101,9 +115,12 @@ func (h *UploadImgHandler) Read(s ssh.Session, filename string) (os.FileInfo, io
 56 
 57 func (h *UploadImgHandler) List(s ssh.Session, filename string) ([]os.FileInfo, error) {
 58 	var fileList []os.FileInfo
 59+	user, err := getUser(s)
 60+	if err != nil {
 61+		return fileList, err
 62+	}
 63 	cleanFilename := strings.ReplaceAll(filename, "/", "")
 64 
 65-	var err error
 66 	var post *db.Post
 67 	var posts []*db.Post
 68 
 69@@ -118,9 +135,9 @@ func (h *UploadImgHandler) List(s ssh.Session, filename string) ([]os.FileInfo,
 70 			FIsDir: true,
 71 		})
 72 
 73-		posts, err = h.DBPool.FindAllPostsForUser(h.User.ID, h.Cfg.Space)
 74+		posts, err = h.DBPool.FindAllPostsForUser(user.ID, h.Cfg.Space)
 75 	} else {
 76-		post, err = h.DBPool.FindPostWithFilename(cleanFilename, h.User.ID, h.Cfg.Space)
 77+		post, err = h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 78 
 79 		posts = append(posts, post)
 80 	}
 81@@ -157,11 +174,17 @@ func (h *UploadImgHandler) Validate(s ssh.Session) error {
 82 		return fmt.Errorf("must have username set")
 83 	}
 84 
 85-	h.User = user
 86+	s.Context().SetValue(ctxUserKey{}, user)
 87+	h.Cfg.Logger.Infof("(%s) attempting to upload files to (%s)", user.Name, h.Cfg.Space)
 88 	return nil
 89 }
 90 
 91 func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
 92+	user, err := getUser(s)
 93+	if err != nil {
 94+		return "", err
 95+	}
 96+
 97 	filename := entry.Name
 98 
 99 	var text []byte
100@@ -207,7 +230,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
101 
102 	post, err := h.DBPool.FindPostWithFilename(
103 		nextPost.Filename,
104-		h.User.ID,
105+		user.ID,
106 		h.Cfg.Space,
107 	)
108 	if err != nil {
109@@ -218,7 +241,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
110 	metadata := PostMetaData{
111 		OrigText:  text,
112 		Post:      &nextPost,
113-		User:      h.User,
114+		User:      user,
115 		FileEntry: entry,
116 		Cur:       post,
117 	}
118@@ -227,7 +250,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
119 		metadata.Post.PublishAt = post.PublishAt
120 	}
121 
122-	err = h.writeImg(&metadata)
123+	err = h.writeImg(s, &metadata)
124 	if err != nil {
125 		return "", err
126 	}
127@@ -235,7 +258,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
128 	curl := shared.NewCreateURL(h.Cfg)
129 	url := h.Cfg.FullPostURL(
130 		curl,
131-		h.User.Name,
132+		user.Name,
133 		metadata.Slug,
134 	)
135 	return url, nil
M filehandlers/imgs/img.go
+7, -2
 1@@ -8,6 +8,7 @@ import (
 2 	"git.sr.ht/~erock/pico/db"
 3 	"git.sr.ht/~erock/pico/imgs/storage"
 4 	"git.sr.ht/~erock/pico/shared"
 5+	"github.com/gliderlabs/ssh"
 6 )
 7 
 8 func (h *UploadImgHandler) validateImg(data *PostMetaData) (bool, error) {
 9@@ -96,11 +97,15 @@ func (h *UploadImgHandler) metaImg(data *PostMetaData) error {
10 	return nil
11 }
12 
13-func (h *UploadImgHandler) writeImg(data *PostMetaData) error {
14+func (h *UploadImgHandler) writeImg(s ssh.Session, data *PostMetaData) error {
15 	valid, err := h.validateImg(data)
16 	if !valid {
17 		return err
18 	}
19+	user, err := getUser(s)
20+	if err != nil {
21+		return err
22+	}
23 
24 	err = h.metaImg(data)
25 	if err != nil {
26@@ -130,7 +135,7 @@ func (h *UploadImgHandler) writeImg(data *PostMetaData) error {
27 	} else if data.Cur == nil {
28 		h.Cfg.Logger.Infof("(%s) not found, adding record", data.Filename)
29 		insertPost := db.Post{
30-			UserID: h.User.ID,
31+			UserID: user.ID,
32 			Space:  h.Cfg.Space,
33 
34 			Data:        data.Data,
M filehandlers/post_handler.go
+34, -11
  1@@ -19,6 +19,16 @@ import (
  2 	"github.com/gliderlabs/ssh"
  3 )
  4 
  5+type ctxUserKey struct{}
  6+
  7+func getUser(s ssh.Session) (*db.User, error) {
  8+	user := s.Context().Value(ctxUserKey{}).(*db.User)
  9+	if user == nil {
 10+		return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
 11+	}
 12+	return user, nil
 13+}
 14+
 15 type PostMetaData struct {
 16 	*db.Post
 17 	Cur       *db.Post
 18@@ -33,7 +43,6 @@ type ScpFileHooks interface {
 19 }
 20 
 21 type ScpUploadHandler struct {
 22-	User      *db.User
 23 	DBPool    db.DB
 24 	Cfg       *shared.ConfigSite
 25 	Hooks     ScpFileHooks
 26@@ -52,13 +61,17 @@ func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks,
 27 }
 28 
 29 func (h *ScpUploadHandler) Read(s ssh.Session, filename string) (os.FileInfo, io.ReaderAt, error) {
 30+	user, err := getUser(s)
 31+	if err != nil {
 32+		return nil, nil, err
 33+	}
 34 	cleanFilename := strings.ReplaceAll(filename, "/", "")
 35 
 36 	if cleanFilename == "" || cleanFilename == "." {
 37 		return nil, nil, os.ErrNotExist
 38 	}
 39 
 40-	post, err := h.DBPool.FindPostWithFilename(cleanFilename, h.User.ID, h.Cfg.Space)
 41+	post, err := h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 42 	if err != nil {
 43 		return nil, nil, err
 44 	}
 45@@ -75,9 +88,13 @@ func (h *ScpUploadHandler) Read(s ssh.Session, filename string) (os.FileInfo, io
 46 
 47 func (h *ScpUploadHandler) List(s ssh.Session, filename string) ([]os.FileInfo, error) {
 48 	var fileList []os.FileInfo
 49+	user, err := getUser(s)
 50+	if err != nil {
 51+		return fileList, err
 52+	}
 53+
 54 	cleanFilename := strings.ReplaceAll(filename, "/", "")
 55 
 56-	var err error
 57 	var post *db.Post
 58 	var posts []*db.Post
 59 
 60@@ -92,9 +109,9 @@ func (h *ScpUploadHandler) List(s ssh.Session, filename string) ([]os.FileInfo,
 61 			FIsDir: true,
 62 		})
 63 
 64-		posts, err = h.DBPool.FindAllPostsForUser(h.User.ID, h.Cfg.Space)
 65+		posts, err = h.DBPool.FindAllPostsForUser(user.ID, h.Cfg.Space)
 66 	} else {
 67-		post, err = h.DBPool.FindPostWithFilename(cleanFilename, h.User.ID, h.Cfg.Space)
 68+		post, err = h.DBPool.FindPostWithFilename(cleanFilename, user.ID, h.Cfg.Space)
 69 
 70 		posts = append(posts, post)
 71 	}
 72@@ -131,19 +148,25 @@ func (h *ScpUploadHandler) Validate(s ssh.Session) error {
 73 		return fmt.Errorf("must have username set")
 74 	}
 75 
 76-	h.User = user
 77+	s.Context().SetValue(ctxUserKey{}, user)
 78+	h.Cfg.Logger.Infof("(%s) attempting to upload files to (%s)", user.Name, h.Cfg.Space)
 79 	return nil
 80 }
 81 
 82 func (h *ScpUploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
 83 	logger := h.Cfg.Logger
 84-	userID := h.User.ID
 85+	user, err := getUser(s)
 86+	if err != nil {
 87+		return "", err
 88+	}
 89+
 90+	userID := user.ID
 91 	filename := entry.Name
 92 
 93 	if shared.IsExtAllowed(filename, h.ImgClient.Cfg.AllowedExt) {
 94 		if !h.ImgClient.HasAccess(userID) {
 95 			msg := "user (%s) does not have access to imgs.sh, cannot upload file (%s)"
 96-			return "", fmt.Errorf(msg, h.User.Name, filename)
 97+			return "", fmt.Errorf(msg, user.Name, filename)
 98 		}
 99 
100 		return h.ImgClient.Upload(s, entry)
101@@ -178,7 +201,7 @@ func (h *ScpUploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
102 
103 	metadata := PostMetaData{
104 		Post:      &nextPost,
105-		User:      h.User,
106+		User:      user,
107 		FileEntry: entry,
108 	}
109 
110@@ -258,7 +281,7 @@ func (h *ScpUploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
111 		if metadata.Text == post.Text {
112 			logger.Infof("(%s) found, but text is identical, skipping", filename)
113 			curl := shared.NewCreateURL(h.Cfg)
114-			return h.Cfg.FullPostURL(curl, h.User.Name, metadata.Slug), nil
115+			return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
116 		}
117 
118 		logger.Infof("(%s) found, updating record", filename)
119@@ -292,5 +315,5 @@ func (h *ScpUploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
120 	}
121 
122 	curl := shared.NewCreateURL(h.Cfg)
123-	return h.Cfg.FullPostURL(curl, h.User.Name, metadata.Slug), nil
124+	return h.Cfg.FullPostURL(curl, user.Name, metadata.Slug), nil
125 }