repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

commit
4a514c1
parent
090d8de
author
Eric Bower
date
2024-03-03 12:47:50 +0000 UTC
feat(pgs): pico-ui
7 files changed,  +406, -135
M db/db.go
+20, -12
 1@@ -28,6 +28,12 @@ type User struct {
 2 	CreatedAt *time.Time `json:"created_at"`
 3 }
 4 
 5+type PicoApi struct {
 6+	UserID    string `json:"user_id"`
 7+	UserName  string `json:"username"`
 8+	PublicKey string `json:"pubkey"`
 9+}
10+
11 type PostData struct {
12 	ImgPath    string     `json:"img_path"`
13 	LastDigest *time.Time `json:"last_digest"`
14@@ -165,21 +171,21 @@ type FeedItem struct {
15 }
16 
17 type Token struct {
18-	ID        string
19-	UserID    string
20-	Name      string
21-	CreatedAt *time.Time
22-	ExpiresAt *time.Time
23+	ID        string     `json:"id"`
24+	UserID    string     `json:"user_id"`
25+	Name      string     `json:"name"`
26+	CreatedAt *time.Time `json:"created_at"`
27+	ExpiresAt *time.Time `json:"expires_at"`
28 }
29 
30 type FeatureFlag struct {
31-	ID               string
32-	UserID           string
33-	PaymentHistoryID string
34-	Name             string
35-	CreatedAt        *time.Time
36-	ExpiresAt        *time.Time
37-	Data             FeatureFlagData
38+	ID               string          `json:"id"`
39+	UserID           string          `json:"user_id"`
40+	PaymentHistoryID string          `json:"payment_history_id"`
41+	Name             string          `json:"name"`
42+	CreatedAt        *time.Time      `json:"created_at"`
43+	ExpiresAt        *time.Time      `json:"expires_at"`
44+	Data             FeatureFlagData `json:"data"`
45 }
46 
47 func NewFeatureFlag(userID, name string, storageMax uint64, fileMax int64) *FeatureFlag {
48@@ -297,6 +303,7 @@ type DB interface {
49 	FindUserForToken(token string) (*User, error)
50 	FindTokensForUser(userID string) ([]*Token, error)
51 	InsertToken(userID, name string) (string, error)
52+	FindRssToken(userID string) (string, error)
53 	RemoveToken(tokenID string) error
54 
55 	FindPosts() ([]*Post, error)
56@@ -326,6 +333,7 @@ type DB interface {
57 
58 	AddPicoPlusUser(username string, paymentType, txId string) error
59 	FindFeatureForUser(userID string, feature string) (*FeatureFlag, error)
60+	FindFeaturesForUser(userID string) ([]*FeatureFlag, error)
61 	HasFeatureForUser(userID string, feature string) bool
62 	FindTotalSizeForUser(userID string) (int, error)
63 
M db/postgres/storage.go
+45, -3
 1@@ -148,9 +148,10 @@ const (
 2 	FROM app_users
 3 	LEFT JOIN tokens ON tokens.user_id = app_users.id
 4 	WHERE tokens.token = $1 AND tokens.expires_at > NOW()`
 5-	sqlInsertToken         = `INSERT INTO tokens (user_id, name) VALUES($1, $2) RETURNING token;`
 6-	sqlRemoveToken         = `DELETE FROM tokens WHERE id = $1`
 7-	sqlSelectTokensForUser = `SELECT id, user_id, name, created_at, expires_at FROM tokens WHERE user_id = $1`
 8+	sqlInsertToken           = `INSERT INTO tokens (user_id, name) VALUES($1, $2) RETURNING token;`
 9+	sqlRemoveToken           = `DELETE FROM tokens WHERE id = $1`
10+	sqlSelectTokensForUser   = `SELECT id, user_id, name, created_at, expires_at FROM tokens WHERE user_id = $1`
11+	sqlSelectRssTokenForUser = `SELECT token FROM tokens WHERE user_id = $1 AND name = 'pico-rss'`
12 
13 	sqlSelectTotalUsers          = `SELECT count(id) FROM app_users`
14 	sqlSelectUsersAfterDate      = `SELECT count(id) FROM app_users WHERE created_at >= $1`
15@@ -1199,6 +1200,38 @@ func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.Feature
16 	return ff, nil
17 }
18 
19+func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error) {
20+	var features []*db.FeatureFlag
21+	query := "SELECT id, user_id, payment_history_id, name, data, created_at, expires_at FROM feature_flags WHERE user_id=$1"
22+	rs, err := me.Db.Query(query, userID)
23+	if err != nil {
24+		return features, err
25+	}
26+	for rs.Next() {
27+		var paymentHistoryID sql.NullString
28+		ff := &db.FeatureFlag{}
29+		err := rs.Scan(
30+			&ff.ID,
31+			&ff.UserID,
32+			&paymentHistoryID,
33+			&ff.Name,
34+			&ff.Data,
35+			&ff.CreatedAt,
36+			&ff.ExpiresAt,
37+		)
38+		if err != nil {
39+			return features, err
40+		}
41+		ff.Name = paymentHistoryID.String
42+
43+		features = append(features, ff)
44+	}
45+	if rs.Err() != nil {
46+		return features, rs.Err()
47+	}
48+	return features, nil
49+}
50+
51 func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
52 	ff, err := me.FindFeatureForUser(userID, feature)
53 	if err != nil {
54@@ -1515,6 +1548,15 @@ func (me *PsqlDB) InsertToken(userID, name string) (string, error) {
55 	return token, nil
56 }
57 
58+func (me *PsqlDB) FindRssToken(userID string) (string, error) {
59+	var token string
60+	err := me.Db.QueryRow(sqlSelectRssTokenForUser, userID).Scan(&token)
61+	if err != nil {
62+		return "", err
63+	}
64+	return token, nil
65+}
66+
67 func (me *PsqlDB) RemoveToken(tokenID string) error {
68 	_, err := me.Db.Exec(sqlRemoveToken, tokenID)
69 	return err
M pgs/ssh.go
+1, -120
  1@@ -2,9 +2,7 @@ package pgs
  2 
  3 import (
  4 	"context"
  5-	"encoding/json"
  6 	"fmt"
  7-	"net/http"
  8 	"os"
  9 	"os/signal"
 10 	"syscall"
 11@@ -14,7 +12,6 @@ import (
 12 	"github.com/charmbracelet/ssh"
 13 	"github.com/charmbracelet/wish"
 14 	bm "github.com/charmbracelet/wish/bubbletea"
 15-	"github.com/picosh/pico/db"
 16 	"github.com/picosh/pico/db/postgres"
 17 	uploadassets "github.com/picosh/pico/filehandlers/assets"
 18 	"github.com/picosh/pico/shared"
 19@@ -30,21 +27,8 @@ import (
 20 	"github.com/picosh/send/send/sftp"
 21 )
 22 
 23-type ctxPublicKey struct{}
 24-
 25-func getPublicKeyCtx(ctx ssh.Context) (ssh.PublicKey, error) {
 26-	pk, ok := ctx.Value(ctxPublicKey{}).(ssh.PublicKey)
 27-	if !ok {
 28-		return nil, fmt.Errorf("public key not set on `ssh.Context()` for connection")
 29-	}
 30-	return pk, nil
 31-}
 32-func setPublicKeyCtx(ctx ssh.Context, pk ssh.PublicKey) {
 33-	ctx.SetValue(ctxPublicKey{}, pk)
 34-}
 35-
 36 func authHandler(ctx ssh.Context, key ssh.PublicKey) bool {
 37-	setPublicKeyCtx(ctx, key)
 38+	shared.SetPublicKeyCtx(ctx, key)
 39 	return true
 40 }
 41 
 42@@ -74,109 +58,6 @@ func withProxy(cfg *shared.ConfigSite, handler *uploadassets.UploadAssetHandler,
 43 	}
 44 }
 45 
 46-func unauthorizedHandler(w http.ResponseWriter, r *http.Request) {
 47-	http.Error(w, "You do not have access to this site", http.StatusUnauthorized)
 48-}
 49-
 50-func allowPerm(proj *db.Project) bool {
 51-	return true
 52-}
 53-
 54-type PicoApi struct {
 55-	UserID    string `json:"user_id"`
 56-	UserName  string `json:"username"`
 57-	PublicKey string `json:"public_key"`
 58-}
 59-
 60-type CtxHttpBridge = func(ssh.Context) http.Handler
 61-
 62-func createHttpHandler(httpCtx *shared.HttpCtx) CtxHttpBridge {
 63-	return func(ctx ssh.Context) http.Handler {
 64-		subdomain := ctx.User()
 65-		dbh := httpCtx.Dbpool
 66-		logger := httpCtx.Cfg.Logger
 67-		log := logger.With(
 68-			"subdomain", subdomain,
 69-		)
 70-
 71-		pubkey, err := getPublicKeyCtx(ctx)
 72-		if err != nil {
 73-			log.Error(err.Error(), "subdomain", subdomain)
 74-			return http.HandlerFunc(unauthorizedHandler)
 75-		}
 76-		pubkeyStr, err := shared.KeyForKeyText(pubkey)
 77-		if err != nil {
 78-			log.Error(err.Error())
 79-			return http.HandlerFunc(unauthorizedHandler)
 80-		}
 81-		log = log.With(
 82-			"pubkey", pubkeyStr,
 83-		)
 84-
 85-		props, err := getProjectFromSubdomain(subdomain)
 86-		if err != nil {
 87-			log.Error(err.Error())
 88-			return http.HandlerFunc(unauthorizedHandler)
 89-		}
 90-
 91-		owner, err := dbh.FindUserForName(props.Username)
 92-		if err != nil {
 93-			log.Error(err.Error())
 94-			return http.HandlerFunc(unauthorizedHandler)
 95-		}
 96-		log = log.With(
 97-			"owner", owner.Name,
 98-		)
 99-
100-		project, err := dbh.FindProjectByName(owner.ID, props.ProjectName)
101-		if err != nil {
102-			log.Error(err.Error())
103-			return http.HandlerFunc(unauthorizedHandler)
104-		}
105-
106-		requester, _ := dbh.FindUserForKey("", pubkeyStr)
107-		if requester != nil {
108-			log = logger.With(
109-				"requester", requester.Name,
110-			)
111-		}
112-
113-		if !HasProjectAccess(project, owner, requester, pubkey) {
114-			log.Error("no access")
115-			return http.HandlerFunc(unauthorizedHandler)
116-		}
117-
118-		log.Info("user has access to site")
119-
120-		routes := []shared.Route{
121-			// special API endpoint for tunnel users accessing site
122-			shared.NewRoute("GET", "/pico", func(w http.ResponseWriter, r *http.Request) {
123-				w.Header().Set("Content-Type", "application/json")
124-				pico := &PicoApi{
125-					UserID:    "",
126-					UserName:  "",
127-					PublicKey: pubkeyStr,
128-				}
129-				if requester != nil {
130-					pico.UserID = requester.ID
131-					pico.UserName = requester.Name
132-				}
133-				err := json.NewEncoder(w).Encode(pico)
134-				if err != nil {
135-					log.Error(err.Error())
136-				}
137-			}),
138-		}
139-
140-		subdomainRoutes := createSubdomainRoutes(allowPerm)
141-		routes = append(routes, subdomainRoutes...)
142-		finctx := httpCtx.CreateCtx(ctx, subdomain)
143-		httpHandler := shared.CreateServeBasic(routes, finctx)
144-		httpRouter := http.HandlerFunc(httpHandler)
145-		return httpRouter
146-	}
147-}
148-
149 func StartSshServer() {
150 	host := shared.GetEnv("PGS_HOST", "0.0.0.0")
151 	port := shared.GetEnv("PGS_SSH_PORT", "2222")
A pgs/tunnel.go
+109, -0
  1@@ -0,0 +1,109 @@
  2+package pgs
  3+
  4+import (
  5+	"encoding/json"
  6+	"net/http"
  7+
  8+	"github.com/charmbracelet/ssh"
  9+	"github.com/picosh/pico/db"
 10+	"github.com/picosh/pico/plus"
 11+	"github.com/picosh/pico/shared"
 12+)
 13+
 14+func allowPerm(proj *db.Project) bool {
 15+	return true
 16+}
 17+
 18+type CtxHttpBridge = func(ssh.Context) http.Handler
 19+
 20+func createHttpHandler(httpCtx *shared.HttpCtx) CtxHttpBridge {
 21+	return func(ctx ssh.Context) http.Handler {
 22+		subdomain := ctx.User()
 23+		dbh := httpCtx.Dbpool
 24+		logger := httpCtx.Cfg.Logger
 25+		log := logger.With(
 26+			"subdomain", subdomain,
 27+		)
 28+
 29+		pubkey, err := shared.GetPublicKeyCtx(ctx)
 30+		if err != nil {
 31+			log.Error(err.Error(), "subdomain", subdomain)
 32+			return http.HandlerFunc(shared.UnauthorizedHandler)
 33+		}
 34+		pubkeyStr, err := shared.KeyForKeyText(pubkey)
 35+		if err != nil {
 36+			log.Error(err.Error())
 37+			return http.HandlerFunc(shared.UnauthorizedHandler)
 38+		}
 39+		log = log.With(
 40+			"pubkey", pubkeyStr,
 41+		)
 42+
 43+		props, err := getProjectFromSubdomain(subdomain)
 44+		if err != nil {
 45+			log.Error(err.Error())
 46+			return http.HandlerFunc(shared.UnauthorizedHandler)
 47+		}
 48+
 49+		owner, err := dbh.FindUserForName(props.Username)
 50+		if err != nil {
 51+			log.Error(err.Error())
 52+			return http.HandlerFunc(shared.UnauthorizedHandler)
 53+		}
 54+		log = log.With(
 55+			"owner", owner.Name,
 56+		)
 57+
 58+		project, err := dbh.FindProjectByName(owner.ID, props.ProjectName)
 59+		if err != nil {
 60+			log.Error(err.Error())
 61+			return http.HandlerFunc(shared.UnauthorizedHandler)
 62+		}
 63+
 64+		requester, _ := dbh.FindUserForKey("", pubkeyStr)
 65+		if requester != nil {
 66+			log = logger.With(
 67+				"requester", requester.Name,
 68+			)
 69+		}
 70+
 71+		if !HasProjectAccess(project, owner, requester, pubkey) {
 72+			log.Error("no access")
 73+			return http.HandlerFunc(shared.UnauthorizedHandler)
 74+		}
 75+
 76+		log.Info("user has access to site")
 77+
 78+		routes := []shared.Route{
 79+			// special API endpoint for tunnel users accessing site
 80+			shared.NewRoute("GET", "/api/current_user", func(w http.ResponseWriter, r *http.Request) {
 81+				w.Header().Set("Content-Type", "application/json")
 82+				pico := &db.PicoApi{
 83+					UserID:    "",
 84+					UserName:  "",
 85+					PublicKey: pubkeyStr,
 86+				}
 87+				if requester != nil {
 88+					pico.UserID = requester.ID
 89+					pico.UserName = requester.Name
 90+				}
 91+				err := json.NewEncoder(w).Encode(pico)
 92+				if err != nil {
 93+					log.Error(err.Error())
 94+				}
 95+			}),
 96+		}
 97+
 98+		if subdomain == "pico-ui" || subdomain == "erock-ui" {
 99+			rts := plus.CreateRoutes(httpCtx, ctx)
100+			routes = append(routes, rts...)
101+		}
102+
103+		subdomainRoutes := createSubdomainRoutes(allowPerm)
104+		routes = append(routes, subdomainRoutes...)
105+		finctx := httpCtx.CreateCtx(ctx, subdomain)
106+		httpHandler := shared.CreateServeBasic(routes, finctx)
107+		httpRouter := http.HandlerFunc(httpHandler)
108+		return httpRouter
109+	}
110+}
A plus/routes.go
+196, -0
  1@@ -0,0 +1,196 @@
  2+package plus
  3+
  4+import (
  5+	"encoding/json"
  6+	"fmt"
  7+	"io"
  8+	"net/http"
  9+
 10+	"github.com/charmbracelet/ssh"
 11+	"github.com/picosh/pico/db"
 12+	"github.com/picosh/pico/shared"
 13+)
 14+
 15+type registerPayload struct {
 16+	Name string `json:"name"`
 17+}
 18+
 19+func registerUser(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
 20+	logger := httpCtx.Cfg.Logger
 21+	return func(w http.ResponseWriter, r *http.Request) {
 22+		dbpool := shared.GetDB(r)
 23+		var payload registerPayload
 24+		body, _ := io.ReadAll(r.Body)
 25+		_ = json.Unmarshal(body, &payload)
 26+
 27+		user, err := dbpool.RegisterUser(payload.Name, pubkey)
 28+		if err != nil {
 29+			errMsg := fmt.Sprintf("error registering user: %s", err.Error())
 30+			logger.Info(errMsg)
 31+			shared.JSONError(w, errMsg, http.StatusUnprocessableEntity)
 32+			return
 33+		}
 34+
 35+		w.Header().Set("Content-Type", "application/json")
 36+		pico := &db.PicoApi{
 37+			UserID:    user.ID,
 38+			UserName:  user.Name,
 39+			PublicKey: pubkey,
 40+		}
 41+		err = json.NewEncoder(w).Encode(pico)
 42+		if err != nil {
 43+			logger.Error(err.Error())
 44+		}
 45+	}
 46+}
 47+
 48+type featuresPayload struct {
 49+	Features []*db.FeatureFlag `json:"features"`
 50+}
 51+
 52+func getFeatures(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
 53+	logger := httpCtx.Cfg.Logger
 54+	return func(w http.ResponseWriter, r *http.Request) {
 55+		w.Header().Set("Content-Type", "application/json")
 56+
 57+		dbpool := shared.GetDB(r)
 58+		user, err := dbpool.FindUserForKey("", pubkey)
 59+		if err != nil {
 60+			shared.JSONError(w, "User not found", http.StatusNotFound)
 61+			return
 62+		}
 63+
 64+		features, err := dbpool.FindFeaturesForUser(user.ID)
 65+		if err != nil {
 66+			shared.JSONError(w, err.Error(), http.StatusUnprocessableEntity)
 67+			return
 68+		}
 69+
 70+		if features == nil {
 71+			features = []*db.FeatureFlag{}
 72+		}
 73+		err = json.NewEncoder(w).Encode(&featuresPayload{Features: features})
 74+		if err != nil {
 75+			logger.Error(err.Error())
 76+		}
 77+	}
 78+}
 79+
 80+type rssTokenPayload struct {
 81+	Token string `json:"token"`
 82+}
 83+
 84+func findOrCreateRssToken(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
 85+	logger := httpCtx.Cfg.Logger
 86+	return func(w http.ResponseWriter, r *http.Request) {
 87+		dbpool := shared.GetDB(r)
 88+		user, err := dbpool.FindUserForKey("", pubkey)
 89+		if err != nil {
 90+			shared.JSONError(w, "User not found", http.StatusUnprocessableEntity)
 91+			return
 92+		}
 93+
 94+		rssToken, err := dbpool.FindRssToken(user.ID)
 95+		if err != nil {
 96+			shared.JSONError(w, err.Error(), http.StatusUnprocessableEntity)
 97+			return
 98+		}
 99+
100+		if rssToken == "" {
101+			rssToken, err = dbpool.InsertToken(user.ID, "pico-rss")
102+			if err != nil {
103+				shared.JSONError(w, err.Error(), http.StatusUnprocessableEntity)
104+				return
105+			}
106+		}
107+
108+		w.Header().Set("Content-Type", "application/json")
109+		err = json.NewEncoder(w).Encode(&rssTokenPayload{Token: rssToken})
110+		if err != nil {
111+			logger.Error(err.Error())
112+		}
113+	}
114+}
115+
116+type pubkeysPayload struct {
117+	Pubkeys []*db.PublicKey `json:"pubkeys"`
118+}
119+
120+func getPublicKeys(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
121+	logger := httpCtx.Cfg.Logger
122+	return func(w http.ResponseWriter, r *http.Request) {
123+		dbpool := shared.GetDB(r)
124+		user, err := dbpool.FindUserForKey("", pubkey)
125+		if err != nil {
126+			shared.JSONError(w, "User not found", http.StatusUnprocessableEntity)
127+			return
128+		}
129+
130+		pubkeys, err := dbpool.FindKeysForUser(user)
131+		if err != nil {
132+			shared.JSONError(w, err.Error(), http.StatusUnprocessableEntity)
133+			return
134+		}
135+
136+		w.Header().Set("Content-Type", "application/json")
137+		err = json.NewEncoder(w).Encode(&pubkeysPayload{Pubkeys: pubkeys})
138+		if err != nil {
139+			logger.Error(err.Error())
140+		}
141+	}
142+}
143+
144+type tokensPayload struct {
145+	Tokens []*db.Token `json:"tokens"`
146+}
147+
148+func getTokens(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
149+	logger := httpCtx.Cfg.Logger
150+	return func(w http.ResponseWriter, r *http.Request) {
151+		dbpool := shared.GetDB(r)
152+		user, err := dbpool.FindUserForKey("", pubkey)
153+		if err != nil {
154+			shared.JSONError(w, "User not found", http.StatusUnprocessableEntity)
155+			return
156+		}
157+
158+		tokens, err := dbpool.FindTokensForUser(user.ID)
159+		if err != nil {
160+			shared.JSONError(w, err.Error(), http.StatusUnprocessableEntity)
161+			return
162+		}
163+
164+		if tokens == nil {
165+			tokens = []*db.Token{}
166+		}
167+
168+		w.Header().Set("Content-Type", "application/json")
169+		err = json.NewEncoder(w).Encode(&tokensPayload{Tokens: tokens})
170+		if err != nil {
171+			logger.Error(err.Error())
172+		}
173+	}
174+}
175+
176+func CreateRoutes(httpCtx *shared.HttpCtx, ctx ssh.Context) []shared.Route {
177+	logger := httpCtx.Cfg.Logger
178+	pubkey, err := shared.GetPublicKeyCtx(ctx)
179+	if err != nil {
180+		logger.Error("could not get pubkey from ctx", "err", err.Error())
181+		return []shared.Route{}
182+	}
183+
184+	pubkeyStr, err := shared.KeyForKeyText(pubkey)
185+	if err != nil {
186+		logger.Error("could not convert key to text", "err", err.Error())
187+		return []shared.Route{}
188+	}
189+
190+	return []shared.Route{
191+		shared.NewRoute("POST", "/api/users", registerUser(httpCtx, ctx, pubkeyStr)),
192+		shared.NewRoute("GET", "/api/features", getFeatures(httpCtx, ctx, pubkeyStr)),
193+		shared.NewRoute("PUT", "/api/rss-token", findOrCreateRssToken(httpCtx, ctx, pubkeyStr)),
194+		shared.NewRoute("GET", "/api/pubkeys", getPublicKeys(httpCtx, ctx, pubkeyStr)),
195+		shared.NewRoute("GET", "/api/tokens", getTokens(httpCtx, ctx, pubkeyStr)),
196+	}
197+}
M shared/api.go
+15, -0
 1@@ -1,6 +1,7 @@
 2 package shared
 3 
 4 import (
 5+	"encoding/json"
 6 	"fmt"
 7 	"html/template"
 8 	"net/http"
 9@@ -8,6 +9,20 @@ import (
10 	"strings"
11 )
12 
13+func UnauthorizedHandler(w http.ResponseWriter, r *http.Request) {
14+	http.Error(w, "You do not have access to this site", http.StatusUnauthorized)
15+}
16+
17+type errPayload struct {
18+	Message string `json:"message"`
19+}
20+
21+func JSONError(w http.ResponseWriter, msg string, code int) {
22+	w.Header().Set("Content-Type", "application/json")
23+	w.WriteHeader(code)
24+	_ = json.NewEncoder(w).Encode(errPayload{Message: msg})
25+}
26+
27 func CheckHandler(w http.ResponseWriter, r *http.Request) {
28 	dbpool := GetDB(r)
29 	cfg := GetCfg(r)
A shared/tunnel.go
+20, -0
 1@@ -0,0 +1,20 @@
 2+package shared
 3+
 4+import (
 5+	"fmt"
 6+
 7+	"github.com/charmbracelet/ssh"
 8+)
 9+
10+type ctxPublicKey struct{}
11+
12+func GetPublicKeyCtx(ctx ssh.Context) (ssh.PublicKey, error) {
13+	pk, ok := ctx.Value(ctxPublicKey{}).(ssh.PublicKey)
14+	if !ok {
15+		return nil, fmt.Errorf("public key not set on `ssh.Context()` for connection")
16+	}
17+	return pk, nil
18+}
19+func SetPublicKeyCtx(ctx ssh.Context, pk ssh.PublicKey) {
20+	ctx.SetValue(ctxPublicKey{}, pk)
21+}