repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pubsub
Eric Bower · 24 Sep 24

cli.go

  1package pubsub
  2
  3import (
  4	"bytes"
  5	"flag"
  6	"fmt"
  7	"io"
  8	"log/slog"
  9	"strings"
 10	"text/tabwriter"
 11	"time"
 12
 13	"github.com/charmbracelet/ssh"
 14	"github.com/charmbracelet/wish"
 15	"github.com/google/uuid"
 16	"github.com/picosh/pico/db"
 17	"github.com/picosh/pico/shared"
 18	psub "github.com/picosh/pubsub"
 19	"github.com/picosh/send/send/utils"
 20)
 21
 22func flagSet(cmdName string, sesh ssh.Session) *flag.FlagSet {
 23	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 24	cmd.SetOutput(sesh)
 25	return cmd
 26}
 27
 28func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 29	_ = cmd.Parse(cmdArgs)
 30
 31	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 32		cmd.Usage()
 33		return false
 34	}
 35	return true
 36}
 37
 38func NewTabWriter(out io.Writer) *tabwriter.Writer {
 39	return tabwriter.NewWriter(out, 0, 0, 1, ' ', tabwriter.TabIndent)
 40}
 41
 42func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
 43	var err error
 44	key, err := shared.KeyText(s)
 45	if err != nil {
 46		return nil, fmt.Errorf("key not found")
 47	}
 48
 49	user, err := dbpool.FindUserForKey(s.User(), key)
 50	if err != nil {
 51		return nil, err
 52	}
 53
 54	if user.Name == "" {
 55		return nil, fmt.Errorf("must have username set")
 56	}
 57
 58	return user, nil
 59}
 60
 61// scope channel to user by prefixing name.
 62func toChannel(userName, name string) string {
 63	return fmt.Sprintf("%s/%s", userName, name)
 64}
 65
 66func toPublicChannel(name string) string {
 67	return fmt.Sprintf("public/%s", name)
 68}
 69
 70var helpStr = `Commands: [pub, sub, ls, pipe]
 71
 72The simplest authenticated pubsub system.  Send messages through
 73user-defined channels.  Channels are private to the authenticated
 74ssh user.  The default pubsub model is multicast with bidirectional
 75blocking, meaning a publisher ("pub") will send its message to all
 76subscribers ("sub").  Further, both "pub" and "sub" will wait for
 77at least one event to be sent or received. Pipe ("pipe") allows
 78for bidirectional messages to be sent between any clients connected
 79to a pipe.`
 80
 81type CliHandler struct {
 82	DBPool db.DB
 83	Logger *slog.Logger
 84	PubSub *psub.Cfg
 85	Cfg    *shared.ConfigSite
 86}
 87
 88func toSshCmd(cfg *shared.ConfigSite) string {
 89	port := "22"
 90	if cfg.Port != "" {
 91		port = fmt.Sprintf("-p %s", cfg.Port)
 92	}
 93	return fmt.Sprintf("%s %s", port, cfg.Domain)
 94}
 95
 96func WishMiddleware(handler *CliHandler) wish.Middleware {
 97	dbpool := handler.DBPool
 98	pubsub := handler.PubSub
 99
100	return func(next ssh.Handler) ssh.Handler {
101		return func(sesh ssh.Session) {
102			logger := handler.Logger
103			ctx := sesh.Context()
104			user, err := getUser(sesh, dbpool)
105			if err != nil {
106				utils.ErrorHandler(sesh, err)
107				return
108			}
109
110			logger = shared.LoggerWithUser(logger, user)
111
112			args := sesh.Command()
113
114			if len(args) == 0 {
115				wish.Println(sesh, helpStr)
116				next(sesh)
117				return
118			}
119
120			cmd := strings.TrimSpace(args[0])
121			if cmd == "help" {
122				wish.Println(sesh, helpStr)
123			} else if cmd == "ls" {
124				channelFilter := fmt.Sprintf("%s/", user.Name)
125				if handler.DBPool.HasFeatureForUser(user.ID, "admin") {
126					channelFilter = ""
127				}
128
129				channels := pubsub.PubSub.GetChannels(channelFilter)
130				pipes := pubsub.PubSub.GetPipes(channelFilter)
131
132				if len(channels) == 0 && len(pipes) == 0 {
133					wish.Println(sesh, "no pubsub channels or pipes found")
134				} else {
135					var outputData string
136					if len(channels) > 0 {
137						outputData += "Channel Information\r\n"
138						for _, channel := range channels {
139							outputData += fmt.Sprintf("- %s:\r\n", channel.Name)
140							outputData += "\tPubs:\r\n"
141
142							channel.Pubs.Range(func(I string, J *psub.Pub) bool {
143								outputData += fmt.Sprintf("\t- %s:\r\n", I)
144								return true
145							})
146
147							outputData += "\tSubs:\r\n"
148
149							channel.Subs.Range(func(I string, J *psub.Sub) bool {
150								outputData += fmt.Sprintf("\t- %s:\r\n", I)
151								return true
152							})
153						}
154					}
155
156					if len(pipes) > 0 {
157						outputData += "Pipe Information\r\n"
158						for _, pipe := range pipes {
159							outputData += fmt.Sprintf("- %s:\r\n", pipe.Name)
160							outputData += "\tClients:\r\n"
161
162							pipe.Clients.Range(func(I string, J *psub.PipeClient) bool {
163								outputData += fmt.Sprintf("\t- %s:\r\n", I)
164								return true
165							})
166						}
167					}
168
169					_, _ = sesh.Write([]byte(outputData))
170				}
171			}
172
173			channelName := ""
174			cmdArgs := args[1:]
175			if len(args) > 1 {
176				channelName = strings.TrimSpace(args[1])
177				cmdArgs = args[2:]
178			}
179			logger.Info(
180				"imgs middleware detected command",
181				"args", args,
182				"cmd", cmd,
183				"channelName", channelName,
184				"cmdArgs", cmdArgs,
185			)
186
187			if cmd == "pub" {
188				defaultTimeout, _ := time.ParseDuration("720h")
189
190				pubCmd := flagSet("pub", sesh)
191				empty := pubCmd.Bool("e", false, "Send an empty message to subs")
192				public := pubCmd.Bool("p", false, "Anyone can sub to this channel")
193				timeout := pubCmd.Duration("t", defaultTimeout, "Timeout as a Go duration before cancelling the pub event. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
194				if !flagCheck(pubCmd, channelName, cmdArgs) {
195					return
196				}
197
198				var reader io.Reader
199				if *empty {
200					reader = bytes.NewReader(make([]byte, 1))
201				} else {
202					reader = sesh
203				}
204
205				if channelName == "" {
206					channelName = uuid.NewString()
207				}
208				name := toChannel(user.Name, channelName)
209				if *public {
210					name = toPublicChannel(channelName)
211				}
212				wish.Printf(
213					sesh,
214					"subscribe to this channel:\n\tssh %s sub %s\n",
215					toSshCmd(handler.Cfg),
216					channelName,
217				)
218
219				wish.Println(sesh, "sending msg ...")
220				pub := &psub.Pub{
221					ID:     fmt.Sprintf("%s (%s@%s)", uuid.NewString(), user.Name, sesh.RemoteAddr().String()),
222					Done:   make(chan struct{}),
223					Reader: reader,
224				}
225
226				count := 0
227				channelInfo := pubsub.PubSub.GetChannel(name)
228
229				if channelInfo != nil {
230					channelInfo.Subs.Range(func(I string, J *psub.Sub) bool {
231						count++
232						return true
233					})
234				}
235
236				tt := *timeout
237				if count == 0 {
238					str := "no subs found ... waiting"
239					if tt > 0 {
240						str += " " + tt.String()
241					}
242					wish.Println(sesh, str)
243				}
244
245				go func() {
246					select {
247					case <-ctx.Done():
248						pub.Cleanup()
249					case <-time.After(tt):
250						wish.Fatalln(sesh, "timeout reached, exiting ...")
251						pub.Cleanup()
252					}
253				}()
254
255				err = pubsub.PubSub.Pub(name, pub)
256				wish.Println(sesh, "msg sent!")
257				if err != nil {
258					wish.Errorln(sesh, err)
259				}
260			} else if cmd == "sub" {
261				pubCmd := flagSet("pub", sesh)
262				public := pubCmd.Bool("p", false, "Subscribe to a public channel")
263				if !flagCheck(pubCmd, channelName, cmdArgs) {
264					return
265				}
266				channelName := channelName
267
268				name := toChannel(user.Name, channelName)
269				if *public {
270					name = toPublicChannel(channelName)
271				}
272
273				sub := &psub.Sub{
274					ID:     fmt.Sprintf("%s (%s@%s)", uuid.NewString(), user.Name, sesh.RemoteAddr().String()),
275					Writer: sesh,
276					Done:   make(chan struct{}),
277					Data:   make(chan []byte),
278				}
279
280				go func() {
281					<-ctx.Done()
282					sub.Cleanup()
283				}()
284				err = pubsub.PubSub.Sub(name, sub)
285				if err != nil {
286					wish.Errorln(sesh, err)
287				}
288			} else if cmd == "pipe" {
289				pipeCmd := flagSet("pipe", sesh)
290				public := pipeCmd.Bool("p", false, "Pipe to a public channel")
291				replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
292				if !flagCheck(pipeCmd, channelName, cmdArgs) {
293					return
294				}
295				isCreator := channelName == ""
296				if isCreator {
297					channelName = uuid.NewString()
298				}
299				name := toChannel(user.Name, channelName)
300				if *public {
301					name = toPublicChannel(channelName)
302				}
303				if isCreator {
304					wish.Printf(
305						sesh,
306						"subscribe to this channel:\n\tssh %s sub %s\n",
307						toSshCmd(handler.Cfg),
308						channelName,
309					)
310				}
311
312				pipe := &psub.PipeClient{
313					ID:         fmt.Sprintf("%s (%s@%s)", uuid.NewString(), user.Name, sesh.RemoteAddr().String()),
314					Done:       make(chan struct{}),
315					Data:       make(chan psub.PipeMessage),
316					ReadWriter: sesh,
317					Replay:     *replay,
318				}
319
320				go func() {
321					<-ctx.Done()
322					pipe.Cleanup()
323				}()
324				readErr, writeErr := pubsub.PubSub.Pipe(name, pipe)
325				if readErr != nil {
326					wish.Errorln(sesh, readErr)
327				}
328				if writeErr != nil {
329					wish.Errorln(sesh, writeErr)
330				}
331			}
332
333			next(sesh)
334		}
335	}
336}