repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / imgs
Antonio Mika · 11 Oct 24

ssh.go

  1package imgs
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"log"
 10	"log/slog"
 11	"net/http"
 12	"net/http/httputil"
 13	"net/url"
 14	"os"
 15	"os/signal"
 16	"strconv"
 17	"strings"
 18	"syscall"
 19	"time"
 20
 21	"github.com/charmbracelet/promwish"
 22	"github.com/charmbracelet/ssh"
 23	"github.com/charmbracelet/wish"
 24	"github.com/google/uuid"
 25	"github.com/picosh/pico/db"
 26	"github.com/picosh/pico/db/postgres"
 27	"github.com/picosh/pico/shared"
 28	"github.com/picosh/pico/shared/storage"
 29	wsh "github.com/picosh/pico/wish"
 30	psub "github.com/picosh/pubsub"
 31	"github.com/picosh/tunkit"
 32	"github.com/picosh/utils"
 33)
 34
 35type ctxUserKey struct{}
 36
 37func getUserCtx(ctx ssh.Context) (*db.User, error) {
 38	user, ok := ctx.Value(ctxUserKey{}).(*db.User)
 39	if user == nil || !ok {
 40		return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
 41	}
 42	return user, nil
 43}
 44func setUserCtx(ctx ssh.Context, user *db.User) {
 45	ctx.SetValue(ctxUserKey{}, user)
 46}
 47
 48func AuthHandler(dbh db.DB, log *slog.Logger) func(ssh.Context, ssh.PublicKey) bool {
 49	return func(ctx ssh.Context, key ssh.PublicKey) bool {
 50		kk := utils.KeyForKeyText(key)
 51
 52		user, err := dbh.FindUserForKey("", kk)
 53		if err != nil {
 54			log.Error("user not found", "err", err)
 55			return false
 56		}
 57
 58		if user == nil {
 59			log.Error("user not found", "err", err)
 60			return false
 61		}
 62
 63		setUserCtx(ctx, user)
 64
 65		if !dbh.HasFeatureForUser(user.ID, "plus") {
 66			log.Error("not a pico+ user", "user", user.Name)
 67			return false
 68		}
 69
 70		return true
 71	}
 72}
 73
 74type ErrorHandler struct {
 75	Err error
 76}
 77
 78func (e *ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 79	log.Println(e.Err.Error())
 80	http.Error(w, e.Err.Error(), http.StatusInternalServerError)
 81}
 82
 83func createServeMux(handler *CliHandler, pubsub psub.PubSub) func(ctx ssh.Context) http.Handler {
 84	return func(ctx ssh.Context) http.Handler {
 85		router := http.NewServeMux()
 86
 87		slug := ""
 88		user, err := getUserCtx(ctx)
 89		if err == nil && user != nil {
 90			slug = user.Name
 91		}
 92
 93		proxy := httputil.NewSingleHostReverseProxy(&url.URL{
 94			Scheme: "http",
 95			Host:   handler.RegistryUrl,
 96		})
 97
 98		oldDirector := proxy.Director
 99
100		proxy.Director = func(r *http.Request) {
101			handler.Logger.Info("director", "request", r)
102			oldDirector(r)
103
104			if strings.HasSuffix(r.URL.Path, "_catalog") || r.URL.Path == "/v2" || r.URL.Path == "/v2/" {
105				return
106			}
107
108			fullPath := strings.TrimPrefix(r.URL.Path, "/v2")
109
110			newPath, err := url.JoinPath("/v2", slug, fullPath)
111			if err != nil {
112				return
113			}
114
115			r.URL.Path = newPath
116
117			query := r.URL.Query()
118
119			if query.Has("from") {
120				joinedFrom, err := url.JoinPath(slug, query.Get("from"))
121				if err != nil {
122					return
123				}
124				query.Set("from", joinedFrom)
125
126				r.URL.RawQuery = query.Encode()
127			}
128		}
129
130		proxy.ModifyResponse = func(r *http.Response) error {
131			handler.Logger.Info("modify", "request", r)
132			shared.CorsHeaders(r.Header)
133
134			if r.Request.Method == http.MethodGet && strings.HasSuffix(r.Request.URL.Path, "_catalog") {
135				b, err := io.ReadAll(r.Body)
136				if err != nil {
137					return err
138				}
139
140				err = r.Body.Close()
141				if err != nil {
142					return err
143				}
144
145				var data map[string]any
146				err = json.Unmarshal(b, &data)
147				if err != nil {
148					return err
149				}
150
151				newRepos := []string{}
152
153				if repos, ok := data["repositories"].([]any); ok {
154					for _, repo := range repos {
155						if repoStr, ok := repo.(string); ok && strings.HasPrefix(repoStr, slug) {
156							newRepos = append(newRepos, strings.Replace(repoStr, fmt.Sprintf("%s/", slug), "", 1))
157						}
158					}
159				}
160
161				data["repositories"] = newRepos
162
163				newB, err := json.Marshal(data)
164				if err != nil {
165					return err
166				}
167
168				jsonBuf := bytes.NewBuffer(newB)
169
170				r.ContentLength = int64(jsonBuf.Len())
171				r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
172				r.Body = io.NopCloser(jsonBuf)
173			}
174
175			if r.Request.Method == http.MethodGet && (strings.Contains(r.Request.URL.Path, "/tags/") || strings.Contains(r.Request.URL.Path, "/manifests/")) {
176				splitPath := strings.Split(r.Request.URL.Path, "/")
177
178				if len(splitPath) > 1 {
179					ele := splitPath[len(splitPath)-2]
180					if ele == "tags" || ele == "manifests" {
181						b, err := io.ReadAll(r.Body)
182						if err != nil {
183							return err
184						}
185
186						err = r.Body.Close()
187						if err != nil {
188							return err
189						}
190
191						var data map[string]any
192						err = json.Unmarshal(b, &data)
193						if err != nil {
194							return err
195						}
196
197						if name, ok := data["name"].(string); ok {
198							if strings.HasPrefix(name, slug) {
199								data["name"] = strings.Replace(name, fmt.Sprintf("%s/", slug), "", 1)
200							}
201						}
202
203						newB, err := json.Marshal(data)
204						if err != nil {
205							return err
206						}
207
208						jsonBuf := bytes.NewBuffer(newB)
209
210						r.ContentLength = int64(jsonBuf.Len())
211						r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
212						r.Body = io.NopCloser(jsonBuf)
213					}
214				}
215			}
216
217			if r.Request.Method == http.MethodPut && strings.Contains(r.Request.URL.Path, "/manifests/") {
218				digest := r.Header.Get("Docker-Content-Digest")
219				// [ ]/v2/erock/alpine/manifests/latest
220				splitPath := strings.Split(r.Request.URL.Path, "/")
221				img := splitPath[3]
222				tag := splitPath[5]
223
224				furl := fmt.Sprintf(
225					"digest=%s&image=%s&tag=%s",
226					url.QueryEscape(digest),
227					img,
228					tag,
229				)
230				handler.Logger.Info("sending event", "url", furl)
231
232				err := pubsub.Pub(ctx, uuid.NewString(), bytes.NewBufferString(furl), []*psub.Channel{
233					psub.NewChannel(fmt.Sprintf("%s/%s:%s", user.Name, img, tag)),
234				}, false)
235
236				if err != nil {
237					handler.Logger.Error("pub error", "err", err)
238				}
239			}
240
241			locationHeader := r.Header.Get("location")
242			if strings.Contains(locationHeader, fmt.Sprintf("/v2/%s", slug)) {
243				r.Header.Set("location", strings.ReplaceAll(locationHeader, fmt.Sprintf("/v2/%s", slug), "/v2"))
244			}
245
246			return nil
247		}
248
249		router.HandleFunc("/", proxy.ServeHTTP)
250
251		return router
252	}
253}
254
255func StartSshServer() {
256	host := utils.GetEnv("IMGS_HOST", "0.0.0.0")
257	port := utils.GetEnv("IMGS_SSH_PORT", "2222")
258	promPort := utils.GetEnv("IMGS_PROM_PORT", "9222")
259	dbUrl := os.Getenv("DATABASE_URL")
260	registryUrl := utils.GetEnv("REGISTRY_URL", "0.0.0.0:5000")
261	minioUrl := utils.GetEnv("MINIO_URL", "http://0.0.0.0:9000")
262	minioUser := utils.GetEnv("MINIO_ROOT_USER", "")
263	minioPass := utils.GetEnv("MINIO_ROOT_PASSWORD", "")
264
265	logger := shared.CreateLogger("imgs")
266	logger.Info("bootup", "registry", registryUrl, "minio", minioUrl)
267	dbh := postgres.NewDB(dbUrl, logger)
268	st, err := storage.NewStorageMinio(minioUrl, minioUser, minioPass)
269	if err != nil {
270		panic(err)
271	}
272
273	pubsub := psub.NewMulticast(logger)
274	handler := &CliHandler{
275		Logger:      logger,
276		DBPool:      dbh,
277		Storage:     st,
278		RegistryUrl: registryUrl,
279		PubSub:      pubsub,
280	}
281
282	s, err := wish.NewServer(
283		wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
284		wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
285		wish.WithPublicKeyAuth(AuthHandler(dbh, logger)),
286		wish.WithMiddleware(
287			WishMiddleware(handler),
288			promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "imgs-ssh"),
289			wsh.LogMiddleware(logger),
290		),
291		tunkit.WithWebTunnel(
292			tunkit.NewWebTunnelHandler(createServeMux(handler, pubsub), logger),
293		),
294	)
295
296	if err != nil {
297		logger.Error("could not create server", "err", err)
298	}
299
300	done := make(chan os.Signal, 1)
301	signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
302	logger.Info("starting SSH server", "host", host, "port", port)
303	go func() {
304		if err = s.ListenAndServe(); err != nil {
305			logger.Error("serve error", "err", err)
306			os.Exit(1)
307		}
308	}()
309
310	<-done
311	logger.Info("stopping SSH server")
312	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
313	defer func() { cancel() }()
314	if err := s.Shutdown(ctx); err != nil {
315		logger.Error("shutdown", "err", err)
316		os.Exit(1)
317	}
318}