- commit
- 34adbd4
- parent
- 13803bd
- author
- Eric Bower
- date
- 2024-06-17 15:02:33 +0000 UTC
refactor: be pickier with which pubkey to accept
12 files changed,
+114,
-83
+3,
-8
1@@ -13,6 +13,7 @@ import (
2 "github.com/charmbracelet/wish"
3 "github.com/picosh/pico/db/postgres"
4 "github.com/picosh/pico/filehandlers"
5+ "github.com/picosh/pico/filehandlers/util"
6 "github.com/picosh/pico/shared"
7 "github.com/picosh/pico/shared/storage"
8 wsh "github.com/picosh/pico/wish"
9@@ -25,12 +26,6 @@ import (
10 "github.com/picosh/send/send/sftp"
11 )
12
13-type SSHServer struct{}
14-
15-func (me *SSHServer) authHandler(ctx ssh.Context, key ssh.PublicKey) bool {
16- return true
17-}
18-
19 func createRouter(handler *filehandlers.FileHandlerRouter) proxy.Router {
20 return func(sh ssh.Handler, s ssh.Session) []wish.Middleware {
21 return []wish.Middleware{
22@@ -88,11 +83,11 @@ func StartSshServer() {
23 }
24 handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap)
25
26- sshServer := &SSHServer{}
27+ sshAuth := util.NewSshAuthHandler(dbh, logger, cfg)
28 s, err := wish.NewServer(
29 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
30 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
31- wish.WithPublicKeyAuth(sshServer.authHandler),
32+ wish.WithPublicKeyAuth(sshAuth.PubkeyAuthHandler),
33 withProxy(
34 handler,
35 promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "feeds-ssh"),
+7,
-7
1@@ -116,7 +116,7 @@ func (h *UploadAssetHandler) GetLogger() *slog.Logger {
2 }
3
4 func (h *UploadAssetHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
5- user, err := futil.GetUser(s)
6+ user, err := futil.GetUser(s.Context())
7 if err != nil {
8 return nil, nil, err
9 }
10@@ -150,7 +150,7 @@ func (h *UploadAssetHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.Fil
11 func (h *UploadAssetHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
12 var fileList []os.FileInfo
13
14- user, err := futil.GetUser(s)
15+ user, err := futil.GetUser(s.Context())
16 if err != nil {
17 return fileList, err
18 }
19@@ -222,8 +222,8 @@ func (h *UploadAssetHandler) Validate(s ssh.Session) error {
20 ff.Data.StorageMax = ff.FindStorageMax(h.Cfg.MaxSize)
21 ff.Data.FileMax = ff.FindFileMax(h.Cfg.MaxAssetSize)
22
23- futil.SetFeatureFlag(s, ff)
24- futil.SetUser(s, user)
25+ futil.SetFeatureFlag(s.Context(), ff)
26+ futil.SetUser(s.Context(), user)
27
28 assetBucket := shared.GetAssetBucketName(user.ID)
29 bucket, err := h.Storage.UpsertBucket(assetBucket)
30@@ -271,7 +271,7 @@ func (h *UploadAssetHandler) findDenylist(bucket sst.Bucket, project *db.Project
31 }
32
33 func (h *UploadAssetHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
34- user, err := futil.GetUser(s)
35+ user, err := futil.GetUser(s.Context())
36 if err != nil {
37 h.Cfg.Logger.Error("user not found in ctx", "err", err.Error())
38 return "", err
39@@ -330,7 +330,7 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *utils.FileEntry) (strin
40 return "", err
41 }
42
43- featureFlag, err := futil.GetFeatureFlag(s)
44+ featureFlag, err := futil.GetFeatureFlag(s.Context())
45 if err != nil {
46 return "", err
47 }
48@@ -423,7 +423,7 @@ func (h *UploadAssetHandler) Write(s ssh.Session, entry *utils.FileEntry) (strin
49 }
50
51 func (h *UploadAssetHandler) Delete(s ssh.Session, entry *utils.FileEntry) error {
52- user, err := futil.GetUser(s)
53+ user, err := futil.GetUser(s.Context())
54 if err != nil {
55 h.Cfg.Logger.Error("user not found in ctx", "err", err.Error())
56 return err
+4,
-4
1@@ -48,7 +48,7 @@ func NewUploadImgHandler(dbpool db.DB, cfg *shared.ConfigSite, storage storage.S
2 }
3
4 func (h *UploadImgHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
5- user, err := util.GetUser(s)
6+ user, err := util.GetUser(s.Context())
7 if err != nil {
8 return nil, nil, err
9 }
10@@ -87,7 +87,7 @@ func (h *UploadImgHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileI
11 }
12
13 func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
14- user, err := util.GetUser(s)
15+ user, err := util.GetUser(s.Context())
16 if err != nil {
17 h.Cfg.Logger.Error(err.Error())
18 return "", err
19@@ -143,7 +143,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
20 h.Cfg.Logger.Info("unable to find image, continuing", "filename", nextPost.Filename, "err", err.Error())
21 }
22
23- featureFlag, err := util.GetFeatureFlag(s)
24+ featureFlag, err := util.GetFeatureFlag(s.Context())
25 if err != nil {
26 return "", err
27 }
28@@ -190,7 +190,7 @@ func (h *UploadImgHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
29 }
30
31 func (h *UploadImgHandler) Delete(s ssh.Session, entry *utils.FileEntry) error {
32- user, err := util.GetUser(s)
33+ user, err := util.GetUser(s.Context())
34 if err != nil {
35 return err
36 }
+1,
-1
1@@ -80,7 +80,7 @@ func (h *UploadImgHandler) writeImg(s ssh.Session, data *PostMetaData) error {
2 if !valid {
3 return err
4 }
5- user, err := util.GetUser(s)
6+ user, err := util.GetUser(s.Context())
7 if err != nil {
8 return err
9 }
+3,
-3
1@@ -47,7 +47,7 @@ func NewScpPostHandler(dbpool db.DB, cfg *shared.ConfigSite, hooks ScpFileHooks,
2 }
3
4 func (h *ScpUploadHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
5- user, err := util.GetUser(s)
6+ user, err := util.GetUser(s.Context())
7 if err != nil {
8 return nil, nil, err
9 }
10@@ -76,7 +76,7 @@ func (h *ScpUploadHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileI
11
12 func (h *ScpUploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
13 logger := h.Cfg.Logger
14- user, err := util.GetUser(s)
15+ user, err := util.GetUser(s.Context())
16 if err != nil {
17 logger.Error(err.Error())
18 return "", err
19@@ -262,7 +262,7 @@ func (h *ScpUploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string,
20
21 func (h *ScpUploadHandler) Delete(s ssh.Session, entry *utils.FileEntry) error {
22 logger := h.Cfg.Logger
23- user, err := util.GetUser(s)
24+ user, err := util.GetUser(s.Context())
25 if err != nil {
26 logger.Error(err.Error())
27 return err
+7,
-31
1@@ -83,7 +83,7 @@ func (r *FileHandlerRouter) Read(s ssh.Session, entry *utils.FileEntry) (os.File
2
3 func BaseList(s ssh.Session, fpath string, isDir bool, recursive bool, spaces []string, dbpool db.DB) ([]os.FileInfo, error) {
4 var fileList []os.FileInfo
5- user, err := util.GetUser(s)
6+ user, err := util.GetUser(s.Context())
7 if err != nil {
8 return fileList, err
9 }
10@@ -154,39 +154,15 @@ func (r *FileHandlerRouter) GetLogger() *slog.Logger {
11 }
12
13 func (r *FileHandlerRouter) Validate(s ssh.Session) error {
14- var err error
15- key, err := utils.KeyText(s)
16- if err != nil {
17- return fmt.Errorf("key not found")
18- }
19-
20- user, err := r.DBPool.FindUserForKey(s.User(), key)
21+ user, err := util.GetUser(s.Context())
22 if err != nil {
23 return err
24 }
25
26- if user.Name == "" {
27- return fmt.Errorf("must have username set")
28- }
29-
30- ff, _ := r.DBPool.FindFeatureForUser(user.ID, "plus")
31- // we have free tiers so users might not have a feature flag
32- // in which case we set sane defaults
33- if ff == nil {
34- ff = db.NewFeatureFlag(
35- user.ID,
36- "plus",
37- r.Cfg.MaxSize,
38- r.Cfg.MaxAssetSize,
39- )
40- }
41- // this is jank
42- ff.Data.StorageMax = ff.FindStorageMax(r.Cfg.MaxSize)
43- ff.Data.FileMax = ff.FindFileMax(r.Cfg.MaxAssetSize)
44-
45- util.SetUser(s, user)
46- util.SetFeatureFlag(s, ff)
47-
48- r.Cfg.Logger.Info("attempting to upload files", "user", user.Name, "space", r.Cfg.Space)
49+ r.Cfg.Logger.Info(
50+ "attempting to upload files",
51+ "user", user.Name,
52+ "space", r.Cfg.Space,
53+ )
54 return nil
55 }
+64,
-0
1@@ -0,0 +1,64 @@
2+package util
3+
4+import (
5+ "log/slog"
6+
7+ "github.com/charmbracelet/ssh"
8+ "github.com/picosh/pico/db"
9+ "github.com/picosh/pico/shared"
10+)
11+
12+type SshAuthHandler struct {
13+ DBPool db.DB
14+ Logger *slog.Logger
15+ Cfg *shared.ConfigSite
16+}
17+
18+func NewSshAuthHandler(dbpool db.DB, logger *slog.Logger, cfg *shared.ConfigSite) *SshAuthHandler {
19+ return &SshAuthHandler{
20+ DBPool: dbpool,
21+ Logger: logger,
22+ Cfg: cfg,
23+ }
24+}
25+
26+func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool {
27+ pubkey, err := shared.KeyForKeyText(key)
28+ if err != nil {
29+ return false
30+ }
31+
32+ user, err := r.DBPool.FindUserForKey(ctx.User(), pubkey)
33+ if err != nil {
34+ r.Logger.Error(
35+ "could not find user for key",
36+ "key", key,
37+ "err", err,
38+ )
39+ return false
40+ }
41+
42+ if user.Name == "" {
43+ r.Logger.Error("username is not set")
44+ return false
45+ }
46+
47+ ff, _ := r.DBPool.FindFeatureForUser(user.ID, "plus")
48+ // we have free tiers so users might not have a feature flag
49+ // in which case we set sane defaults
50+ if ff == nil {
51+ ff = db.NewFeatureFlag(
52+ user.ID,
53+ "plus",
54+ r.Cfg.MaxSize,
55+ r.Cfg.MaxAssetSize,
56+ )
57+ }
58+ // this is jank
59+ ff.Data.StorageMax = ff.FindStorageMax(r.Cfg.MaxSize)
60+ ff.Data.FileMax = ff.FindFileMax(r.Cfg.MaxAssetSize)
61+
62+ SetUser(ctx, user)
63+ SetFeatureFlag(ctx, ff)
64+ return true
65+}
+8,
-8
1@@ -10,26 +10,26 @@ import (
2 type ctxUserKey struct{}
3 type ctxFeatureFlagKey struct{}
4
5-func GetUser(s ssh.Session) (*db.User, error) {
6- user, ok := s.Context().Value(ctxUserKey{}).(*db.User)
7+func GetUser(ctx ssh.Context) (*db.User, error) {
8+ user, ok := ctx.Value(ctxUserKey{}).(*db.User)
9 if !ok {
10 return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
11 }
12 return user, nil
13 }
14
15-func SetUser(s ssh.Session, user *db.User) {
16- s.Context().SetValue(ctxUserKey{}, user)
17+func SetUser(ctx ssh.Context, user *db.User) {
18+ ctx.SetValue(ctxUserKey{}, user)
19 }
20
21-func GetFeatureFlag(s ssh.Session) (*db.FeatureFlag, error) {
22- ff, ok := s.Context().Value(ctxFeatureFlagKey{}).(*db.FeatureFlag)
23+func GetFeatureFlag(ctx ssh.Context) (*db.FeatureFlag, error) {
24+ ff, ok := ctx.Value(ctxFeatureFlagKey{}).(*db.FeatureFlag)
25 if !ok || ff.Name == "" {
26 return ff, fmt.Errorf("feature flag not set on `ssh.Context()` for connection")
27 }
28 return ff, nil
29 }
30
31-func SetFeatureFlag(s ssh.Session, ff *db.FeatureFlag) {
32- s.Context().SetValue(ctxFeatureFlagKey{}, ff)
33+func SetFeatureFlag(ctx ssh.Context, ff *db.FeatureFlag) {
34+ ctx.SetValue(ctxFeatureFlagKey{}, ff)
35 }
+3,
-9
1@@ -13,6 +13,7 @@ import (
2 "github.com/charmbracelet/wish"
3 "github.com/picosh/pico/db/postgres"
4 "github.com/picosh/pico/filehandlers"
5+ "github.com/picosh/pico/filehandlers/util"
6 "github.com/picosh/pico/shared"
7 "github.com/picosh/pico/shared/storage"
8 wsh "github.com/picosh/pico/wish"
9@@ -25,12 +26,6 @@ import (
10 "github.com/picosh/send/send/sftp"
11 )
12
13-type SSHServer struct{}
14-
15-func (me *SSHServer) authHandler(ctx ssh.Context, key ssh.PublicKey) bool {
16- return true
17-}
18-
19 func createRouter(handler *filehandlers.FileHandlerRouter) proxy.Router {
20 return func(sh ssh.Handler, s ssh.Session) []wish.Middleware {
21 return []wish.Middleware{
22@@ -86,12 +81,11 @@ func StartSshServer() {
23 "fallback": filehandlers.NewScpPostHandler(dbh, cfg, hooks, st),
24 }
25 handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap)
26-
27- sshServer := &SSHServer{}
28+ sshAuth := util.NewSshAuthHandler(dbh, logger, cfg)
29 s, err := wish.NewServer(
30 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
31 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
32- wish.WithPublicKeyAuth(sshServer.authHandler),
33+ wish.WithPublicKeyAuth(sshAuth.PubkeyAuthHandler),
34 withProxy(
35 handler,
36 promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "pastes-ssh"),
+4,
-4
1@@ -56,7 +56,7 @@ func (h *UploadHandler) Delete(s ssh.Session, entry *utils.FileEntry) error {
2 }
3
4 func (h *UploadHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo, utils.ReaderAtCloser, error) {
5- user, err := util.GetUser(s)
6+ user, err := util.GetUser(s.Context())
7 if err != nil {
8 return nil, nil, err
9 }
10@@ -80,7 +80,7 @@ func (h *UploadHandler) Read(s ssh.Session, entry *utils.FileEntry) (os.FileInfo
11
12 func (h *UploadHandler) List(s ssh.Session, fpath string, isDir bool, recursive bool) ([]os.FileInfo, error) {
13 var fileList []os.FileInfo
14- user, err := util.GetUser(s)
15+ user, err := util.GetUser(s.Context())
16 if err != nil {
17 return fileList, err
18 }
19@@ -135,7 +135,7 @@ func (h *UploadHandler) Validate(s ssh.Session) error {
20 return fmt.Errorf("must have username set")
21 }
22
23- util.SetUser(s, user)
24+ util.SetUser(s.Context(), user)
25 return nil
26 }
27
28@@ -282,7 +282,7 @@ func (h *UploadHandler) ProcessAuthorizedKeys(text []byte, logger *slog.Logger,
29
30 func (h *UploadHandler) Write(s ssh.Session, entry *utils.FileEntry) (string, error) {
31 logger := h.Cfg.Logger
32- user, err := util.GetUser(s)
33+ user, err := util.GetUser(s.Context())
34 if err != nil {
35 logger.Error(err.Error())
36 return "", err
+3,
-8
1@@ -14,6 +14,7 @@ import (
2 "github.com/picosh/pico/db/postgres"
3 "github.com/picosh/pico/filehandlers"
4 uploadimgs "github.com/picosh/pico/filehandlers/imgs"
5+ "github.com/picosh/pico/filehandlers/util"
6 "github.com/picosh/pico/shared"
7 "github.com/picosh/pico/shared/storage"
8 wsh "github.com/picosh/pico/wish"
9@@ -26,12 +27,6 @@ import (
10 "github.com/picosh/send/send/sftp"
11 )
12
13-type SSHServer struct{}
14-
15-func (me *SSHServer) authHandler(ctx ssh.Context, key ssh.PublicKey) bool {
16- return true
17-}
18-
19 func createRouter(handler *filehandlers.FileHandlerRouter, cliHandler *CliHandler) proxy.Router {
20 return func(sh ssh.Handler, s ssh.Session) []wish.Middleware {
21 return []wish.Middleware{
22@@ -97,11 +92,11 @@ func StartSshServer() {
23 DBPool: dbh,
24 }
25
26- sshServer := &SSHServer{}
27+ sshAuth := util.NewSshAuthHandler(dbh, logger, cfg)
28 s, err := wish.NewServer(
29 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
30 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
31- wish.WithPublicKeyAuth(sshServer.authHandler),
32+ wish.WithPublicKeyAuth(sshAuth.PubkeyAuthHandler),
33 withProxy(
34 handler,
35 cliHandler,
1@@ -0,0 +1,7 @@
2+package shared
3+
4+import "github.com/charmbracelet/ssh"
5+
6+func PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool {
7+ return true
8+}