repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pipe
Antonio Mika · 22 Nov 24

cli.go

  1package pipe
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"slices"
 11	"strings"
 12	"text/tabwriter"
 13	"time"
 14
 15	"github.com/antoniomika/syncmap"
 16	"github.com/charmbracelet/ssh"
 17	"github.com/charmbracelet/wish"
 18	"github.com/google/uuid"
 19	"github.com/picosh/pico/db"
 20	"github.com/picosh/pico/shared"
 21	psub "github.com/picosh/pubsub"
 22	gossh "golang.org/x/crypto/ssh"
 23)
 24
 25func flagSet(cmdName string, sesh ssh.Session) *flag.FlagSet {
 26	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 27	cmd.SetOutput(sesh)
 28	cmd.Usage = func() {
 29		fmt.Fprintf(cmd.Output(), "Usage: %s <topic> [args...]\nArgs:\n", cmdName)
 30		cmd.PrintDefaults()
 31	}
 32	return cmd
 33}
 34
 35func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 36	err := cmd.Parse(cmdArgs)
 37
 38	if err != nil || posArg == "help" {
 39		if posArg == "help" {
 40			cmd.Usage()
 41		}
 42		return false
 43	}
 44	return true
 45}
 46
 47func NewTabWriter(out io.Writer) *tabwriter.Writer {
 48	return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
 49}
 50
 51// scope topic to user by prefixing name.
 52func toTopic(userName, topic string) string {
 53	return fmt.Sprintf("%s/%s", userName, topic)
 54}
 55
 56func toPublicTopic(topic string) string {
 57	return fmt.Sprintf("public/%s", topic)
 58}
 59
 60func clientInfo(clients []*psub.Client, clientType string) string {
 61	if len(clients) == 0 {
 62		return ""
 63	}
 64
 65	outputData := fmt.Sprintf("    %s:\r\n", clientType)
 66
 67	for _, client := range clients {
 68		outputData += fmt.Sprintf("    - %s\r\n", client.ID)
 69	}
 70
 71	return outputData
 72}
 73
 74var helpStr = func(sshCmd string) string {
 75	return fmt.Sprintf(`Command: ssh %s <help | ls | pub | sub | pipe> <topic> [-h | args...]
 76
 77The simplest authenticated pubsub system.  Send messages through
 78user-defined topics.  Topics are private to the authenticated
 79ssh user.  The default pubsub model is multicast with bidirectional
 80blocking, meaning a publisher ("pub") will send its message to all
 81subscribers ("sub").  Further, both "pub" and "sub" will wait for
 82at least one event to be sent or received. Pipe ("pipe") allows
 83for bidirectional messages to be sent between any clients connected
 84to a pipe.
 85
 86Think of these different commands in terms of the direction the
 87data is being sent:
 88
 89- pub => writes to client
 90- sub => reads from client
 91- pipe => read and write between clients
 92`, sshCmd)
 93}
 94
 95type CliHandler struct {
 96	DBPool  db.DB
 97	Logger  *slog.Logger
 98	PubSub  psub.PubSub
 99	Cfg     *shared.ConfigSite
100	Waiters *syncmap.Map[string, []string]
101	Access  *syncmap.Map[string, []string]
102}
103
104func toSshCmd(cfg *shared.ConfigSite) string {
105	port := ""
106	if cfg.PortOverride != "22" {
107		port = fmt.Sprintf("-p %s ", cfg.PortOverride)
108	}
109	return fmt.Sprintf("%s%s", port, cfg.Domain)
110}
111
112// parseArgList parses a comma separated list of arguments.
113func parseArgList(arg string) []string {
114	argList := strings.Split(arg, ",")
115	for i, acc := range argList {
116		argList[i] = strings.TrimSpace(acc)
117	}
118	return argList
119}
120
121// checkAccess checks if the user has access to a topic based on an access list.
122func checkAccess(accessList []string, userName string, sesh ssh.Session) bool {
123	for _, acc := range accessList {
124		if acc == userName {
125			return true
126		}
127
128		if key := sesh.PublicKey(); key != nil && acc == gossh.FingerprintSHA256(key) {
129			return true
130		}
131	}
132
133	return false
134}
135
136func WishMiddleware(handler *CliHandler) wish.Middleware {
137	pubsub := handler.PubSub
138
139	return func(next ssh.Handler) ssh.Handler {
140		return func(sesh ssh.Session) {
141			logger := handler.Logger
142			ctx := sesh.Context()
143
144			user, err := shared.GetUser(ctx)
145			if err != nil {
146				logger.Info("user not found", "err", err)
147			}
148
149			if user != nil {
150				logger = shared.LoggerWithUser(logger, user)
151			}
152
153			args := sesh.Command()
154
155			if len(args) == 0 {
156				wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
157				next(sesh)
158				return
159			}
160
161			userName := "public"
162
163			userNameAddition := ""
164
165			isAdmin := false
166			if user != nil {
167				isAdmin = handler.DBPool.HasFeatureForUser(user.ID, "admin")
168
169				userName = user.Name
170				if user.PublicKey.Name != "" {
171					userNameAddition = fmt.Sprintf("-%s", user.PublicKey.Name)
172				}
173			}
174
175			pipeCtx, cancel := context.WithCancel(ctx)
176
177			go func() {
178				defer cancel()
179
180				for {
181					select {
182					case <-pipeCtx.Done():
183						return
184					default:
185						_, err := sesh.SendRequest("ping@pico.sh", false, nil)
186						if err != nil {
187							logger.Error("error sending ping", "err", err)
188							return
189						}
190
191						time.Sleep(5 * time.Second)
192					}
193				}
194			}()
195
196			cmd := strings.TrimSpace(args[0])
197			if cmd == "help" {
198				wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
199				next(sesh)
200				return
201			} else if cmd == "ls" {
202				if userName == "public" {
203					wish.Fatalln(sesh, "access denied")
204					return
205				}
206
207				topicFilter := fmt.Sprintf("%s/", userName)
208				if isAdmin {
209					topicFilter = ""
210					if len(args) > 1 {
211						topicFilter = args[1]
212					}
213				}
214
215				var channels []*psub.Channel
216				waitingChannels := map[string][]string{}
217
218				for topic, channel := range pubsub.GetChannels() {
219					if strings.HasPrefix(topic, topicFilter) {
220						channels = append(channels, channel)
221					}
222				}
223
224				for channel, clients := range handler.Waiters.Range {
225					if strings.HasPrefix(channel, topicFilter) {
226						waitingChannels[channel] = clients
227					}
228				}
229
230				if len(channels) == 0 && len(waitingChannels) == 0 {
231					wish.Println(sesh, "no pubsub channels found")
232				} else {
233					var outputData string
234					if len(channels) > 0 || len(waitingChannels) > 0 {
235						outputData += "Channel Information\r\n"
236						for _, channel := range channels {
237							extraData := ""
238
239							if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
240								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
241							}
242
243							outputData += fmt.Sprintf("- %s:%s\r\n", channel.Topic, extraData)
244							outputData += "  Clients:\r\n"
245
246							var pubs []*psub.Client
247							var subs []*psub.Client
248							var pipes []*psub.Client
249
250							for _, client := range channel.GetClients() {
251								if client.Direction == psub.ChannelDirectionInput {
252									pubs = append(pubs, client)
253								} else if client.Direction == psub.ChannelDirectionOutput {
254									subs = append(subs, client)
255								} else if client.Direction == psub.ChannelDirectionInputOutput {
256									pipes = append(pipes, client)
257								}
258							}
259							outputData += clientInfo(pubs, "Pubs")
260							outputData += clientInfo(subs, "Subs")
261							outputData += clientInfo(pipes, "Pipes")
262						}
263
264						for waitingChannel, channelPubs := range waitingChannels {
265							extraData := ""
266
267							if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
268								extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
269							}
270
271							outputData += fmt.Sprintf("- %s:%s\r\n", waitingChannel, extraData)
272							outputData += "  Clients:\r\n"
273							outputData += fmt.Sprintf("    %s:\r\n", "Waiting Pubs")
274							for _, client := range channelPubs {
275								outputData += fmt.Sprintf("    - %s\r\n", client)
276							}
277						}
278					}
279
280					_, _ = sesh.Write([]byte(outputData))
281				}
282
283				next(sesh)
284				return
285			}
286
287			topic := ""
288			cmdArgs := args[1:]
289			if len(args) > 1 && !strings.HasPrefix(args[1], "-") {
290				topic = strings.TrimSpace(args[1])
291				cmdArgs = args[2:]
292			}
293
294			logger.Info(
295				"pubsub middleware detected command",
296				"args", args,
297				"cmd", cmd,
298				"topic", topic,
299				"cmdArgs", cmdArgs,
300			)
301
302			clientID := fmt.Sprintf("%s (%s%s@%s)", uuid.NewString(), userName, userNameAddition, sesh.RemoteAddr().String())
303
304			if cmd == "pub" {
305				pubCmd := flagSet("pub", sesh)
306				access := pubCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
307				empty := pubCmd.Bool("e", false, "Send an empty message to subs")
308				public := pubCmd.Bool("p", false, "Publish message to public topic")
309				block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
310				timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
311				clean := pubCmd.Bool("c", false, "Don't send status messages")
312
313				if !flagCheck(pubCmd, topic, cmdArgs) {
314					return
315				}
316
317				if pubCmd.NArg() == 1 && topic == "" {
318					topic = pubCmd.Arg(0)
319				}
320
321				logger.Info(
322					"flags parsed",
323					"cmd", cmd,
324					"empty", *empty,
325					"public", *public,
326					"block", *block,
327					"timeout", *timeout,
328					"topic", topic,
329					"access", *access,
330					"clean", *clean,
331				)
332
333				var accessList []string
334
335				if *access != "" {
336					accessList = parseArgList(*access)
337				}
338
339				var rw io.ReadWriter
340				if *empty {
341					rw = bytes.NewBuffer(make([]byte, 1))
342				} else {
343					rw = sesh
344				}
345
346				if topic == "" {
347					topic = uuid.NewString()
348				}
349
350				var withoutUser string
351				var name string
352				msgFlag := ""
353
354				if isAdmin && strings.HasPrefix(topic, "/") {
355					name = strings.TrimPrefix(topic, "/")
356				} else {
357					name = toTopic(userName, topic)
358					if *public {
359						name = toPublicTopic(topic)
360						msgFlag = "-p "
361						withoutUser = name
362					} else {
363						withoutUser = topic
364					}
365				}
366
367				var accessListCreator bool
368
369				_, loaded := handler.Access.LoadOrStore(name, accessList)
370				if !loaded {
371					defer func() {
372						handler.Access.Delete(name)
373					}()
374
375					accessListCreator = true
376				}
377
378				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
379					if checkAccess(accessList, userName, sesh) || accessListCreator {
380						name = withoutUser
381					} else if !*public {
382						name = toTopic(userName, withoutUser)
383					} else {
384						topic = uuid.NewString()
385						name = toPublicTopic(topic)
386					}
387				}
388
389				if !*clean {
390					wish.Printf(
391						sesh,
392						"subscribe to this channel:\n  ssh %s sub %s%s\n",
393						toSshCmd(handler.Cfg),
394						msgFlag,
395						topic,
396					)
397				}
398
399				if *block {
400					count := 0
401					for topic, channel := range pubsub.GetChannels() {
402						if topic == name {
403							for _, client := range channel.GetClients() {
404								if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
405									count++
406								}
407							}
408							break
409						}
410					}
411
412					tt := *timeout
413					if count == 0 {
414						currentWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
415						handler.Waiters.Store(name, append(currentWaiters, clientID))
416
417						termMsg := "no subs found ... waiting"
418						if tt > 0 {
419							termMsg += " " + tt.String()
420						}
421
422						if !*clean {
423							wish.Println(sesh, termMsg)
424						}
425
426						ready := make(chan struct{})
427
428						go func() {
429							for {
430								select {
431								case <-pipeCtx.Done():
432									cancel()
433									return
434								case <-time.After(1 * time.Millisecond):
435									count := 0
436									for topic, channel := range pubsub.GetChannels() {
437										if topic == name {
438											for _, client := range channel.GetClients() {
439												if client.Direction == psub.ChannelDirectionOutput || client.Direction == psub.ChannelDirectionInputOutput {
440													count++
441												}
442											}
443											break
444										}
445									}
446
447									if count > 0 {
448										close(ready)
449										return
450									}
451								}
452							}
453						}()
454
455						select {
456						case <-ready:
457						case <-pipeCtx.Done():
458						case <-time.After(tt):
459							cancel()
460
461							if !*clean {
462								wish.Fatalln(sesh, "timeout reached, exiting ...")
463							} else {
464								err = sesh.Exit(1)
465								if err != nil {
466									logger.Error("error exiting session", "err", err)
467								}
468
469								sesh.Close()
470							}
471						}
472
473						newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
474						newWaiters = slices.DeleteFunc(newWaiters, func(cl string) bool {
475							return cl == clientID
476						})
477						handler.Waiters.Store(name, newWaiters)
478
479						var toDelete []string
480
481						for channel, clients := range handler.Waiters.Range {
482							if len(clients) == 0 {
483								toDelete = append(toDelete, channel)
484							}
485						}
486
487						for _, channel := range toDelete {
488							handler.Waiters.Delete(channel)
489						}
490					}
491				}
492
493				if !*clean {
494					wish.Println(sesh, "sending msg ...")
495				}
496
497				err = pubsub.Pub(
498					pipeCtx,
499					clientID,
500					rw,
501					[]*psub.Channel{
502						psub.NewChannel(name),
503					},
504					*block,
505				)
506
507				if !*clean {
508					wish.Println(sesh, "msg sent!")
509				}
510
511				if err != nil && !*clean {
512					wish.Errorln(sesh, err)
513				}
514			} else if cmd == "sub" {
515				subCmd := flagSet("sub", sesh)
516				access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
517				public := subCmd.Bool("p", false, "Subscribe to a public topic")
518				keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
519				clean := subCmd.Bool("c", false, "Don't send status messages")
520
521				if !flagCheck(subCmd, topic, cmdArgs) {
522					return
523				}
524
525				if subCmd.NArg() == 1 && topic == "" {
526					topic = subCmd.Arg(0)
527				}
528
529				logger.Info(
530					"flags parsed",
531					"cmd", cmd,
532					"public", *public,
533					"keepAlive", *keepAlive,
534					"topic", topic,
535					"clean", *clean,
536					"access", *access,
537				)
538
539				var accessList []string
540
541				if *access != "" {
542					accessList = parseArgList(*access)
543				}
544
545				var withoutUser string
546				var name string
547
548				if isAdmin && strings.HasPrefix(topic, "/") {
549					name = strings.TrimPrefix(topic, "/")
550				} else {
551					name = toTopic(userName, topic)
552					if *public {
553						name = toPublicTopic(topic)
554						withoutUser = name
555					} else {
556						withoutUser = topic
557					}
558				}
559
560				var accessListCreator bool
561
562				_, loaded := handler.Access.LoadOrStore(name, accessList)
563				if !loaded {
564					defer func() {
565						handler.Access.Delete(name)
566					}()
567
568					accessListCreator = true
569				}
570
571				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
572					if checkAccess(accessList, userName, sesh) || accessListCreator {
573						name = withoutUser
574					} else if !*public {
575						name = toTopic(userName, withoutUser)
576					} else {
577						wish.Errorln(sesh, "access denied")
578						return
579					}
580				}
581
582				err = pubsub.Sub(
583					pipeCtx,
584					clientID,
585					sesh,
586					[]*psub.Channel{
587						psub.NewChannel(name),
588					},
589					*keepAlive,
590				)
591
592				if err != nil && !*clean {
593					wish.Errorln(sesh, err)
594				}
595			} else if cmd == "pipe" {
596				pipeCmd := flagSet("pipe", sesh)
597				access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
598				public := pipeCmd.Bool("p", false, "Pipe to a public topic")
599				replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
600				clean := pipeCmd.Bool("c", false, "Don't send status messages")
601
602				if !flagCheck(pipeCmd, topic, cmdArgs) {
603					return
604				}
605
606				if pipeCmd.NArg() == 1 && topic == "" {
607					topic = pipeCmd.Arg(0)
608				}
609
610				logger.Info(
611					"flags parsed",
612					"cmd", cmd,
613					"public", *public,
614					"replay", *replay,
615					"topic", topic,
616					"access", *access,
617					"clean", *clean,
618				)
619
620				var accessList []string
621
622				if *access != "" {
623					accessList = parseArgList(*access)
624				}
625
626				isCreator := topic == ""
627				if isCreator {
628					topic = uuid.NewString()
629				}
630
631				var withoutUser string
632				var name string
633				flagMsg := ""
634
635				if isAdmin && strings.HasPrefix(topic, "/") {
636					name = strings.TrimPrefix(topic, "/")
637				} else {
638					name = toTopic(userName, topic)
639					if *public {
640						name = toPublicTopic(topic)
641						flagMsg = "-p "
642						withoutUser = name
643					} else {
644						withoutUser = topic
645					}
646				}
647
648				var accessListCreator bool
649
650				_, loaded := handler.Access.LoadOrStore(name, accessList)
651				if !loaded {
652					defer func() {
653						handler.Access.Delete(name)
654					}()
655
656					accessListCreator = true
657				}
658
659				if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
660					if checkAccess(accessList, userName, sesh) || accessListCreator {
661						name = withoutUser
662					} else if !*public {
663						name = toTopic(userName, withoutUser)
664					} else {
665						topic = uuid.NewString()
666						name = toPublicTopic(topic)
667					}
668				}
669
670				if isCreator && !*clean {
671					wish.Printf(
672						sesh,
673						"subscribe to this topic:\n  ssh %s sub %s%s\n",
674						toSshCmd(handler.Cfg),
675						flagMsg,
676						topic,
677					)
678				}
679
680				readErr, writeErr := pubsub.Pipe(
681					pipeCtx,
682					clientID,
683					sesh,
684					[]*psub.Channel{
685						psub.NewChannel(name),
686					},
687					*replay,
688				)
689
690				if readErr != nil && !*clean {
691					wish.Errorln(sesh, "error reading from pipe", readErr)
692				}
693
694				if writeErr != nil && !*clean {
695					wish.Errorln(sesh, "error writing to pipe", writeErr)
696				}
697			}
698
699			next(sesh)
700		}
701	}
702}