repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / imgs
Eric Bower · 20 Sep 24

ssh.go

  1package imgs
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"log"
 10	"log/slog"
 11	"net/http"
 12	"net/http/httputil"
 13	"net/url"
 14	"os"
 15	"os/signal"
 16	"strconv"
 17	"strings"
 18	"syscall"
 19	"time"
 20
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	"github.com/google/uuid"
 24	"github.com/picosh/pico/db"
 25	"github.com/picosh/pico/db/postgres"
 26	"github.com/picosh/pico/shared"
 27	"github.com/picosh/pico/shared/storage"
 28	psub "github.com/picosh/pubsub"
 29	"github.com/picosh/tunkit"
 30)
 31
 32type ctxUserKey struct{}
 33
 34func getUserCtx(ctx ssh.Context) (*db.User, error) {
 35	user, ok := ctx.Value(ctxUserKey{}).(*db.User)
 36	if user == nil || !ok {
 37		return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
 38	}
 39	return user, nil
 40}
 41func setUserCtx(ctx ssh.Context, user *db.User) {
 42	ctx.SetValue(ctxUserKey{}, user)
 43}
 44
 45func AuthHandler(dbh db.DB, log *slog.Logger) func(ssh.Context, ssh.PublicKey) bool {
 46	return func(ctx ssh.Context, key ssh.PublicKey) bool {
 47		kk, err := shared.KeyForKeyText(key)
 48		if err != nil {
 49			log.Error("cannot get pubkey", "err", err)
 50			return false
 51		}
 52
 53		user, err := dbh.FindUserForKey("", kk)
 54		if err != nil {
 55			log.Error("user not found", "err", err)
 56			return false
 57		}
 58
 59		if user == nil {
 60			log.Error("user not found", "err", err)
 61			return false
 62		}
 63
 64		setUserCtx(ctx, user)
 65
 66		if !dbh.HasFeatureForUser(user.ID, "plus") {
 67			log.Error("not a pico+ user", "user", user.Name)
 68			return false
 69		}
 70
 71		return true
 72	}
 73}
 74
 75type ErrorHandler struct {
 76	Err error
 77}
 78
 79func (e *ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 80	log.Println(e.Err.Error())
 81	http.Error(w, e.Err.Error(), http.StatusInternalServerError)
 82}
 83
 84func createServeMux(handler *CliHandler, pubsub *psub.Cfg) func(ctx ssh.Context) http.Handler {
 85	return func(ctx ssh.Context) http.Handler {
 86		router := http.NewServeMux()
 87
 88		slug := ""
 89		user, err := getUserCtx(ctx)
 90		if err == nil && user != nil {
 91			slug = user.Name
 92		}
 93
 94		proxy := httputil.NewSingleHostReverseProxy(&url.URL{
 95			Scheme: "http",
 96			Host:   handler.RegistryUrl,
 97		})
 98
 99		oldDirector := proxy.Director
100
101		proxy.Director = func(r *http.Request) {
102			handler.Logger.Info("director", "request", r)
103			oldDirector(r)
104
105			if strings.HasSuffix(r.URL.Path, "_catalog") || r.URL.Path == "/v2" || r.URL.Path == "/v2/" {
106				return
107			}
108
109			fullPath := strings.TrimPrefix(r.URL.Path, "/v2")
110
111			newPath, err := url.JoinPath("/v2", slug, fullPath)
112			if err != nil {
113				return
114			}
115
116			r.URL.Path = newPath
117
118			query := r.URL.Query()
119
120			if query.Has("from") {
121				joinedFrom, err := url.JoinPath(slug, query.Get("from"))
122				if err != nil {
123					return
124				}
125				query.Set("from", joinedFrom)
126
127				r.URL.RawQuery = query.Encode()
128			}
129		}
130
131		proxy.ModifyResponse = func(r *http.Response) error {
132			handler.Logger.Info("modify", "request", r)
133			shared.CorsHeaders(r.Header)
134
135			if r.Request.Method == http.MethodGet && strings.HasSuffix(r.Request.URL.Path, "_catalog") {
136				b, err := io.ReadAll(r.Body)
137				if err != nil {
138					return err
139				}
140
141				err = r.Body.Close()
142				if err != nil {
143					return err
144				}
145
146				var data map[string]any
147				err = json.Unmarshal(b, &data)
148				if err != nil {
149					return err
150				}
151
152				newRepos := []string{}
153
154				if repos, ok := data["repositories"].([]any); ok {
155					for _, repo := range repos {
156						if repoStr, ok := repo.(string); ok && strings.HasPrefix(repoStr, slug) {
157							newRepos = append(newRepos, strings.Replace(repoStr, fmt.Sprintf("%s/", slug), "", 1))
158						}
159					}
160				}
161
162				data["repositories"] = newRepos
163
164				newB, err := json.Marshal(data)
165				if err != nil {
166					return err
167				}
168
169				jsonBuf := bytes.NewBuffer(newB)
170
171				r.ContentLength = int64(jsonBuf.Len())
172				r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
173				r.Body = io.NopCloser(jsonBuf)
174			}
175
176			if r.Request.Method == http.MethodGet && (strings.Contains(r.Request.URL.Path, "/tags/") || strings.Contains(r.Request.URL.Path, "/manifests/")) {
177				splitPath := strings.Split(r.Request.URL.Path, "/")
178
179				if len(splitPath) > 1 {
180					ele := splitPath[len(splitPath)-2]
181					if ele == "tags" || ele == "manifests" {
182						b, err := io.ReadAll(r.Body)
183						if err != nil {
184							return err
185						}
186
187						err = r.Body.Close()
188						if err != nil {
189							return err
190						}
191
192						var data map[string]any
193						err = json.Unmarshal(b, &data)
194						if err != nil {
195							return err
196						}
197
198						if name, ok := data["name"].(string); ok {
199							if strings.HasPrefix(name, slug) {
200								data["name"] = strings.Replace(name, fmt.Sprintf("%s/", slug), "", 1)
201							}
202						}
203
204						newB, err := json.Marshal(data)
205						if err != nil {
206							return err
207						}
208
209						jsonBuf := bytes.NewBuffer(newB)
210
211						r.ContentLength = int64(jsonBuf.Len())
212						r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
213						r.Body = io.NopCloser(jsonBuf)
214					}
215				}
216			}
217
218			if r.Request.Method == http.MethodPut && strings.Contains(r.Request.URL.Path, "/manifests/") {
219				digest := r.Header.Get("Docker-Content-Digest")
220				// [ ]/v2/erock/alpine/manifests/latest
221				splitPath := strings.Split(r.Request.URL.Path, "/")
222				img := splitPath[3]
223				tag := splitPath[5]
224
225				furl := fmt.Sprintf(
226					"digest=%s&image=%s&tag=%s",
227					url.QueryEscape(digest),
228					img,
229					tag,
230				)
231				handler.Logger.Info("sending event", "url", furl)
232
233				err := pubsub.PubSub.Pub(fmt.Sprintf("%s@%s:%s", user.Name, img, tag), &psub.Pub{
234					ID:     uuid.NewString(),
235					Done:   make(chan struct{}),
236					Reader: strings.NewReader(furl),
237				})
238
239				if err != nil {
240					handler.Logger.Error("pub error", "err", err)
241				}
242			}
243
244			locationHeader := r.Header.Get("location")
245			if strings.Contains(locationHeader, fmt.Sprintf("/v2/%s", slug)) {
246				r.Header.Set("location", strings.ReplaceAll(locationHeader, fmt.Sprintf("/v2/%s", slug), "/v2"))
247			}
248
249			return nil
250		}
251
252		router.HandleFunc("/", proxy.ServeHTTP)
253
254		return router
255	}
256}
257
258func StartSshServer() {
259	host := os.Getenv("SSH_HOST")
260	if host == "" {
261		host = "0.0.0.0"
262	}
263	port := os.Getenv("SSH_PORT")
264	if port == "" {
265		port = "2222"
266	}
267	dbUrl := os.Getenv("DATABASE_URL")
268	registryUrl := shared.GetEnv("REGISTRY_URL", "0.0.0.0:5000")
269	minioUrl := shared.GetEnv("MINIO_URL", "http://0.0.0.0:9000")
270	minioUser := shared.GetEnv("MINIO_ROOT_USER", "")
271	minioPass := shared.GetEnv("MINIO_ROOT_PASSWORD", "")
272
273	logger := shared.CreateLogger("imgs")
274	logger.Info("bootup", "registry", registryUrl, "minio", minioUrl)
275	dbh := postgres.NewDB(dbUrl, logger)
276	st, err := storage.NewStorageMinio(minioUrl, minioUser, minioPass)
277	if err != nil {
278		panic(err)
279	}
280
281	pubsub := &psub.Cfg{
282		Logger: logger,
283		PubSub: &psub.PubSubMulticast{
284			Logger: logger,
285		},
286	}
287
288	handler := &CliHandler{
289		Logger:      logger,
290		DBPool:      dbh,
291		Storage:     st,
292		RegistryUrl: registryUrl,
293		PubSub:      pubsub,
294	}
295
296	s, err := wish.NewServer(
297		wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
298		wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
299		wish.WithPublicKeyAuth(AuthHandler(dbh, logger)),
300		wish.WithMiddleware(WishMiddleware(handler)),
301		tunkit.WithWebTunnel(
302			tunkit.NewWebTunnelHandler(createServeMux(handler, pubsub), logger),
303		),
304	)
305
306	if err != nil {
307		logger.Error("could not create server", "err", err)
308	}
309
310	done := make(chan os.Signal, 1)
311	signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
312	logger.Info("starting SSH server", "host", host, "port", port)
313	go func() {
314		if err = s.ListenAndServe(); err != nil {
315			logger.Error("serve error", "err", err)
316			os.Exit(1)
317		}
318	}()
319
320	<-done
321	logger.Info("stopping SSH server")
322	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
323	defer func() { cancel() }()
324	if err := s.Shutdown(ctx); err != nil {
325		logger.Error("shutdown", "err", err)
326		os.Exit(1)
327	}
328}