repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pipe
Eric Bower · 13 Dec 24

cli.go

  1package pipe
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12	"text/tabwriter"
 13	"time"
 14
 15	"github.com/antoniomika/syncmap"
 16	"github.com/charmbracelet/ssh"
 17	"github.com/charmbracelet/wish"
 18	"github.com/google/uuid"
 19	"github.com/picosh/pico/db"
 20	"github.com/picosh/pico/shared"
 21	psub "github.com/picosh/pubsub"
 22	"github.com/picosh/utils"
 23	gossh "golang.org/x/crypto/ssh"
 24)
 25
 26func flagSet(cmdName string, sesh ssh.Session) *flag.FlagSet {
 27	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 28	cmd.SetOutput(sesh)
 29	cmd.Usage = func() {
 30		fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
 31		cmd.PrintDefaults()
 32	}
 33	return cmd
 34}
 35
 36func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 37	err := cmd.Parse(cmdArgs)
 38
 39	if err != nil || posArg == "help" {
 40		if posArg == "help" {
 41			cmd.Usage()
 42		}
 43		return false
 44	}
 45	return true
 46}
 47
 48func NewTabWriter(out io.Writer) *tabwriter.Writer {
 49	return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
 50}
 51
 52// scope topic to user by prefixing name.
 53func toTopic(userName, topic string) string {
 54	return fmt.Sprintf("%s/%s", userName, topic)
 55}
 56
 57func toPublicTopic(topic string) string {
 58	return fmt.Sprintf("public/%s", topic)
 59}
 60
 61func clientInfo(clients []*psub.Client, clientType string) string {
 62	if len(clients) == 0 {
 63		return ""
 64	}
 65
 66	outputData := fmt.Sprintf("    %s:\r\n", clientType)
 67
 68	for _, client := range clients {
 69		outputData += fmt.Sprintf("    - %s\r\n", client.ID)
 70	}
 71
 72	return outputData
 73}
 74
 75var helpStr = func(sshCmd string) string {
 76	return fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
 77
 78The simplest authenticated pubsub system.  Send messages through
 79user-defined topics.  Topics are private to the authenticated
 80ssh user.  The default pubsub model is multicast with bidirectional
 81blocking, meaning a publisher ("pub") will send its message to all
 82subscribers ("sub").  Further, both "pub" and "sub" will wait for
 83at least one event to be sent or received. Pipe ("pipe") allows
 84for bidirectional messages to be sent between any clients connected
 85to a pipe.
 86
 87Think of these different commands in terms of the direction the
 88data is being sent:
 89
 90- pub => writes to client
 91- sub => reads from client
 92- pipe => read and write between clients
 93`, sshCmd)
 94}
 95
 96type CliHandler struct {
 97	DBPool  db.DB
 98	Logger  *slog.Logger
 99	PubSub  psub.PubSub
100	Cfg     *shared.ConfigSite
101	Waiters *syncmap.Map[string, []string]
102	Access  *syncmap.Map[string, []string]
103}
104
105func toSshCmd(cfg *shared.ConfigSite) string {
106	port := ""
107	if cfg.PortOverride != "22" {
108		port = fmt.Sprintf("-p %s ", cfg.PortOverride)
109	}
110	return fmt.Sprintf("%s%s", port, cfg.Domain)
111}
112
113// parseArgList parses a comma separated list of arguments.
114func parseArgList(arg string) []string {
115	argList := strings.Split(arg, ",")
116	for i, acc := range argList {
117		argList[i] = strings.TrimSpace(acc)
118	}
119	return argList
120}
121
122// checkAccess checks if the user has access to a topic based on an access list.
123func checkAccess(accessList []string, userName string, sesh ssh.Session) bool {
124	for _, acc := range accessList {
125		if acc == userName {
126			return true
127		}
128
129		if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
130			return true
131		}
132	}
133
134	return false
135}
136
137func WishMiddleware(handler *CliHandler) wish.Middleware {
138	pubsub := handler.PubSub
139
140	return func(next ssh.Handler) ssh.Handler {
141		return func(sesh ssh.Session) {
142			logger := handler.Logger
143			ctx := sesh.Context()
144
145			pubkey := utils.KeyForKeyText(sesh.PublicKey())
146			user, err := handler.DBPool.FindUserForKey(sesh.User(), pubkey)
147			if err != nil {
148				logger.Info("user not found", "err", err)
149			}
150
151			if user != nil {
152				logger = shared.LoggerWithUser(logger, user)
153			}
154
155			args := sesh.Command()
156
157			if len(args) == 0 {
158				wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
159				next(sesh)
160				return
161			}
162
163			userName := "public"
164
165			userNameAddition := ""
166
167			isAdmin := false
168			if user != nil {
169				isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
170
171				userName = user.Name
172				if user.PublicKey.Name != "" {
173					userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
174				}
175			}
176
177			pipeCtx, cancel := context.WithCancel(ctx)
178
179			go func() {
180				defer cancel()
181
182				for {
183					select {
184					case <-pipeCtx.Done():
185						return
186					default:
187						_, err := sesh.SendRequest("ping@pico.sh", false, nil)
188						if err != nil {
189							logger.Error("error sending ping", "err", err)
190							return
191						}
192
193						time.Sleep(5 * time.Second)
194					}
195				}
196			}()
197
198			cmd := strings.TrimSpace(args[0])
199			if cmd == "help" {
200				wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
201				next(sesh)
202				return
203			} else if cmd == "ls" {
204				if userName == "public" {
205					wish.Fatalln(sesh, "access denied")
206					return
207				}
208
209				topicFilter := fmt.Sprintf("%s/", userName)
210				if isAdmin {
211					topicFilter = ""
212					if len(args) > 1 {
213						topicFilter = args[1]
214					}
215				}
216
217				var channels []*psub.Channel
218				waitingChannels := map[string][]string{}
219
220				for topic, channel := range pubsub.GetChannels() {
221					if strings.HasPrefix(topic, topicFilter) {
222						channels = append(channels, channel)
223					}
224				}
225
226				for channel, clients := range handler.Waiters.Range {
227					if strings.HasPrefix(channel, topicFilter) {
228						waitingChannels[channel] = clients
229					}
230				}
231
232				if len(channels) == 0 && len(waitingChannels) == 0 {
233					wish.Println(sesh, "no pubsub channels found")
234				} else {
235					var outputData string
236					if len(channels) > 0 || len(waitingChannels) > 0 {
237						outputData += "Channel Information\r\n"
238						for _, channel := range channels {
239							extraData := ""
240
241							if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
242								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
243							}
244
245							outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
246							outputData += "  Clients:\r\n"
247
248							var pubs []*psub.Client
249							var subs []*psub.Client
250							var pipes []*psub.Client
251
252							for _, client := range channel.GetClients() {
253								if client.Direction == psub.ChannelDirectionInput {
254									pubs = append(pubs, client)
255								} else if client.Direction == psub.ChannelDirectionOutput {
256									subs = append(subs, client)
257								} else if client.Direction == psub.ChannelDirectionInputOutput {
258									pipes = append(pipes, client)
259								}
260							}
261							outputData += clientInfo(pubs, "Pubs")
262							outputData += clientInfo(subs, "Subs")
263							outputData += clientInfo(pipes, "Pipes")
264						}
265
266						for waitingChannel, channelPubs := range waitingChannels {
267							extraData := ""
268
269							if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
270								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
271							}
272
273							outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
274							outputData += "  Clients:\r\n"
275							outputData += fmt.Sprintf("    %s:\r\n", "Waiting Pubs")
276							for _, client := range channelPubs {
277								outputData += fmt.Sprintf("    - %s\r\n", client)
278							}
279						}
280					}
281
282					_, _ = sesh.Write([]byte(outputData))
283				}
284
285				next(sesh)
286				return
287			}
288
289			topic := ""
290			cmdArgs := args[1:]
291			if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
292				topic = strings.TrimSpace(args[1])
293				cmdArgs = args[2:]
294			}
295
296			logger.Info(
297				"pubsub middleware detected command",
298				"args", args,
299				"cmd", cmd,
300				"topic", topic,
301				"cmdArgs", cmdArgs,
302			)
303
304			clientID := fmt.Sprintf("%s (%s%s@%s)", uuid.NewString(), userName, userNameAddition, sesh.RemoteAddr().String())
305
306			if cmd == "pub" {
307				pubCmd := flagSet("pub", sesh)
308				access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
309				empty := pubCmd.Bool("e", false, "Send an empty message to subs")
310				public := pubCmd.Bool("p", false, "Publish message to public topic")
311				block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
312				timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
313				clean := pubCmd.Bool("c", false, "Don't send status messages")
314
315				if !flagCheck(pubCmd, topic, cmdArgs) {
316					return
317				}
318
319				if pubCmd.NArg() == 1 && topic == "" {
320					topic = pubCmd.Arg(0)
321				}
322
323				logger.Info(
324					"flags parsed",
325					"cmd", cmd,
326					"empty", *empty,
327					"public", *public,
328					"block", *block,
329					"timeout", *timeout,
330					"topic", topic,
331					"access", *access,
332					"clean", *clean,
333				)
334
335				var accessList []string
336
337				if *access != "" {
338					accessList = parseArgList(*access)
339				}
340
341				var rw io.ReadWriter
342				if *empty {
343					rw = bytes.NewBuffer(make([]byte, 1))
344				} else {
345					rw = sesh
346				}
347
348				if topic == "" {
349					topic = uuid.NewString()
350				}
351
352				var withoutUser string
353				var name string
354				msgFlag := ""
355
356				if isAdmin && strings.HasPrefix(topic, "/") {
357					name = strings.TrimPrefix(topic, "/")
358				} else {
359					name = toTopic(userName, topic)
360					if *public {
361						name = toPublicTopic(topic)
362						msgFlag = "-p "
363						withoutUser = name
364					} else {
365						withoutUser = topic
366					}
367				}
368
369				var accessListCreator bool
370
371				_, loaded := handler.Access.LoadOrStore(name, accessList)
372				if !loaded {
373					defer func() {
374						handler.Access.Delete(name)
375					}()
376
377					accessListCreator = true
378				}
379
380				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
381					if checkAccess(accessList, userName, sesh) || accessListCreator {
382						name = withoutUser
383					} else if !*public {
384						name = toTopic(userName, withoutUser)
385					} else {
386						topic = uuid.NewString()
387						name = toPublicTopic(topic)
388					}
389				}
390
391				if !*clean {
392					wish.Printf(
393						sesh,
394						"subscribe to this channel:\n  ssh %s sub %s%s\n",
395						toSshCmd(handler.Cfg),
396						msgFlag,
397						topic,
398					)
399				}
400
401				if *block {
402					count := 0
403					for topic, channel := range pubsub.GetChannels() {
404						if topic == name {
405							for _, client := range channel.GetClients() {
406								if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
407									count++
408								}
409							}
410							break
411						}
412					}
413
414					tt := *timeout
415					if count == 0 {
416						currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
417						handler.Waiters.Store(name, append(currentWaiters, clientID))
418
419						termMsg := "no subs found ... waiting"
420						if tt > 0 {
421							termMsg += " " + tt.String()
422						}
423
424						if !*clean {
425							wish.Println(sesh, termMsg)
426						}
427
428						ready := make(chan struct{})
429
430						go func() {
431							for {
432								select {
433								case <-pipeCtx.Done():
434									cancel()
435									return
436								case <-time.After(1 * time.Millisecond):
437									count := 0
438									for topic, channel := range pubsub.GetChannels() {
439										if topic == name {
440											for _, client := range channel.GetClients() {
441												if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
442													count++
443												}
444											}
445											break
446										}
447									}
448
449									if count > 0 {
450										close(ready)
451										return
452									}
453								}
454							}
455						}()
456
457						select {
458						case <-ready:
459						case <-pipeCtx.Done():
460						case <-time.After(tt):
461							cancel()
462
463							if !*clean {
464								wish.Fatalln(sesh, "timeout reached, exiting ...")
465							} else {
466								err = sesh.Exit(1)
467								if err != nil {
468									logger.Error("error exiting session", "err", err)
469								}
470
471								sesh.Close()
472							}
473						}
474
475						newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
476						newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
477							return cl == clientID
478						})
479						handler.Waiters.Store(name, newWaiters)
480
481						var toDelete []string
482
483						for channel, clients := range handler.Waiters.Range {
484							if len(clients) == 0 {
485								toDelete = append(toDelete, channel)
486							}
487						}
488
489						for _, channel := range toDelete {
490							handler.Waiters.Delete(channel)
491						}
492					}
493				}
494
495				if !*clean {
496					wish.Println(sesh, "sending msg ...")
497				}
498
499				err = pubsub.Pub(
500					pipeCtx,
501					clientID,
502					rw,
503					[]*psub.Channel{
504						psub.NewChannel(name),
505					},
506					*block,
507				)
508
509				if !*clean {
510					wish.Println(sesh, "msg sent!")
511				}
512
513				if err != nil && !*clean {
514					wish.Errorln(sesh, err)
515				}
516			} else if cmd == "sub" {
517				subCmd := flagSet("sub", sesh)
518				access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
519				public := subCmd.Bool("p", false, "Subscribe to a public topic")
520				keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
521				clean := subCmd.Bool("c", false, "Don't send status messages")
522
523				if !flagCheck(subCmd, topic, cmdArgs) {
524					return
525				}
526
527				if subCmd.NArg() == 1 && topic == "" {
528					topic = subCmd.Arg(0)
529				}
530
531				logger.Info(
532					"flags parsed",
533					"cmd", cmd,
534					"public", *public,
535					"keepAlive", *keepAlive,
536					"topic", topic,
537					"clean", *clean,
538					"access", *access,
539				)
540
541				var accessList []string
542
543				if *access != "" {
544					accessList = parseArgList(*access)
545				}
546
547				var withoutUser string
548				var name string
549
550				if isAdmin && strings.HasPrefix(topic, "/") {
551					name = strings.TrimPrefix(topic, "/")
552				} else {
553					name = toTopic(userName, topic)
554					if *public {
555						name = toPublicTopic(topic)
556						withoutUser = name
557					} else {
558						withoutUser = topic
559					}
560				}
561
562				var accessListCreator bool
563
564				_, loaded := handler.Access.LoadOrStore(name, accessList)
565				if !loaded {
566					defer func() {
567						handler.Access.Delete(name)
568					}()
569
570					accessListCreator = true
571				}
572
573				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
574					if checkAccess(accessList, userName, sesh) || accessListCreator {
575						name = withoutUser
576					} else if !*public {
577						name = toTopic(userName, withoutUser)
578					} else {
579						wish.Errorln(sesh, "access denied")
580						return
581					}
582				}
583
584				err = pubsub.Sub(
585					pipeCtx,
586					clientID,
587					sesh,
588					[]*psub.Channel{
589						psub.NewChannel(name),
590					},
591					*keepAlive,
592				)
593
594				if err != nil && !*clean {
595					wish.Errorln(sesh, err)
596				}
597			} else if cmd == "pipe" {
598				pipeCmd := flagSet("pipe", sesh)
599				access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
600				public := pipeCmd.Bool("p", false, "Pipe to a public topic")
601				replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
602				clean := pipeCmd.Bool("c", false, "Don't send status messages")
603
604				if !flagCheck(pipeCmd, topic, cmdArgs) {
605					return
606				}
607
608				if pipeCmd.NArg() == 1 && topic == "" {
609					topic = pipeCmd.Arg(0)
610				}
611
612				logger.Info(
613					"flags parsed",
614					"cmd", cmd,
615					"public", *public,
616					"replay", *replay,
617					"topic", topic,
618					"access", *access,
619					"clean", *clean,
620				)
621
622				var accessList []string
623
624				if *access != "" {
625					accessList = parseArgList(*access)
626				}
627
628				isCreator := topic == ""
629				if isCreator {
630					topic = uuid.NewString()
631				}
632
633				var withoutUser string
634				var name string
635				flagMsg := ""
636
637				if isAdmin && strings.HasPrefix(topic, "/") {
638					name = strings.TrimPrefix(topic, "/")
639				} else {
640					name = toTopic(userName, topic)
641					if *public {
642						name = toPublicTopic(topic)
643						flagMsg = "-p "
644						withoutUser = name
645					} else {
646						withoutUser = topic
647					}
648				}
649
650				var accessListCreator bool
651
652				_, loaded := handler.Access.LoadOrStore(name, accessList)
653				if !loaded {
654					defer func() {
655						handler.Access.Delete(name)
656					}()
657
658					accessListCreator = true
659				}
660
661				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
662					if checkAccess(accessList, userName, sesh) || accessListCreator {
663						name = withoutUser
664					} else if !*public {
665						name = toTopic(userName, withoutUser)
666					} else {
667						topic = uuid.NewString()
668						name = toPublicTopic(topic)
669					}
670				}
671
672				if isCreator && !*clean {
673					wish.Printf(
674						sesh,
675						"subscribe to this topic:\n  ssh %s sub %s%s\n",
676						toSshCmd(handler.Cfg),
677						flagMsg,
678						topic,
679					)
680				}
681
682				readErr, writeErr := pubsub.Pipe(
683					pipeCtx,
684					clientID,
685					sesh,
686					[]*psub.Channel{
687						psub.NewChannel(name),
688					},
689					*replay,
690				)
691
692				if readErr != nil && !*clean {
693					wish.Errorln(sesh, "error reading from pipe", readErr)
694				}
695
696				if writeErr != nil && !*clean {
697					wish.Errorln(sesh, "error writing to pipe", writeErr)
698				}
699			}
700
701			next(sesh)
702		}
703	}
704}