Antonio Mika
·
22 Nov 24
cli.go
1package pipe
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "log/slog"
10 "slices"
11 "strings"
12 "text/tabwriter"
13 "time"
14
15 "github.com/antoniomika/syncmap"
16 "github.com/charmbracelet/ssh"
17 "github.com/charmbracelet/wish"
18 "github.com/google/uuid"
19 "github.com/picosh/pico/db"
20 "github.com/picosh/pico/shared"
21 psub "github.com/picosh/pubsub"
22 gossh "golang.org/x/crypto/ssh"
23)
24
25func flagSet(cmdName string, sesh ssh.Session) *flag.FlagSet {
26 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
27 cmd.SetOutput(sesh)
28 cmd.Usage = func() {
29 fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
30 cmd.PrintDefaults()
31 }
32 return cmd
33}
34
35func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
36 err := cmd.Parse(cmdArgs)
37
38 if err != nil || posArg == "help" {
39 if posArg == "help" {
40 cmd.Usage()
41 }
42 return false
43 }
44 return true
45}
46
47func NewTabWriter(out io.Writer) *tabwriter.Writer {
48 return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
49}
50
51// scope topic to user by prefixing name.
52func toTopic(userName, topic string) string {
53 return fmt.Sprintf("%s/%s", userName, topic)
54}
55
56func toPublicTopic(topic string) string {
57 return fmt.Sprintf("public/%s", topic)
58}
59
60func clientInfo(clients []*psub.Client, clientType string) string {
61 if len(clients) == 0 {
62 return ""
63 }
64
65 outputData := fmt.Sprintf(" %s:\r\n", clientType)
66
67 for _, client := range clients {
68 outputData += fmt.Sprintf(" - %s\r\n", client.ID)
69 }
70
71 return outputData
72}
73
74var helpStr = func(sshCmd string) string {
75 return fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
76
77The simplest authenticated pubsub system. Send messages through
78user-defined topics. Topics are private to the authenticated
79ssh user. The default pubsub model is multicast with bidirectional
80blocking, meaning a publisher ("pub") will send its message to all
81subscribers ("sub"). Further, both "pub" and "sub" will wait for
82at least one event to be sent or received. Pipe ("pipe") allows
83for bidirectional messages to be sent between any clients connected
84to a pipe.
85
86Think of these different commands in terms of the direction the
87data is being sent:
88
89- pub => writes to client
90- sub => reads from client
91- pipe => read and write between clients
92`, sshCmd)
93}
94
95type CliHandler struct {
96 DBPool db.DB
97 Logger *slog.Logger
98 PubSub psub.PubSub
99 Cfg *shared.ConfigSite
100 Waiters *syncmap.Map[string, []string]
101 Access *syncmap.Map[string, []string]
102}
103
104func toSshCmd(cfg *shared.ConfigSite) string {
105 port := ""
106 if cfg.PortOverride != "22" {
107 port = fmt.Sprintf("-p %s ", cfg.PortOverride)
108 }
109 return fmt.Sprintf("%s%s", port, cfg.Domain)
110}
111
112// parseArgList parses a comma separated list of arguments.
113func parseArgList(arg string) []string {
114 argList := strings.Split(arg, ",")
115 for i, acc := range argList {
116 argList[i] = strings.TrimSpace(acc)
117 }
118 return argList
119}
120
121// checkAccess checks if the user has access to a topic based on an access list.
122func checkAccess(accessList []string, userName string, sesh ssh.Session) bool {
123 for _, acc := range accessList {
124 if acc == userName {
125 return true
126 }
127
128 if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
129 return true
130 }
131 }
132
133 return false
134}
135
136func WishMiddleware(handler *CliHandler) wish.Middleware {
137 pubsub := handler.PubSub
138
139 return func(next ssh.Handler) ssh.Handler {
140 return func(sesh ssh.Session) {
141 logger := handler.Logger
142 ctx := sesh.Context()
143
144 user, err := shared.GetUser(ctx)
145 if err != nil {
146 logger.Info("user not found", "err", err)
147 }
148
149 if user != nil {
150 logger = shared.LoggerWithUser(logger, user)
151 }
152
153 args := sesh.Command()
154
155 if len(args) == 0 {
156 wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
157 next(sesh)
158 return
159 }
160
161 userName := "public"
162
163 userNameAddition := ""
164
165 isAdmin := false
166 if user != nil {
167 isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
168
169 userName = user.Name
170 if user.PublicKey.Name != "" {
171 userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
172 }
173 }
174
175 pipeCtx, cancel := context.WithCancel(ctx)
176
177 go func() {
178 defer cancel()
179
180 for {
181 select {
182 case <-pipeCtx.Done():
183 return
184 default:
185 _, err := sesh.SendRequest("ping@pico.sh", false, nil)
186 if err != nil {
187 logger.Error("error sending ping", "err", err)
188 return
189 }
190
191 time.Sleep(5 * time.Second)
192 }
193 }
194 }()
195
196 cmd := strings.TrimSpace(args[0])
197 if cmd == "help" {
198 wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
199 next(sesh)
200 return
201 } else if cmd == "ls" {
202 if userName == "public" {
203 wish.Fatalln(sesh, "access denied")
204 return
205 }
206
207 topicFilter := fmt.Sprintf("%s/", userName)
208 if isAdmin {
209 topicFilter = ""
210 if len(args) > 1 {
211 topicFilter = args[1]
212 }
213 }
214
215 var channels []*psub.Channel
216 waitingChannels := map[string][]string{}
217
218 for topic, channel := range pubsub.GetChannels() {
219 if strings.HasPrefix(topic, topicFilter) {
220 channels = append(channels, channel)
221 }
222 }
223
224 for channel, clients := range handler.Waiters.Range {
225 if strings.HasPrefix(channel, topicFilter) {
226 waitingChannels[channel] = clients
227 }
228 }
229
230 if len(channels) == 0 && len(waitingChannels) == 0 {
231 wish.Println(sesh, "no pubsub channels found")
232 } else {
233 var outputData string
234 if len(channels) > 0 || len(waitingChannels) > 0 {
235 outputData += "Channel Information\r\n"
236 for _, channel := range channels {
237 extraData := ""
238
239 if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
240 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
241 }
242
243 outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
244 outputData += " Clients:\r\n"
245
246 var pubs []*psub.Client
247 var subs []*psub.Client
248 var pipes []*psub.Client
249
250 for _, client := range channel.GetClients() {
251 if client.Direction == psub.ChannelDirectionInput {
252 pubs = append(pubs, client)
253 } else if client.Direction == psub.ChannelDirectionOutput {
254 subs = append(subs, client)
255 } else if client.Direction == psub.ChannelDirectionInputOutput {
256 pipes = append(pipes, client)
257 }
258 }
259 outputData += clientInfo(pubs, "Pubs")
260 outputData += clientInfo(subs, "Subs")
261 outputData += clientInfo(pipes, "Pipes")
262 }
263
264 for waitingChannel, channelPubs := range waitingChannels {
265 extraData := ""
266
267 if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
268 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
269 }
270
271 outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
272 outputData += " Clients:\r\n"
273 outputData += fmt.Sprintf(" %s:\r\n", "Waiting Pubs")
274 for _, client := range channelPubs {
275 outputData += fmt.Sprintf(" - %s\r\n", client)
276 }
277 }
278 }
279
280 _, _ = sesh.Write([]byte(outputData))
281 }
282
283 next(sesh)
284 return
285 }
286
287 topic := ""
288 cmdArgs := args[1:]
289 if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
290 topic = strings.TrimSpace(args[1])
291 cmdArgs = args[2:]
292 }
293
294 logger.Info(
295 "pubsub middleware detected command",
296 "args", args,
297 "cmd", cmd,
298 "topic", topic,
299 "cmdArgs", cmdArgs,
300 )
301
302 clientID := fmt.Sprintf("%s (%s%s@%s)", uuid.NewString(), userName, userNameAddition, sesh.RemoteAddr().String())
303
304 if cmd == "pub" {
305 pubCmd := flagSet("pub", sesh)
306 access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
307 empty := pubCmd.Bool("e", false, "Send an empty message to subs")
308 public := pubCmd.Bool("p", false, "Publish message to public topic")
309 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
310 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
311 clean := pubCmd.Bool("c", false, "Don't send status messages")
312
313 if !flagCheck(pubCmd, topic, cmdArgs) {
314 return
315 }
316
317 if pubCmd.NArg() == 1 && topic == "" {
318 topic = pubCmd.Arg(0)
319 }
320
321 logger.Info(
322 "flags parsed",
323 "cmd", cmd,
324 "empty", *empty,
325 "public", *public,
326 "block", *block,
327 "timeout", *timeout,
328 "topic", topic,
329 "access", *access,
330 "clean", *clean,
331 )
332
333 var accessList []string
334
335 if *access != "" {
336 accessList = parseArgList(*access)
337 }
338
339 var rw io.ReadWriter
340 if *empty {
341 rw = bytes.NewBuffer(make([]byte, 1))
342 } else {
343 rw = sesh
344 }
345
346 if topic == "" {
347 topic = uuid.NewString()
348 }
349
350 var withoutUser string
351 var name string
352 msgFlag := ""
353
354 if isAdmin && strings.HasPrefix(topic, "/") {
355 name = strings.TrimPrefix(topic, "/")
356 } else {
357 name = toTopic(userName, topic)
358 if *public {
359 name = toPublicTopic(topic)
360 msgFlag = "-p "
361 withoutUser = name
362 } else {
363 withoutUser = topic
364 }
365 }
366
367 var accessListCreator bool
368
369 _, loaded := handler.Access.LoadOrStore(name, accessList)
370 if !loaded {
371 defer func() {
372 handler.Access.Delete(name)
373 }()
374
375 accessListCreator = true
376 }
377
378 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
379 if checkAccess(accessList, userName, sesh) || accessListCreator {
380 name = withoutUser
381 } else if !*public {
382 name = toTopic(userName, withoutUser)
383 } else {
384 topic = uuid.NewString()
385 name = toPublicTopic(topic)
386 }
387 }
388
389 if !*clean {
390 wish.Printf(
391 sesh,
392 "subscribe to this channel:\n ssh %s sub %s%s\n",
393 toSshCmd(handler.Cfg),
394 msgFlag,
395 topic,
396 )
397 }
398
399 if *block {
400 count := 0
401 for topic, channel := range pubsub.GetChannels() {
402 if topic == name {
403 for _, client := range channel.GetClients() {
404 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
405 count++
406 }
407 }
408 break
409 }
410 }
411
412 tt := *timeout
413 if count == 0 {
414 currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
415 handler.Waiters.Store(name, append(currentWaiters, clientID))
416
417 termMsg := "no subs found ... waiting"
418 if tt > 0 {
419 termMsg += " " + tt.String()
420 }
421
422 if !*clean {
423 wish.Println(sesh, termMsg)
424 }
425
426 ready := make(chan struct{})
427
428 go func() {
429 for {
430 select {
431 case <-pipeCtx.Done():
432 cancel()
433 return
434 case <-time.After(1 * time.Millisecond):
435 count := 0
436 for topic, channel := range pubsub.GetChannels() {
437 if topic == name {
438 for _, client := range channel.GetClients() {
439 if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
440 count++
441 }
442 }
443 break
444 }
445 }
446
447 if count > 0 {
448 close(ready)
449 return
450 }
451 }
452 }
453 }()
454
455 select {
456 case <-ready:
457 case <-pipeCtx.Done():
458 case <-time.After(tt):
459 cancel()
460
461 if !*clean {
462 wish.Fatalln(sesh, "timeout reached, exiting ...")
463 } else {
464 err = sesh.Exit(1)
465 if err != nil {
466 logger.Error("error exiting session", "err", err)
467 }
468
469 sesh.Close()
470 }
471 }
472
473 newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
474 newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
475 return cl == clientID
476 })
477 handler.Waiters.Store(name, newWaiters)
478
479 var toDelete []string
480
481 for channel, clients := range handler.Waiters.Range {
482 if len(clients) == 0 {
483 toDelete = append(toDelete, channel)
484 }
485 }
486
487 for _, channel := range toDelete {
488 handler.Waiters.Delete(channel)
489 }
490 }
491 }
492
493 if !*clean {
494 wish.Println(sesh, "sending msg ...")
495 }
496
497 err = pubsub.Pub(
498 pipeCtx,
499 clientID,
500 rw,
501 []*psub.Channel{
502 psub.NewChannel(name),
503 },
504 *block,
505 )
506
507 if !*clean {
508 wish.Println(sesh, "msg sent!")
509 }
510
511 if err != nil && !*clean {
512 wish.Errorln(sesh, err)
513 }
514 } else if cmd == "sub" {
515 subCmd := flagSet("sub", sesh)
516 access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
517 public := subCmd.Bool("p", false, "Subscribe to a public topic")
518 keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
519 clean := subCmd.Bool("c", false, "Don't send status messages")
520
521 if !flagCheck(subCmd, topic, cmdArgs) {
522 return
523 }
524
525 if subCmd.NArg() == 1 && topic == "" {
526 topic = subCmd.Arg(0)
527 }
528
529 logger.Info(
530 "flags parsed",
531 "cmd", cmd,
532 "public", *public,
533 "keepAlive", *keepAlive,
534 "topic", topic,
535 "clean", *clean,
536 "access", *access,
537 )
538
539 var accessList []string
540
541 if *access != "" {
542 accessList = parseArgList(*access)
543 }
544
545 var withoutUser string
546 var name string
547
548 if isAdmin && strings.HasPrefix(topic, "/") {
549 name = strings.TrimPrefix(topic, "/")
550 } else {
551 name = toTopic(userName, topic)
552 if *public {
553 name = toPublicTopic(topic)
554 withoutUser = name
555 } else {
556 withoutUser = topic
557 }
558 }
559
560 var accessListCreator bool
561
562 _, loaded := handler.Access.LoadOrStore(name, accessList)
563 if !loaded {
564 defer func() {
565 handler.Access.Delete(name)
566 }()
567
568 accessListCreator = true
569 }
570
571 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
572 if checkAccess(accessList, userName, sesh) || accessListCreator {
573 name = withoutUser
574 } else if !*public {
575 name = toTopic(userName, withoutUser)
576 } else {
577 wish.Errorln(sesh, "access denied")
578 return
579 }
580 }
581
582 err = pubsub.Sub(
583 pipeCtx,
584 clientID,
585 sesh,
586 []*psub.Channel{
587 psub.NewChannel(name),
588 },
589 *keepAlive,
590 )
591
592 if err != nil && !*clean {
593 wish.Errorln(sesh, err)
594 }
595 } else if cmd == "pipe" {
596 pipeCmd := flagSet("pipe", sesh)
597 access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
598 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
599 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
600 clean := pipeCmd.Bool("c", false, "Don't send status messages")
601
602 if !flagCheck(pipeCmd, topic, cmdArgs) {
603 return
604 }
605
606 if pipeCmd.NArg() == 1 && topic == "" {
607 topic = pipeCmd.Arg(0)
608 }
609
610 logger.Info(
611 "flags parsed",
612 "cmd", cmd,
613 "public", *public,
614 "replay", *replay,
615 "topic", topic,
616 "access", *access,
617 "clean", *clean,
618 )
619
620 var accessList []string
621
622 if *access != "" {
623 accessList = parseArgList(*access)
624 }
625
626 isCreator := topic == ""
627 if isCreator {
628 topic = uuid.NewString()
629 }
630
631 var withoutUser string
632 var name string
633 flagMsg := ""
634
635 if isAdmin && strings.HasPrefix(topic, "/") {
636 name = strings.TrimPrefix(topic, "/")
637 } else {
638 name = toTopic(userName, topic)
639 if *public {
640 name = toPublicTopic(topic)
641 flagMsg = "-p "
642 withoutUser = name
643 } else {
644 withoutUser = topic
645 }
646 }
647
648 var accessListCreator bool
649
650 _, loaded := handler.Access.LoadOrStore(name, accessList)
651 if !loaded {
652 defer func() {
653 handler.Access.Delete(name)
654 }()
655
656 accessListCreator = true
657 }
658
659 if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
660 if checkAccess(accessList, userName, sesh) || accessListCreator {
661 name = withoutUser
662 } else if !*public {
663 name = toTopic(userName, withoutUser)
664 } else {
665 topic = uuid.NewString()
666 name = toPublicTopic(topic)
667 }
668 }
669
670 if isCreator && !*clean {
671 wish.Printf(
672 sesh,
673 "subscribe to this topic:\n ssh %s sub %s%s\n",
674 toSshCmd(handler.Cfg),
675 flagMsg,
676 topic,
677 )
678 }
679
680 readErr, writeErr := pubsub.Pipe(
681 pipeCtx,
682 clientID,
683 sesh,
684 []*psub.Channel{
685 psub.NewChannel(name),
686 },
687 *replay,
688 )
689
690 if readErr != nil && !*clean {
691 wish.Errorln(sesh, "error reading from pipe", readErr)
692 }
693
694 if writeErr != nil && !*clean {
695 wish.Errorln(sesh, "error writing to pipe", writeErr)
696 }
697 }
698
699 next(sesh)
700 }
701 }
702}