Antonio Mika
·
17 Nov 24
cli.go
1package pico
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/ssh"
12 "github.com/charmbracelet/wish"
13 "github.com/picosh/pico/db"
14 "github.com/picosh/pico/shared"
15 "github.com/picosh/pico/tui/common"
16 "github.com/picosh/pico/tui/notifications"
17 "github.com/picosh/pico/tui/plus"
18 "github.com/picosh/utils"
19
20 pipeLogger "github.com/picosh/utils/pipe/log"
21)
22
23func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
24 if s.PublicKey() == nil {
25 return nil, fmt.Errorf("key not found")
26 }
27
28 key := utils.KeyForKeyText(s.PublicKey())
29
30 user, err := dbpool.FindUserForKey(s.User(), key)
31 if err != nil {
32 return nil, err
33 }
34
35 if user.Name == "" {
36 return nil, fmt.Errorf("must have username set")
37 }
38
39 return user, nil
40}
41
42type Cmd struct {
43 User *db.User
44 SshSession ssh.Session
45 Session utils.CmdSession
46 Log *slog.Logger
47 Dbpool db.DB
48 Write bool
49 Styles common.Styles
50}
51
52func (c *Cmd) output(out string) {
53 _, _ = c.Session.Write([]byte(out + "\r\n"))
54}
55
56func (c *Cmd) help() {
57 helpStr := "Commands: [help, pico+]\n"
58 c.output(helpStr)
59}
60
61func (c *Cmd) plus() {
62 view := plus.PlusView(c.User.Name, 80)
63 c.output(view)
64}
65
66func (c *Cmd) notifications() error {
67 md := notifications.NotificationsView(c.Dbpool, c.User.ID, 80)
68 c.output(md)
69 return nil
70}
71
72func (c *Cmd) logs(ctx context.Context) error {
73 conn := shared.NewPicoPipeClient()
74 stdoutPipe, err := pipeLogger.ReadLogs(ctx, c.Log, conn)
75
76 if err != nil {
77 return err
78 }
79
80 scanner := bufio.NewScanner(stdoutPipe)
81 for scanner.Scan() {
82 line := scanner.Text()
83 parsedData := map[string]any{}
84
85 err := json.Unmarshal([]byte(line), &parsedData)
86 if err != nil {
87 c.Log.Error("json unmarshal", "err", err)
88 continue
89 }
90
91 user := utils.AnyToStr(parsedData, "user")
92 userId := utils.AnyToStr(parsedData, "userId")
93 if user == c.User.Name || userId == c.User.ID {
94 wish.Println(c.SshSession, line)
95 }
96 }
97 return scanner.Err()
98}
99
100type CliHandler struct {
101 DBPool db.DB
102 Logger *slog.Logger
103}
104
105func WishMiddleware(handler *CliHandler) wish.Middleware {
106 dbpool := handler.DBPool
107 log := handler.Logger
108
109 return func(next ssh.Handler) ssh.Handler {
110 return func(sesh ssh.Session) {
111 args := sesh.Command()
112 if len(args) == 0 {
113 next(sesh)
114 return
115 }
116
117 user, err := getUser(sesh, dbpool)
118 if err != nil {
119 wish.Errorf(sesh, "detected ssh command: %s\n", args)
120 s := fmt.Errorf("error: you need to create an account before using the remote cli: %w", err)
121 wish.Fatalln(sesh, s)
122 return
123 }
124
125 if len(args) > 0 && args[0] == "chat" {
126 _, _, hasPty := sesh.Pty()
127 if !hasPty {
128 wish.Fatalln(
129 sesh,
130 "In order to render chat you need to enable PTY with the `ssh -t` flag",
131 )
132 return
133 }
134
135 pass, err := dbpool.UpsertToken(user.ID, "pico-chat")
136 if err != nil {
137 wish.Fatalln(sesh, err)
138 return
139 }
140 app, err := shared.NewSenpaiApp(sesh, user.Name, pass)
141 if err != nil {
142 wish.Fatalln(sesh, err)
143 return
144 }
145 app.Run()
146 app.Close()
147 return
148 }
149
150 opts := Cmd{
151 Session: sesh,
152 SshSession: sesh,
153 User: user,
154 Log: log,
155 Dbpool: dbpool,
156 Write: false,
157 }
158
159 cmd := strings.TrimSpace(args[0])
160 if len(args) == 1 {
161 if cmd == "help" {
162 opts.help()
163 return
164 } else if cmd == "logs" {
165 err = opts.logs(sesh.Context())
166 if err != nil {
167 wish.Fatalln(sesh, err)
168 }
169 return
170 } else if cmd == "pico+" {
171 opts.plus()
172 return
173 } else if cmd == "notifications" {
174 err := opts.notifications()
175 if err != nil {
176 wish.Fatalln(sesh, err)
177 }
178 return
179 } else {
180 next(sesh)
181 return
182 }
183 }
184
185 next(sesh)
186 }
187 }
188}