Eric Bower
·
20 Sep 24
ssh.go
1package imgs
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "log"
10 "log/slog"
11 "net/http"
12 "net/http/httputil"
13 "net/url"
14 "os"
15 "os/signal"
16 "strconv"
17 "strings"
18 "syscall"
19 "time"
20
21 "github.com/charmbracelet/ssh"
22 "github.com/charmbracelet/wish"
23 "github.com/google/uuid"
24 "github.com/picosh/pico/db"
25 "github.com/picosh/pico/db/postgres"
26 "github.com/picosh/pico/shared"
27 "github.com/picosh/pico/shared/storage"
28 psub "github.com/picosh/pubsub"
29 "github.com/picosh/tunkit"
30)
31
32type ctxUserKey struct{}
33
34func getUserCtx(ctx ssh.Context) (*db.User, error) {
35 user, ok := ctx.Value(ctxUserKey{}).(*db.User)
36 if user == nil || !ok {
37 return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
38 }
39 return user, nil
40}
41func setUserCtx(ctx ssh.Context, user *db.User) {
42 ctx.SetValue(ctxUserKey{}, user)
43}
44
45func AuthHandler(dbh db.DB, log *slog.Logger) func(ssh.Context, ssh.PublicKey) bool {
46 return func(ctx ssh.Context, key ssh.PublicKey) bool {
47 kk, err := shared.KeyForKeyText(key)
48 if err != nil {
49 log.Error("cannot get pubkey", "err", err)
50 return false
51 }
52
53 user, err := dbh.FindUserForKey("", kk)
54 if err != nil {
55 log.Error("user not found", "err", err)
56 return false
57 }
58
59 if user == nil {
60 log.Error("user not found", "err", err)
61 return false
62 }
63
64 setUserCtx(ctx, user)
65
66 if !dbh.HasFeatureForUser(user.ID, "plus") {
67 log.Error("not a pico+ user", "user", user.Name)
68 return false
69 }
70
71 return true
72 }
73}
74
75type ErrorHandler struct {
76 Err error
77}
78
79func (e *ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
80 log.Println(e.Err.Error())
81 http.Error(w, e.Err.Error(), http.StatusInternalServerError)
82}
83
84func createServeMux(handler *CliHandler, pubsub *psub.Cfg) func(ctx ssh.Context) http.Handler {
85 return func(ctx ssh.Context) http.Handler {
86 router := http.NewServeMux()
87
88 slug := ""
89 user, err := getUserCtx(ctx)
90 if err == nil && user != nil {
91 slug = user.Name
92 }
93
94 proxy := httputil.NewSingleHostReverseProxy(&url.URL{
95 Scheme: "http",
96 Host: handler.RegistryUrl,
97 })
98
99 oldDirector := proxy.Director
100
101 proxy.Director = func(r *http.Request) {
102 handler.Logger.Info("director", "request", r)
103 oldDirector(r)
104
105 if strings.HasSuffix(r.URL.Path, "_catalog") || r.URL.Path == "/v2" || r.URL.Path == "/v2/" {
106 return
107 }
108
109 fullPath := strings.TrimPrefix(r.URL.Path, "/v2")
110
111 newPath, err := url.JoinPath("/v2", slug, fullPath)
112 if err != nil {
113 return
114 }
115
116 r.URL.Path = newPath
117
118 query := r.URL.Query()
119
120 if query.Has("from") {
121 joinedFrom, err := url.JoinPath(slug, query.Get("from"))
122 if err != nil {
123 return
124 }
125 query.Set("from", joinedFrom)
126
127 r.URL.RawQuery = query.Encode()
128 }
129 }
130
131 proxy.ModifyResponse = func(r *http.Response) error {
132 handler.Logger.Info("modify", "request", r)
133 shared.CorsHeaders(r.Header)
134
135 if r.Request.Method == http.MethodGet && strings.HasSuffix(r.Request.URL.Path, "_catalog") {
136 b, err := io.ReadAll(r.Body)
137 if err != nil {
138 return err
139 }
140
141 err = r.Body.Close()
142 if err != nil {
143 return err
144 }
145
146 var data map[string]any
147 err = json.Unmarshal(b, &data)
148 if err != nil {
149 return err
150 }
151
152 newRepos := []string{}
153
154 if repos, ok := data["repositories"].([]any); ok {
155 for _, repo := range repos {
156 if repoStr, ok := repo.(string); ok && strings.HasPrefix(repoStr, slug) {
157 newRepos = append(newRepos, strings.Replace(repoStr, fmt.Sprintf("%s/", slug), "", 1))
158 }
159 }
160 }
161
162 data["repositories"] = newRepos
163
164 newB, err := json.Marshal(data)
165 if err != nil {
166 return err
167 }
168
169 jsonBuf := bytes.NewBuffer(newB)
170
171 r.ContentLength = int64(jsonBuf.Len())
172 r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
173 r.Body = io.NopCloser(jsonBuf)
174 }
175
176 if r.Request.Method == http.MethodGet && (strings.Contains(r.Request.URL.Path, "/tags/") || strings.Contains(r.Request.URL.Path, "/manifests/")) {
177 splitPath := strings.Split(r.Request.URL.Path, "/")
178
179 if len(splitPath) > 1 {
180 ele := splitPath[len(splitPath)-2]
181 if ele == "tags" || ele == "manifests" {
182 b, err := io.ReadAll(r.Body)
183 if err != nil {
184 return err
185 }
186
187 err = r.Body.Close()
188 if err != nil {
189 return err
190 }
191
192 var data map[string]any
193 err = json.Unmarshal(b, &data)
194 if err != nil {
195 return err
196 }
197
198 if name, ok := data["name"].(string); ok {
199 if strings.HasPrefix(name, slug) {
200 data["name"] = strings.Replace(name, fmt.Sprintf("%s/", slug), "", 1)
201 }
202 }
203
204 newB, err := json.Marshal(data)
205 if err != nil {
206 return err
207 }
208
209 jsonBuf := bytes.NewBuffer(newB)
210
211 r.ContentLength = int64(jsonBuf.Len())
212 r.Header.Set("Content-Length", strconv.FormatInt(r.ContentLength, 10))
213 r.Body = io.NopCloser(jsonBuf)
214 }
215 }
216 }
217
218 if r.Request.Method == http.MethodPut && strings.Contains(r.Request.URL.Path, "/manifests/") {
219 digest := r.Header.Get("Docker-Content-Digest")
220 // [ ]/v2/erock/alpine/manifests/latest
221 splitPath := strings.Split(r.Request.URL.Path, "/")
222 img := splitPath[3]
223 tag := splitPath[5]
224
225 furl := fmt.Sprintf(
226 "digest=%s&image=%s&tag=%s",
227 url.QueryEscape(digest),
228 img,
229 tag,
230 )
231 handler.Logger.Info("sending event", "url", furl)
232
233 err := pubsub.PubSub.Pub(fmt.Sprintf("%s@%s:%s", user.Name, img, tag), &psub.Pub{
234 ID: uuid.NewString(),
235 Done: make(chan struct{}),
236 Reader: strings.NewReader(furl),
237 })
238
239 if err != nil {
240 handler.Logger.Error("pub error", "err", err)
241 }
242 }
243
244 locationHeader := r.Header.Get("location")
245 if strings.Contains(locationHeader, fmt.Sprintf("/v2/%s", slug)) {
246 r.Header.Set("location", strings.ReplaceAll(locationHeader, fmt.Sprintf("/v2/%s", slug), "/v2"))
247 }
248
249 return nil
250 }
251
252 router.HandleFunc("/", proxy.ServeHTTP)
253
254 return router
255 }
256}
257
258func StartSshServer() {
259 host := os.Getenv("SSH_HOST")
260 if host == "" {
261 host = "0.0.0.0"
262 }
263 port := os.Getenv("SSH_PORT")
264 if port == "" {
265 port = "2222"
266 }
267 dbUrl := os.Getenv("DATABASE_URL")
268 registryUrl := shared.GetEnv("REGISTRY_URL", "0.0.0.0:5000")
269 minioUrl := shared.GetEnv("MINIO_URL", "http://0.0.0.0:9000")
270 minioUser := shared.GetEnv("MINIO_ROOT_USER", "")
271 minioPass := shared.GetEnv("MINIO_ROOT_PASSWORD", "")
272
273 logger := shared.CreateLogger("imgs")
274 logger.Info("bootup", "registry", registryUrl, "minio", minioUrl)
275 dbh := postgres.NewDB(dbUrl, logger)
276 st, err := storage.NewStorageMinio(minioUrl, minioUser, minioPass)
277 if err != nil {
278 panic(err)
279 }
280
281 pubsub := &psub.Cfg{
282 Logger: logger,
283 PubSub: &psub.PubSubMulticast{
284 Logger: logger,
285 },
286 }
287
288 handler := &CliHandler{
289 Logger: logger,
290 DBPool: dbh,
291 Storage: st,
292 RegistryUrl: registryUrl,
293 PubSub: pubsub,
294 }
295
296 s, err := wish.NewServer(
297 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
298 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
299 wish.WithPublicKeyAuth(AuthHandler(dbh, logger)),
300 wish.WithMiddleware(WishMiddleware(handler)),
301 tunkit.WithWebTunnel(
302 tunkit.NewWebTunnelHandler(createServeMux(handler, pubsub), logger),
303 ),
304 )
305
306 if err != nil {
307 logger.Error("could not create server", "err", err)
308 }
309
310 done := make(chan os.Signal, 1)
311 signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
312 logger.Info("starting SSH server", "host", host, "port", port)
313 go func() {
314 if err = s.ListenAndServe(); err != nil {
315 logger.Error("serve error", "err", err)
316 os.Exit(1)
317 }
318 }()
319
320 <-done
321 logger.Info("stopping SSH server")
322 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
323 defer func() { cancel() }()
324 if err := s.Shutdown(ctx); err != nil {
325 logger.Error("shutdown", "err", err)
326 os.Exit(1)
327 }
328}