Antonio Mika
·
11 Oct 24
ssh.go
1package imgs
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "log"
10 "log/slog"
11 "net/http"
12 "net/http/httputil"
13 "net/url"
14 "os"
15 "os/signal"
16 "strconv"
17 "strings"
18 "syscall"
19 "time"
20
21 "github.com/charmbracelet/promwish"
22 "github.com/charmbracelet/ssh"
23 "github.com/charmbracelet/wish"
24 "github.com/google/uuid"
25 "github.com/picosh/pico/db"
26 "github.com/picosh/pico/db/postgres"
27 "github.com/picosh/pico/shared"
28 "github.com/picosh/pico/shared/storage"
29 wsh "github.com/picosh/pico/wish"
30 psub "github.com/picosh/pubsub"
31 "github.com/picosh/tunkit"
32 "github.com/picosh/utils"
33)
34
35type ctxUserKey struct{}
36
37func getUserCtx(ctx ssh.Context) (*db.User, error) {
38 user, ok := ctx.Value(ctxUserKey{}).(*db.User)
39 if user == nil || !ok {
40 return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
41 }
42 return user, nil
43}
44func setUserCtx(ctx ssh.Context, user *db.User) {
45 ctx.SetValue(ctxUserKey{}, user)
46}
47
48func AuthHandler(dbh db.DB, log *slog.Logger) func(ssh.Context, ssh.PublicKey) bool {
49 return func(ctx ssh.Context, key ssh.PublicKey) bool {
50 kk := utils.KeyForKeyText(key)
51
52 user, err := dbh.FindUserForKey("", kk)
53 if err != nil {
54 log.Error("user not found", "err", err)
55 return false
56 }
57
58 if user == nil {
59 log.Error("user not found", "err", err)
60 return false
61 }
62
63 setUserCtx(ctx, user)
64
65 if !dbh.HasFeatureForUser(user.ID, "plus") {
66 log.Error("not a pico+ user", "user", user.Name)
67 return false
68 }
69
70 return true
71 }
72}
73
74type ErrorHandler struct {
75 Err error
76}
77
78func (e *ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
79 log.Println(e.Err.Error())
80 http.Error(w, e.Err.Error(), http.StatusInternalServerError)
81}
82
83func createServeMux(handler *CliHandler, pubsub psub.PubSub) func(ctx ssh.Context) http.Handler {
84 return func(ctx ssh.Context) http.Handler {
85 router := http.NewServeMux()
86
87 slug := ""
88 user, err := getUserCtx(ctx)
89 if err == nil && user != nil {
90 slug = user.Name
91 }
92
93 proxy := httputil.NewSingleHostReverseProxy(&url.URL{
94 Scheme: "http",
95 Host: handler.RegistryUrl,
96 })
97
98 oldDirector := proxy.Director
99
100 proxy.Director = func(r *http.Request) {
101 handler.Logger.Info("director", "request", r)
102 oldDirector(r)
103
104 if strings.HasSuffix(r.URL.Path, "_catalog") || r.URL.Path == "/v2" || r.URL.Path == "/v2/" {
105 return
106 }
107
108 fullPath := strings.TrimPrefix(r.URL.Path, "/v2")
109
110 newPath, err := url.JoinPath("/v2", slug, fullPath)
111 if err != nil {
112 return
113 }
114
115 r.URL.Path = newPath
116
117 query := r.URL.Query()
118
119 if query.Has("from") {
120 joinedFrom, err := url.JoinPath(slug, query.Get("from"))
121 if err != nil {
122 return
123 }
124 query.Set("from", joinedFrom)
125
126 r.URL.RawQuery = query.Encode()
127 }
128 }
129
130 proxy.ModifyResponse = func(r *http.Response) error {
131 handler.Logger.Info("modify", "request", r)
132 shared.CorsHeaders(r.Header)
133
134 if r.Request.Method == http.MethodGet && strings.HasSuffix(r.Request.URL.Path, "_catalog") {
135 b, err := io.ReadAll(r.Body)
136 if err != nil {
137 return err
138 }
139
140 err = r.Body.Close()
141 if err != nil {
142 return err
143 }
144
145 var data map[string]any
146 err = json.Unmarshal(b, &data)
147 if err != nil {
148 return err
149 }
150
151 newRepos := []string{}
152
153 if repos, ok := data["repositories"].([]any); ok {
154 for _, repo := range repos {
155 if repoStr, ok := repo.(string); ok && strings.HasPrefix(repoStr, slug) {
156 newRepos = append(newRepos, strings.Replace(repoStr, fmt.Sprintf("%s/", slug), "", 1))
157 }
158 }
159 }
160
161 data["repositories"] = newRepos
162
163 newB, err := json.Marshal(data)
164 if err != nil {
165 return err
166 }
167
168 jsonBuf := bytes.NewBuffer(newB)
169
170 r.ContentLength = int64(jsonBuf.Len())
171 r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
172 r.Body = io.NopCloser(jsonBuf)
173 }
174
175 if r.Request.Method == http.MethodGet && (strings.Contains(r.Request.URL.Path, "/tags/") || strings.Contains(r.Request.URL.Path, "/manifests/")) {
176 splitPath := strings.Split(r.Request.URL.Path, "/")
177
178 if len(splitPath) > 1 {
179 ele := splitPath[len(splitPath)-2]
180 if ele == "tags" || ele == "manifests" {
181 b, err := io.ReadAll(r.Body)
182 if err != nil {
183 return err
184 }
185
186 err = r.Body.Close()
187 if err != nil {
188 return err
189 }
190
191 var data map[string]any
192 err = json.Unmarshal(b, &data)
193 if err != nil {
194 return err
195 }
196
197 if name, ok := data["name"].(string); ok {
198 if strings.HasPrefix(name, slug) {
199 data["name"] = strings.Replace(name, fmt.Sprintf("%s/", slug), "", 1)
200 }
201 }
202
203 newB, err := json.Marshal(data)
204 if err != nil {
205 return err
206 }
207
208 jsonBuf := bytes.NewBuffer(newB)
209
210 r.ContentLength = int64(jsonBuf.Len())
211 r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
212 r.Body = io.NopCloser(jsonBuf)
213 }
214 }
215 }
216
217 if r.Request.Method == http.MethodPut && strings.Contains(r.Request.URL.Path, "/manifests/") {
218 digest := r.Header.Get("Docker-Content-Digest")
219 // [ ]/v2/erock/alpine/manifests/latest
220 splitPath := strings.Split(r.Request.URL.Path, "/")
221 img := splitPath[3]
222 tag := splitPath[5]
223
224 furl := fmt.Sprintf(
225 "digest=%s&image=%s&tag=%s",
226 url.QueryEscape(digest),
227 img,
228 tag,
229 )
230 handler.Logger.Info("sending event", "url", furl)
231
232 err := pubsub.Pub(ctx, uuid.NewString(), bytes.NewBufferString(furl), []*psub.Channel{
233 psub.NewChannel(fmt.Sprintf("%s/%s:%s", user.Name, img, tag)),
234 }, false)
235
236 if err != nil {
237 handler.Logger.Error("pub error", "err", err)
238 }
239 }
240
241 locationHeader := r.Header.Get("location")
242 if strings.Contains(locationHeader, fmt.Sprintf("/v2/%s", slug)) {
243 r.Header.Set("location", strings.ReplaceAll(locationHeader, fmt.Sprintf("/v2/%s", slug), "/v2"))
244 }
245
246 return nil
247 }
248
249 router.HandleFunc("/", proxy.ServeHTTP)
250
251 return router
252 }
253}
254
255func StartSshServer() {
256 host := utils.GetEnv("IMGS_HOST", "0.0.0.0")
257 port := utils.GetEnv("IMGS_SSH_PORT", "2222")
258 promPort := utils.GetEnv("IMGS_PROM_PORT", "9222")
259 dbUrl := os.Getenv("DATABASE_URL")
260 registryUrl := utils.GetEnv("REGISTRY_URL", "0.0.0.0:5000")
261 minioUrl := utils.GetEnv("MINIO_URL", "http://0.0.0.0:9000")
262 minioUser := utils.GetEnv("MINIO_ROOT_USER", "")
263 minioPass := utils.GetEnv("MINIO_ROOT_PASSWORD", "")
264
265 logger := shared.CreateLogger("imgs")
266 logger.Info("bootup", "registry", registryUrl, "minio", minioUrl)
267 dbh := postgres.NewDB(dbUrl, logger)
268 st, err := storage.NewStorageMinio(minioUrl, minioUser, minioPass)
269 if err != nil {
270 panic(err)
271 }
272
273 pubsub := psub.NewMulticast(logger)
274 handler := &CliHandler{
275 Logger: logger,
276 DBPool: dbh,
277 Storage: st,
278 RegistryUrl: registryUrl,
279 PubSub: pubsub,
280 }
281
282 s, err := wish.NewServer(
283 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
284 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
285 wish.WithPublicKeyAuth(AuthHandler(dbh, logger)),
286 wish.WithMiddleware(
287 WishMiddleware(handler),
288 promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "imgs-ssh"),
289 wsh.LogMiddleware(logger),
290 ),
291 tunkit.WithWebTunnel(
292 tunkit.NewWebTunnelHandler(createServeMux(handler, pubsub), logger),
293 ),
294 )
295
296 if err != nil {
297 logger.Error("could not create server", "err", err)
298 }
299
300 done := make(chan os.Signal, 1)
301 signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
302 logger.Info("starting SSH server", "host", host, "port", port)
303 go func() {
304 if err = s.ListenAndServe(); err != nil {
305 logger.Error("serve error", "err", err)
306 os.Exit(1)
307 }
308 }()
309
310 <-done
311 logger.Info("stopping SSH server")
312 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
313 defer func() { cancel() }()
314 if err := s.Shutdown(ctx); err != nil {
315 logger.Error("shutdown", "err", err)
316 os.Exit(1)
317 }
318}