Eric Bower
·
24 Sep 24
ssh.go
1package pubsub
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "os/signal"
8 "syscall"
9 "time"
10
11 "github.com/antoniomika/syncmap"
12 "github.com/charmbracelet/promwish"
13 "github.com/charmbracelet/wish"
14 "github.com/picosh/pico/db/postgres"
15 "github.com/picosh/pico/filehandlers/util"
16 "github.com/picosh/pico/shared"
17 wsh "github.com/picosh/pico/wish"
18 psub "github.com/picosh/pubsub"
19)
20
21func StartSshServer() {
22 host := shared.GetEnv("PUBSUB_HOST", "0.0.0.0")
23 port := shared.GetEnv("PUBSUB_SSH_PORT", "2222")
24 promPort := shared.GetEnv("PUBSUB_PROM_PORT", "9222")
25 cfg := NewConfigSite()
26 logger := cfg.Logger
27 dbh := postgres.NewDB(cfg.DbURL, cfg.Logger)
28 defer dbh.Close()
29
30 cfg.Port = port
31
32 pubsub := &psub.Cfg{
33 Logger: logger,
34 PubSub: &psub.PubSubMulticast{
35 Logger: logger,
36 Channels: syncmap.New[string, *psub.Channel](),
37 Pipes: syncmap.New[string, *psub.Pipe](),
38 },
39 }
40
41 handler := &CliHandler{
42 Logger: logger,
43 DBPool: dbh,
44 PubSub: pubsub,
45 Cfg: cfg,
46 }
47
48 sshAuth := util.NewSshAuthHandler(dbh, logger, cfg)
49 s, err := wish.NewServer(
50 wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
51 wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
52 wish.WithPublicKeyAuth(sshAuth.PubkeyAuthHandler),
53 wish.WithMiddleware(
54 WishMiddleware(handler),
55 promwish.Middleware(fmt.Sprintf("%s:%s", host, promPort), "pubsub-ssh"),
56 wsh.LogMiddleware(logger),
57 ),
58 )
59 if err != nil {
60 logger.Error("wish server", "err", err.Error())
61 return
62 }
63
64 done := make(chan os.Signal, 1)
65 signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
66 logger.Info("Starting SSH server", "host", host, "port", port)
67 go func() {
68 if err = s.ListenAndServe(); err != nil {
69 logger.Error("listen", "err", err.Error())
70 }
71 }()
72
73 <-done
74 logger.Info("Stopping SSH server")
75 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
76 defer func() { cancel() }()
77 if err := s.Shutdown(ctx); err != nil {
78 logger.Error("shutdown", "err", err.Error())
79 }
80}