Eric Bower
·
13 Dec 24
cli.go
1package pipe
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "log/slog"
10 "slices"
11 "strings"
12 "text/tabwriter"
13 "time"
14
15 "github.com/antoniomika/syncmap"
16 "github.com/charmbracelet/ssh"
17 "github.com/charmbracelet/wish"
18 "github.com/google/uuid"
19 "github.com/picosh/pico/db"
20 "github.com/picosh/pico/shared"
21 psub "github.com/picosh/pubsub"
22 "github.com/picosh/utils"
23 gossh "golang.org/x/crypto/ssh"
24)
25
26func flagSet(cmdName string, sesh ssh.Session) *flag.FlagSet {
27 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
28 cmd.SetOutput(sesh)
29 cmd.Usage = func() {
30 fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
31 cmd.PrintDefaults()
32 }
33 return cmd
34}
35
36func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
37 err := cmd.Parse(cmdArgs)
38
39 if err != nil || posArg == "help" {
40 if posArg == "help" {
41 cmd.Usage()
42 }
43 return false
44 }
45 return true
46}
47
48func NewTabWriter(out io.Writer) *tabwriter.Writer {
49 return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
50}
51
52// scope topic to user by prefixing name.
53func toTopic(userName, topic string) string {
54 return fmt.Sprintf("%s/%s", userName, topic)
55}
56
57func toPublicTopic(topic string) string {
58 return fmt.Sprintf("public/%s", topic)
59}
60
61func clientInfo(clients []*psub.Client, clientType string) string {
62 if len(clients) == 0 {
63 return ""
64 }
65
66 outputData := fmt.Sprintf(" %s:\r\n", clientType)
67
68 for _, client := range clients {
69 outputData += fmt.Sprintf(" - %s\r\n", client.ID)
70 }
71
72 return outputData
73}
74
75var helpStr = func(sshCmd string) string {
76 return fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
77
78The simplest authenticated pubsub system. Send messages through
79user-defined topics. Topics are private to the authenticated
80ssh user. The default pubsub model is multicast with bidirectional
81blocking, meaning a publisher ("pub") will send its message to all
82subscribers ("sub"). Further, both "pub" and "sub" will wait for
83at least one event to be sent or received. Pipe ("pipe") allows
84for bidirectional messages to be sent between any clients connected
85to a pipe.
86
87Think of these different commands in terms of the direction the
88data is being sent:
89
90- pub => writes to client
91- sub => reads from client
92- pipe => read and write between clients
93`, sshCmd)
94}
95
96type CliHandler struct {
97 DBPool db.DB
98 Logger *slog.Logger
99 PubSub psub.PubSub
100 Cfg *shared.ConfigSite
101 Waiters *syncmap.Map[string, []string]
102 Access *syncmap.Map[string, []string]
103}
104
105func toSshCmd(cfg *shared.ConfigSite) string {
106 port := ""
107 if cfg.PortOverride != "22" {
108 port = fmt.Sprintf("-p %s ", cfg.PortOverride)
109 }
110 return fmt.Sprintf("%s%s", port, cfg.Domain)
111}
112
113// parseArgList parses a comma separated list of arguments.
114func parseArgList(arg string) []string {
115 argList := strings.Split(arg, ",")
116 for i, acc := range argList {
117 argList[i] = strings.TrimSpace(acc)
118 }
119 return argList
120}
121
122// checkAccess checks if the user has access to a topic based on an access list.
123func checkAccess(accessList []string, userName string, sesh ssh.Session) bool {
124 for _, acc := range accessList {
125 if acc == userName {
126 return true
127 }
128
129 if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
130 return true
131 }
132 }
133
134 return false
135}
136
137func WishMiddleware(handler *CliHandler) wish.Middleware {
138 pubsub := handler.PubSub
139
140 return func(next ssh.Handler) ssh.Handler {
141 return func(sesh ssh.Session) {
142 logger := handler.Logger
143 ctx := sesh.Context()
144
145 pubkey := utils.KeyForKeyText(sesh.PublicKey())
146 user, err := handler.DBPool.FindUserForKey(sesh.User(), pubkey)
147 if err != nil {
148 logger.Info("user not found", "err", err)
149 }
150
151 if user != nil {
152 logger = shared.LoggerWithUser(logger, user)
153 }
154
155 args := sesh.Command()
156
157 if len(args) == 0 {
158 wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
159 next(sesh)
160 return
161 }
162
163 userName := "public"
164
165 userNameAddition := ""
166
167 isAdmin := false
168 if user != nil {
169 isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
170
171 userName = user.Name
172 if user.PublicKey.Name != "" {
173 userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
174 }
175 }
176
177 pipeCtx, cancel := context.WithCancel(ctx)
178
179 go func() {
180 defer cancel()
181
182 for {
183 select {
184 case <-pipeCtx.Done():
185 return
186 default:
187 _, err := sesh.SendRequest("ping@pico.sh", false, nil)
188 if err != nil {
189 logger.Error("error sending ping", "err", err)
190 return
191 }
192
193 time.Sleep(5 * time.Second)
194 }
195 }
196 }()
197
198 cmd := strings.TrimSpace(args[0])
199 if cmd == "help" {
200 wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
201 next(sesh)
202 return
203 } else if cmd == "ls" {
204 if userName == "public" {
205 wish.Fatalln(sesh, "access denied")
206 return
207 }
208
209 topicFilter := fmt.Sprintf("%s/", userName)
210 if isAdmin {
211 topicFilter = ""
212 if len(args) > 1 {
213 topicFilter = args[1]
214 }
215 }
216
217 var channels []*psub.Channel
218 waitingChannels := map[string][]string{}
219
220 for topic, channel := range pubsub.GetChannels() {
221 if strings.HasPrefix(topic, topicFilter) {
222 channels = append(channels, channel)
223 }
224 }
225
226 for channel, clients := range handler.Waiters.Range {
227 if strings.HasPrefix(channel, topicFilter) {
228 waitingChannels[channel] = clients
229 }
230 }
231
232 if len(channels) == 0 && len(waitingChannels) == 0 {
233 wish.Println(sesh, "no pubsub channels found")
234 } else {
235 var outputData string
236 if len(channels) > 0 || len(waitingChannels) > 0 {
237 outputData += "Channel Information\r\n"
238 for _, channel := range channels {
239 extraData := ""
240
241 if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
242 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
243 }
244
245 outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
246 outputData += " Clients:\r\n"
247
248 var pubs []*psub.Client
249 var subs []*psub.Client
250 var pipes []*psub.Client
251
252 for _, client := range channel.GetClients() {
253 if client.Direction == psub.ChannelDirectionInput {
254 pubs = append(pubs, client)
255 } else if client.Direction == psub.ChannelDirectionOutput {
256 subs = append(subs, client)
257 } else if client.Direction == psub.ChannelDirectionInputOutput {
258 pipes = append(pipes, client)
259 }
260 }
261 outputData += clientInfo(pubs, "Pubs")
262 outputData += clientInfo(subs, "Subs")
263 outputData += clientInfo(pipes, "Pipes")
264 }
265
266 for waitingChannel, channelPubs := range waitingChannels {
267 extraData := ""
268
269 if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
270 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
271 }
272
273 outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
274 outputData += " Clients:\r\n"
275 outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
276 for _, client := range channelPubs {
277 outputData += fmt.Sprintf(" - %s\r\n", client)
278 }
279 }
280 }
281
282 _, _ = sesh.Write([]byte(outputData))
283 }
284
285 next(sesh)
286 return
287 }
288
289 topic := ""
290 cmdArgs := args[1:]
291 if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
292 topic = strings.TrimSpace(args[1])
293 cmdArgs = args[2:]
294 }
295
296 logger.Info(
297 "pubsub middleware detected command",
298 "args", args,
299 "cmd", cmd,
300 "topic", topic,
301 "cmdArgs", cmdArgs,
302 )
303
304 clientID := fmt.Sprintf("%s (%s%s@%s)", uuid.NewString(), userName, userNameAddition, sesh.RemoteAddr().String())
305
306 if cmd == "pub" {
307 pubCmd := flagSet("pub", sesh)
308 access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
309 empty := pubCmd.Bool("e", false, "Send an empty message to subs")
310 public := pubCmd.Bool("p", false, "Publish message to public topic")
311 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
312 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
313 clean := pubCmd.Bool("c", false, "Don't send status messages")
314
315 if !flagCheck(pubCmd, topic, cmdArgs) {
316 return
317 }
318
319 if pubCmd.NArg() == 1 && topic == "" {
320 topic = pubCmd.Arg(0)
321 }
322
323 logger.Info(
324 "flags parsed",
325 "cmd", cmd,
326 "empty", *empty,
327 "public", *public,
328 "block", *block,
329 "timeout", *timeout,
330 "topic", topic,
331 "access", *access,
332 "clean", *clean,
333 )
334
335 var accessList []string
336
337 if *access != "" {
338 accessList = parseArgList(*access)
339 }
340
341 var rw io.ReadWriter
342 if *empty {
343 rw = bytes.NewBuffer(make([]byte, 1))
344 } else {
345 rw = sesh
346 }
347
348 if topic == "" {
349 topic = uuid.NewString()
350 }
351
352 var withoutUser string
353 var name string
354 msgFlag := ""
355
356 if isAdmin && strings.HasPrefix(topic, "/") {
357 name = strings.TrimPrefix(topic, "/")
358 } else {
359 name = toTopic(userName, topic)
360 if *public {
361 name = toPublicTopic(topic)
362 msgFlag = "-p "
363 withoutUser = name
364 } else {
365 withoutUser = topic
366 }
367 }
368
369 var accessListCreator bool
370
371 _, loaded := handler.Access.LoadOrStore(name, accessList)
372 if !loaded {
373 defer func() {
374 handler.Access.Delete(name)
375 }()
376
377 accessListCreator = true
378 }
379
380 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
381 if checkAccess(accessList, userName, sesh) || accessListCreator {
382 name = withoutUser
383 } else if !*public {
384 name = toTopic(userName, withoutUser)
385 } else {
386 topic = uuid.NewString()
387 name = toPublicTopic(topic)
388 }
389 }
390
391 if !*clean {
392 wish.Printf(
393 sesh,
394 "subscribe to this channel:\n ssh %s sub %s%s\n",
395 toSshCmd(handler.Cfg),
396 msgFlag,
397 topic,
398 )
399 }
400
401 if *block {
402 count := 0
403 for topic, channel := range pubsub.GetChannels() {
404 if topic == name {
405 for _, client := range channel.GetClients() {
406 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
407 count++
408 }
409 }
410 break
411 }
412 }
413
414 tt := *timeout
415 if count == 0 {
416 currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
417 handler.Waiters.Store(name, append(currentWaiters, clientID))
418
419 termMsg := "no subs found ... waiting"
420 if tt > 0 {
421 termMsg += " " + tt.String()
422 }
423
424 if !*clean {
425 wish.Println(sesh, termMsg)
426 }
427
428 ready := make(chan struct{})
429
430 go func() {
431 for {
432 select {
433 case <-pipeCtx.Done():
434 cancel()
435 return
436 case <-time.After(1 * time.Millisecond):
437 count := 0
438 for topic, channel := range pubsub.GetChannels() {
439 if topic == name {
440 for _, client := range channel.GetClients() {
441 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
442 count++
443 }
444 }
445 break
446 }
447 }
448
449 if count > 0 {
450 close(ready)
451 return
452 }
453 }
454 }
455 }()
456
457 select {
458 case <-ready:
459 case <-pipeCtx.Done():
460 case <-time.After(tt):
461 cancel()
462
463 if !*clean {
464 wish.Fatalln(sesh, "timeout reached, exiting ...")
465 } else {
466 err = sesh.Exit(1)
467 if err != nil {
468 logger.Error("error exiting session", "err", err)
469 }
470
471 sesh.Close()
472 }
473 }
474
475 newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
476 newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
477 return cl == clientID
478 })
479 handler.Waiters.Store(name, newWaiters)
480
481 var toDelete []string
482
483 for channel, clients := range handler.Waiters.Range {
484 if len(clients) == 0 {
485 toDelete = append(toDelete, channel)
486 }
487 }
488
489 for _, channel := range toDelete {
490 handler.Waiters.Delete(channel)
491 }
492 }
493 }
494
495 if !*clean {
496 wish.Println(sesh, "sending msg ...")
497 }
498
499 err = pubsub.Pub(
500 pipeCtx,
501 clientID,
502 rw,
503 []*psub.Channel{
504 psub.NewChannel(name),
505 },
506 *block,
507 )
508
509 if !*clean {
510 wish.Println(sesh, "msg sent!")
511 }
512
513 if err != nil && !*clean {
514 wish.Errorln(sesh, err)
515 }
516 } else if cmd == "sub" {
517 subCmd := flagSet("sub", sesh)
518 access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
519 public := subCmd.Bool("p", false, "Subscribe to a public topic")
520 keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
521 clean := subCmd.Bool("c", false, "Don't send status messages")
522
523 if !flagCheck(subCmd, topic, cmdArgs) {
524 return
525 }
526
527 if subCmd.NArg() == 1 && topic == "" {
528 topic = subCmd.Arg(0)
529 }
530
531 logger.Info(
532 "flags parsed",
533 "cmd", cmd,
534 "public", *public,
535 "keepAlive", *keepAlive,
536 "topic", topic,
537 "clean", *clean,
538 "access", *access,
539 )
540
541 var accessList []string
542
543 if *access != "" {
544 accessList = parseArgList(*access)
545 }
546
547 var withoutUser string
548 var name string
549
550 if isAdmin && strings.HasPrefix(topic, "/") {
551 name = strings.TrimPrefix(topic, "/")
552 } else {
553 name = toTopic(userName, topic)
554 if *public {
555 name = toPublicTopic(topic)
556 withoutUser = name
557 } else {
558 withoutUser = topic
559 }
560 }
561
562 var accessListCreator bool
563
564 _, loaded := handler.Access.LoadOrStore(name, accessList)
565 if !loaded {
566 defer func() {
567 handler.Access.Delete(name)
568 }()
569
570 accessListCreator = true
571 }
572
573 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
574 if checkAccess(accessList, userName, sesh) || accessListCreator {
575 name = withoutUser
576 } else if !*public {
577 name = toTopic(userName, withoutUser)
578 } else {
579 wish.Errorln(sesh, "access denied")
580 return
581 }
582 }
583
584 err = pubsub.Sub(
585 pipeCtx,
586 clientID,
587 sesh,
588 []*psub.Channel{
589 psub.NewChannel(name),
590 },
591 *keepAlive,
592 )
593
594 if err != nil && !*clean {
595 wish.Errorln(sesh, err)
596 }
597 } else if cmd == "pipe" {
598 pipeCmd := flagSet("pipe", sesh)
599 access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
600 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
601 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
602 clean := pipeCmd.Bool("c", false, "Don't send status messages")
603
604 if !flagCheck(pipeCmd, topic, cmdArgs) {
605 return
606 }
607
608 if pipeCmd.NArg() == 1 && topic == "" {
609 topic = pipeCmd.Arg(0)
610 }
611
612 logger.Info(
613 "flags parsed",
614 "cmd", cmd,
615 "public", *public,
616 "replay", *replay,
617 "topic", topic,
618 "access", *access,
619 "clean", *clean,
620 )
621
622 var accessList []string
623
624 if *access != "" {
625 accessList = parseArgList(*access)
626 }
627
628 isCreator := topic == ""
629 if isCreator {
630 topic = uuid.NewString()
631 }
632
633 var withoutUser string
634 var name string
635 flagMsg := ""
636
637 if isAdmin && strings.HasPrefix(topic, "/") {
638 name = strings.TrimPrefix(topic, "/")
639 } else {
640 name = toTopic(userName, topic)
641 if *public {
642 name = toPublicTopic(topic)
643 flagMsg = "-p "
644 withoutUser = name
645 } else {
646 withoutUser = topic
647 }
648 }
649
650 var accessListCreator bool
651
652 _, loaded := handler.Access.LoadOrStore(name, accessList)
653 if !loaded {
654 defer func() {
655 handler.Access.Delete(name)
656 }()
657
658 accessListCreator = true
659 }
660
661 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
662 if checkAccess(accessList, userName, sesh) || accessListCreator {
663 name = withoutUser
664 } else if !*public {
665 name = toTopic(userName, withoutUser)
666 } else {
667 topic = uuid.NewString()
668 name = toPublicTopic(topic)
669 }
670 }
671
672 if isCreator && !*clean {
673 wish.Printf(
674 sesh,
675 "subscribe to this topic:\n ssh %s sub %s%s\n",
676 toSshCmd(handler.Cfg),
677 flagMsg,
678 topic,
679 )
680 }
681
682 readErr, writeErr := pubsub.Pipe(
683 pipeCtx,
684 clientID,
685 sesh,
686 []*psub.Channel{
687 psub.NewChannel(name),
688 },
689 *replay,
690 )
691
692 if readErr != nil && !*clean {
693 wish.Errorln(sesh, "error reading from pipe", readErr)
694 }
695
696 if writeErr != nil && !*clean {
697 wish.Errorln(sesh, "error writing to pipe", writeErr)
698 }
699 }
700
701 next(sesh)
702 }
703 }
704}