Eric Bower
·
24 Sep 24
cli.go
1package pubsub
2
3import (
4 "bytes"
5 "flag"
6 "fmt"
7 "io"
8 "log/slog"
9 "strings"
10 "text/tabwriter"
11 "time"
12
13 "github.com/charmbracelet/ssh"
14 "github.com/charmbracelet/wish"
15 "github.com/google/uuid"
16 "github.com/picosh/pico/db"
17 "github.com/picosh/pico/shared"
18 psub "github.com/picosh/pubsub"
19 "github.com/picosh/send/send/utils"
20)
21
22func flagSet(cmdName string, sesh ssh.Session) *flag.FlagSet {
23 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
24 cmd.SetOutput(sesh)
25 return cmd
26}
27
28func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
29 _ = cmd.Parse(cmdArgs)
30
31 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
32 cmd.Usage()
33 return false
34 }
35 return true
36}
37
38func NewTabWriter(out io.Writer) *tabwriter.Writer {
39 return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
40}
41
42func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
43 var err error
44 key, err := shared.KeyText(s)
45 if err != nil {
46 return nil, fmt.Errorf("key not found")
47 }
48
49 user, err := dbpool.FindUserForKey(s.User(), key)
50 if err != nil {
51 return nil, err
52 }
53
54 if user.Name == "" {
55 return nil, fmt.Errorf("must have username set")
56 }
57
58 return user, nil
59}
60
61// scope channel to user by prefixing name.
62func toChannel(userName, name string) string {
63 return fmt.Sprintf("%s/%s", userName, name)
64}
65
66func toPublicChannel(name string) string {
67 return fmt.Sprintf("public/%s", name)
68}
69
70var helpStr = `Commands: [pub, sub, ls, pipe]
71
72The simplest authenticated pubsub system. Send messages through
73user-defined channels. Channels are private to the authenticated
74ssh user. The default pubsub model is multicast with bidirectional
75blocking, meaning a publisher ("pub") will send its message to all
76subscribers ("sub"). Further, both "pub" and "sub" will wait for
77at least one event to be sent or received. Pipe ("pipe") allows
78for bidirectional messages to be sent between any clients connected
79to a pipe.`
80
81type CliHandler struct {
82 DBPool db.DB
83 Logger *slog.Logger
84 PubSub *psub.Cfg
85 Cfg *shared.ConfigSite
86}
87
88func toSshCmd(cfg *shared.ConfigSite) string {
89 port := "22"
90 if cfg.Port != "" {
91 port = fmt.Sprintf("-p %s", cfg.Port)
92 }
93 return fmt.Sprintf("%s %s", port, cfg.Domain)
94}
95
96func WishMiddleware(handler *CliHandler) wish.Middleware {
97 dbpool := handler.DBPool
98 pubsub := handler.PubSub
99
100 return func(next ssh.Handler) ssh.Handler {
101 return func(sesh ssh.Session) {
102 logger := handler.Logger
103 ctx := sesh.Context()
104 user, err := getUser(sesh, dbpool)
105 if err != nil {
106 utils.ErrorHandler(sesh, err)
107 return
108 }
109
110 logger = shared.LoggerWithUser(logger, user)
111
112 args := sesh.Command()
113
114 if len(args) == 0 {
115 wish.Println(sesh, helpStr)
116 next(sesh)
117 return
118 }
119
120 cmd := strings.TrimSpace(args[0])
121 if cmd == "help" {
122 wish.Println(sesh, helpStr)
123 } else if cmd == "ls" {
124 channelFilter := fmt.Sprintf("%s/", user.Name)
125 if handler.DBPool.HasFeatureForUser(user.ID, "admin") {
126 channelFilter = ""
127 }
128
129 channels := pubsub.PubSub.GetChannels(channelFilter)
130 pipes := pubsub.PubSub.GetPipes(channelFilter)
131
132 if len(channels) == 0 && len(pipes) == 0 {
133 wish.Println(sesh, "no pubsub channels or pipes found")
134 } else {
135 var outputData string
136 if len(channels) > 0 {
137 outputData += "Channel Information\r\n"
138 for _, channel := range channels {
139 outputData += fmt.Sprintf("- %s:\r\n", channel.Name)
140 outputData += "\tPubs:\r\n"
141
142 channel.Pubs.Range(func(I string, J *psub.Pub) bool {
143 outputData += fmt.Sprintf("\t- %s:\r\n", I)
144 return true
145 })
146
147 outputData += "\tSubs:\r\n"
148
149 channel.Subs.Range(func(I string, J *psub.Sub) bool {
150 outputData += fmt.Sprintf("\t- %s:\r\n", I)
151 return true
152 })
153 }
154 }
155
156 if len(pipes) > 0 {
157 outputData += "Pipe Information\r\n"
158 for _, pipe := range pipes {
159 outputData += fmt.Sprintf("- %s:\r\n", pipe.Name)
160 outputData += "\tClients:\r\n"
161
162 pipe.Clients.Range(func(I string, J *psub.PipeClient) bool {
163 outputData += fmt.Sprintf("\t- %s:\r\n", I)
164 return true
165 })
166 }
167 }
168
169 _, _ = sesh.Write([]byte(outputData))
170 }
171 }
172
173 channelName := ""
174 cmdArgs := args[1:]
175 if len(args) > 1 {
176 channelName = strings.TrimSpace(args[1])
177 cmdArgs = args[2:]
178 }
179 logger.Info(
180 "imgs middleware detected command",
181 "args", args,
182 "cmd", cmd,
183 "channelName", channelName,
184 "cmdArgs", cmdArgs,
185 )
186
187 if cmd == "pub" {
188 defaultTimeout, _ := time.ParseDuration("720h")
189
190 pubCmd := flagSet("pub", sesh)
191 empty := pubCmd.Bool("e", false, "Send an empty message to subs")
192 public := pubCmd.Bool("p", false, "Anyone can sub to this channel")
193 timeout := pubCmd.Duration("t", defaultTimeout, "Timeout as a Go duration before cancelling the pub event. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
194 if !flagCheck(pubCmd, channelName, cmdArgs) {
195 return
196 }
197
198 var reader io.Reader
199 if *empty {
200 reader = bytes.NewReader(make([]byte, 1))
201 } else {
202 reader = sesh
203 }
204
205 if channelName == "" {
206 channelName = uuid.NewString()
207 }
208 name := toChannel(user.Name, channelName)
209 if *public {
210 name = toPublicChannel(channelName)
211 }
212 wish.Printf(
213 sesh,
214 "subscribe to this channel:\n\tssh %s sub %s\n",
215 toSshCmd(handler.Cfg),
216 channelName,
217 )
218
219 wish.Println(sesh, "sending msg ...")
220 pub := &psub.Pub{
221 ID: fmt.Sprintf("%s (%s@%s)", uuid.NewString(), user.Name, sesh.RemoteAddr().String()),
222 Done: make(chan struct{}),
223 Reader: reader,
224 }
225
226 count := 0
227 channelInfo := pubsub.PubSub.GetChannel(name)
228
229 if channelInfo != nil {
230 channelInfo.Subs.Range(func(I string, J *psub.Sub) bool {
231 count++
232 return true
233 })
234 }
235
236 tt := *timeout
237 if count == 0 {
238 str := "no subs found ... waiting"
239 if tt > 0 {
240 str += " " + tt.String()
241 }
242 wish.Println(sesh, str)
243 }
244
245 go func() {
246 select {
247 case <-ctx.Done():
248 pub.Cleanup()
249 case <-time.After(tt):
250 wish.Fatalln(sesh, "timeout reached, exiting ...")
251 pub.Cleanup()
252 }
253 }()
254
255 err = pubsub.PubSub.Pub(name, pub)
256 wish.Println(sesh, "msg sent!")
257 if err != nil {
258 wish.Errorln(sesh, err)
259 }
260 } else if cmd == "sub" {
261 pubCmd := flagSet("pub", sesh)
262 public := pubCmd.Bool("p", false, "Subscribe to a public channel")
263 if !flagCheck(pubCmd, channelName, cmdArgs) {
264 return
265 }
266 channelName := channelName
267
268 name := toChannel(user.Name, channelName)
269 if *public {
270 name = toPublicChannel(channelName)
271 }
272
273 sub := &psub.Sub{
274 ID: fmt.Sprintf("%s (%s@%s)", uuid.NewString(), user.Name, sesh.RemoteAddr().String()),
275 Writer: sesh,
276 Done: make(chan struct{}),
277 Data: make(chan []byte),
278 }
279
280 go func() {
281 <-ctx.Done()
282 sub.Cleanup()
283 }()
284 err = pubsub.PubSub.Sub(name, sub)
285 if err != nil {
286 wish.Errorln(sesh, err)
287 }
288 } else if cmd == "pipe" {
289 pipeCmd := flagSet("pipe", sesh)
290 public := pipeCmd.Bool("p", false, "Pipe to a public channel")
291 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
292 if !flagCheck(pipeCmd, channelName, cmdArgs) {
293 return
294 }
295 isCreator := channelName == ""
296 if isCreator {
297 channelName = uuid.NewString()
298 }
299 name := toChannel(user.Name, channelName)
300 if *public {
301 name = toPublicChannel(channelName)
302 }
303 if isCreator {
304 wish.Printf(
305 sesh,
306 "subscribe to this channel:\n\tssh %s sub %s\n",
307 toSshCmd(handler.Cfg),
308 channelName,
309 )
310 }
311
312 pipe := &psub.PipeClient{
313 ID: fmt.Sprintf("%s (%s@%s)", uuid.NewString(), user.Name, sesh.RemoteAddr().String()),
314 Done: make(chan struct{}),
315 Data: make(chan psub.PipeMessage),
316 ReadWriter: sesh,
317 Replay: *replay,
318 }
319
320 go func() {
321 <-ctx.Done()
322 pipe.Cleanup()
323 }()
324 readErr, writeErr := pubsub.PubSub.Pipe(name, pipe)
325 if readErr != nil {
326 wish.Errorln(sesh, readErr)
327 }
328 if writeErr != nil {
329 wish.Errorln(sesh, writeErr)
330 }
331 }
332
333 next(sesh)
334 }
335 }
336}