repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pgs
Eric Bower · 19 Aug 24

wish.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/ssh"
 10	"github.com/charmbracelet/wish"
 11	bm "github.com/charmbracelet/wish/bubbletea"
 12	"github.com/picosh/pico/db"
 13	uploadassets "github.com/picosh/pico/filehandlers/assets"
 14	"github.com/picosh/pico/shared"
 15	"github.com/picosh/pico/tui/common"
 16	"github.com/picosh/send/send/utils"
 17)
 18
 19func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
 20	var err error
 21	key, err := shared.KeyText(s)
 22	if err != nil {
 23		return nil, fmt.Errorf("key not found")
 24	}
 25
 26	user, err := dbpool.FindUserForKey(s.User(), key)
 27	if err != nil {
 28		return nil, err
 29	}
 30
 31	if user.Name == "" {
 32		return nil, fmt.Errorf("must have username set")
 33	}
 34
 35	return user, nil
 36}
 37
 38type arrayFlags []string
 39
 40func (i *arrayFlags) String() string {
 41	return "array flags"
 42}
 43
 44func (i *arrayFlags) Set(value string) error {
 45	*i = append(*i, value)
 46	return nil
 47}
 48
 49func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
 50	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 51	cmd.SetOutput(sesh)
 52	write := cmd.Bool("write", false, "apply changes")
 53	return cmd, write
 54}
 55
 56func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 57	_ = cmd.Parse(cmdArgs)
 58
 59	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 60		cmd.Usage()
 61		return false
 62	}
 63	return true
 64}
 65
 66func WishMiddleware(handler *uploadassets.UploadAssetHandler) wish.Middleware {
 67	dbpool := handler.DBPool
 68	log := handler.Cfg.Logger
 69	cfg := handler.Cfg
 70	store := handler.Storage
 71
 72	return func(next ssh.Handler) ssh.Handler {
 73		return func(sesh ssh.Session) {
 74			args := sesh.Command()
 75			if len(args) == 0 {
 76				next(sesh)
 77				return
 78			}
 79
 80			// default width and height when no pty
 81			width := 100
 82			height := 24
 83			pty, _, ok := sesh.Pty()
 84			if ok {
 85				width = pty.Window.Width
 86				height = pty.Window.Height
 87			}
 88
 89			user, err := getUser(sesh, dbpool)
 90			if err != nil {
 91				utils.ErrorHandler(sesh, err)
 92				return
 93			}
 94
 95			renderer := bm.MakeRenderer(sesh)
 96			styles := common.DefaultStyles(renderer)
 97
 98			opts := Cmd{
 99				Session: sesh,
100				User:    user,
101				Store:   store,
102				Log:     log,
103				Dbpool:  dbpool,
104				Write:   false,
105				Styles:  styles,
106				Width:   width,
107				Height:  height,
108			}
109
110			cmd := strings.TrimSpace(args[0])
111			if len(args) == 1 {
112				if cmd == "help" {
113					opts.help()
114					return
115				} else if cmd == "stats" {
116					err := opts.stats(cfg.MaxSize)
117					opts.bail(err)
118					return
119				} else if cmd == "ls" {
120					err := opts.ls()
121					opts.bail(err)
122					return
123				} else {
124					next(sesh)
125					return
126				}
127			}
128
129			projectName := strings.TrimSpace(args[1])
130			cmdArgs := args[2:]
131			log.Info(
132				"pgs middleware detected command",
133				"args", args,
134				"cmd", cmd,
135				"projectName", projectName,
136				"cmdArgs", cmdArgs,
137			)
138
139			if cmd == "stats" {
140				err := opts.statsByProject(projectName)
141				opts.bail(err)
142				return
143			} else if cmd == "link" {
144				linkCmd, write := flagSet("link", sesh)
145				linkTo := linkCmd.String("to", "", "symbolic link to this project")
146				if !flagCheck(linkCmd, projectName, cmdArgs) {
147					return
148				}
149				opts.Write = *write
150
151				if *linkTo == "" {
152					err := fmt.Errorf(
153						"must provide `--to` flag",
154					)
155					opts.bail(err)
156					return
157				}
158
159				err := opts.link(projectName, *linkTo)
160				opts.notice()
161				if err != nil {
162					opts.bail(err)
163				}
164				return
165			} else if cmd == "unlink" {
166				unlinkCmd, write := flagSet("unlink", sesh)
167				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
168					return
169				}
170				opts.Write = *write
171
172				err := opts.unlink(projectName)
173				opts.notice()
174				opts.bail(err)
175				return
176			} else if cmd == "depends" {
177				err := opts.depends(projectName)
178				opts.bail(err)
179				return
180			} else if cmd == "retain" {
181				retainCmd, write := flagSet("retain", sesh)
182				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
183				if !flagCheck(retainCmd, projectName, cmdArgs) {
184					return
185				}
186				opts.Write = *write
187
188				err := opts.prune(projectName, *retainNum)
189				opts.notice()
190				opts.bail(err)
191				return
192			} else if cmd == "prune" {
193				pruneCmd, write := flagSet("prune", sesh)
194				if !flagCheck(pruneCmd, projectName, cmdArgs) {
195					return
196				}
197				opts.Write = *write
198
199				err := opts.prune(projectName, 0)
200				opts.notice()
201				opts.bail(err)
202				return
203			} else if cmd == "rm" {
204				rmCmd, write := flagSet("rm", sesh)
205				if !flagCheck(rmCmd, projectName, cmdArgs) {
206					return
207				}
208				opts.Write = *write
209
210				err := opts.rm(projectName)
211				opts.notice()
212				opts.bail(err)
213				return
214			} else if cmd == "acl" {
215				aclCmd, write := flagSet("acl", sesh)
216				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
217				var acls arrayFlags
218				aclCmd.Var(
219					&acls,
220					"acl",
221					"list of pico usernames or sha256 public keys, delimited by commas",
222				)
223				if !flagCheck(aclCmd, projectName, cmdArgs) {
224					return
225				}
226				opts.Write = *write
227
228				if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
229					err := fmt.Errorf(
230						"acl type must be one of the following: [public, pubkeys, pico], found %s",
231						*aclType,
232					)
233					opts.bail(err)
234					return
235				}
236
237				err := opts.acl(projectName, *aclType, acls)
238				opts.notice()
239				opts.bail(err)
240			} else {
241				next(sesh)
242				return
243			}
244		}
245	}
246}