Eric Bower
·
10 Dec 24
ssh.go
1package pgs
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "os/signal"
8 "syscall"
9 "time"
10
11 "github.com/charmbracelet/promwish"
12 "github.com/charmbracelet/ssh"
13 "github.com/charmbracelet/wish"
14 "github.com/picosh/pico/db/postgres"
15 "github.com/picosh/pico/shared"
16 "github.com/picosh/pico/shared/storage"
17 wsh "github.com/picosh/pico/wish"
18 "github.com/picosh/send/auth"
19 "github.com/picosh/send/list"
20 "github.com/picosh/send/pipe"
21 wishrsync "github.com/picosh/send/protocols/rsync"
22 "github.com/picosh/send/protocols/scp"
23 "github.com/picosh/send/protocols/sftp"
24 "github.com/picosh/send/proxy"
25 "github.com/picosh/tunkit"
26 "github.com/picosh/utils"
27)
28
29func createRouter(handler *UploadAssetHandler) proxy.Router {
30 return func(sh ssh.Handler, s ssh.Session) []wish.Middleware {
31 return []wish.Middleware{
32 pipe.Middleware(handler, ""),
33 list.Middleware(handler),
34 scp.Middleware(handler),
35 wishrsync.Middleware(handler),
36 auth.Middleware(handler),
37 wsh.PtyMdw(wsh.DeprecatedNotice()),
38 WishMiddleware(handler),
39 wsh.LogMiddleware(handler.GetLogger()),
40 }
41 }
42}
43
44func withProxy(handler *UploadAssetHandler, otherMiddleware ...wish.Middleware) ssh.Option {
45 return func(server *ssh.Server) error {
46 err := sftp.SSHOption(handler)(server)
47 if err != nil {
48 return err
49 }
50
51 return proxy.WithProxy(createRouter(handler), otherMiddleware...)(server)
52 }
53}
54
55func StartSshServer() {
56 host := utils.GetEnv("PGS_HOST", "0.0.0.0")
57 port := utils.GetEnv("PGS_SSH_PORT", "2222")
58 promPort := utils.GetEnv("PGS_PROM_PORT", "9222")
59 cfg := NewConfigSite()
60 logger := cfg.Logger
61 dbpool := postgres.NewDB(cfg.DbURL, cfg.Logger)
62 defer dbpool.Close()
63
64 var st storage.StorageServe
65 var err error
66 if cfg.MinioURL == "" {
67 st, err = storage.NewStorageFS(cfg.StorageDir)
68 } else {
69 st, err = storage.NewStorageMinio(cfg.MinioURL, cfg.MinioUser, cfg.MinioPass)
70 }
71
72 if err != nil {
73 logger.Error(err.Error())
74 return
75 }
76
77 ctx := context.Background()
78 defer ctx.Done()
79 handler := NewUploadAssetHandler(
80 dbpool,
81 cfg,
82 st,
83 ctx,
84 )
85
86 apiConfig := &shared.ApiConfig{
87 Cfg: cfg,
88 Dbpool: dbpool,
89 Storage: st,
90 }
91
92 webTunnel := &tunkit.WebTunnelHandler{
93 Logger: logger,
94 HttpHandler: createHttpHandler(apiConfig),
95 }
96
97 sshAuth := shared.NewSshAuthHandler(dbpool, logger, cfg)
98 s, err := wish.NewServer(
99 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
100 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
101 wish.WithPublicKeyAuth(sshAuth.PubkeyAuthHandler),
102 tunkit.WithWebTunnel(webTunnel),
103 withProxy(
104 handler,
105 promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "pgs-ssh"),
106 ),
107 )
108 if err != nil {
109 logger.Error(err.Error())
110 return
111 }
112
113 done := make(chan os.Signal, 1)
114 signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
115 logger.Info("starting SSH server on", "host", host, "port", port)
116 go func() {
117 if err = s.ListenAndServe(); err != nil {
118 logger.Error("serve", "err", err.Error())
119 os.Exit(1)
120 }
121 }()
122
123 <-done
124 logger.Info("stopping SSH server")
125 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
126 defer func() { cancel() }()
127 if err := s.Shutdown(ctx); err != nil {
128 logger.Error("shutdown", "err", err.Error())
129 os.Exit(1)
130 }
131}