repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pgs
Eric Bower · 28 Oct 24

wish.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/ssh"
 10	"github.com/charmbracelet/wish"
 11	bm "github.com/charmbracelet/wish/bubbletea"
 12	"github.com/picosh/pico/db"
 13	"github.com/picosh/pico/tui/common"
 14	sendutils "github.com/picosh/send/utils"
 15	"github.com/picosh/utils"
 16)
 17
 18func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
 19	if s.PublicKey() == nil {
 20		return nil, fmt.Errorf("key not found")
 21	}
 22
 23	key := utils.KeyForKeyText(s.PublicKey())
 24
 25	user, err := dbpool.FindUserForKey(s.User(), key)
 26	if err != nil {
 27		return nil, err
 28	}
 29
 30	if user.Name == "" {
 31		return nil, fmt.Errorf("must have username set")
 32	}
 33
 34	return user, nil
 35}
 36
 37type arrayFlags []string
 38
 39func (i *arrayFlags) String() string {
 40	return "array flags"
 41}
 42
 43func (i *arrayFlags) Set(value string) error {
 44	*i = append(*i, value)
 45	return nil
 46}
 47
 48func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
 49	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 50	cmd.SetOutput(sesh)
 51	write := cmd.Bool("write", false, "apply changes")
 52	return cmd, write
 53}
 54
 55func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 56	_ = cmd.Parse(cmdArgs)
 57
 58	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 59		cmd.Usage()
 60		return false
 61	}
 62	return true
 63}
 64
 65func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
 66	dbpool := handler.DBPool
 67	log := handler.Cfg.Logger
 68	cfg := handler.Cfg
 69	store := handler.Storage
 70
 71	return func(next ssh.Handler) ssh.Handler {
 72		return func(sesh ssh.Session) {
 73			args := sesh.Command()
 74			if len(args) == 0 {
 75				next(sesh)
 76				return
 77			}
 78
 79			// default width and height when no pty
 80			width := 100
 81			height := 24
 82			pty, _, ok := sesh.Pty()
 83			if ok {
 84				width = pty.Window.Width
 85				height = pty.Window.Height
 86			}
 87
 88			user, err := getUser(sesh, dbpool)
 89			if err != nil {
 90				sendutils.ErrorHandler(sesh, err)
 91				return
 92			}
 93
 94			renderer := bm.MakeRenderer(sesh)
 95			styles := common.DefaultStyles(renderer)
 96
 97			opts := Cmd{
 98				Session: sesh,
 99				User:    user,
100				Store:   store,
101				Log:     log,
102				Dbpool:  dbpool,
103				Write:   false,
104				Styles:  styles,
105				Width:   width,
106				Height:  height,
107			}
108
109			cmd := strings.TrimSpace(args[0])
110			if len(args) == 1 {
111				if cmd == "help" {
112					opts.help()
113					return
114				} else if cmd == "stats" {
115					err := opts.stats(cfg.MaxSize)
116					opts.bail(err)
117					return
118				} else if cmd == "ls" {
119					err := opts.ls()
120					opts.bail(err)
121					return
122				} else {
123					next(sesh)
124					return
125				}
126			}
127
128			projectName := strings.TrimSpace(args[1])
129			cmdArgs := args[2:]
130			log.Info(
131				"pgs middleware detected command",
132				"args", args,
133				"cmd", cmd,
134				"projectName", projectName,
135				"cmdArgs", cmdArgs,
136			)
137
138			if cmd == "stats" {
139				err := opts.statsByProject(projectName)
140				opts.bail(err)
141				return
142			} else if cmd == "link" {
143				linkCmd, write := flagSet("link", sesh)
144				linkTo := linkCmd.String("to", "", "symbolic link to this project")
145				if !flagCheck(linkCmd, projectName, cmdArgs) {
146					return
147				}
148				opts.Write = *write
149
150				if *linkTo == "" {
151					err := fmt.Errorf(
152						"must provide `--to` flag",
153					)
154					opts.bail(err)
155					return
156				}
157
158				err := opts.link(projectName, *linkTo)
159				opts.notice()
160				if err != nil {
161					opts.bail(err)
162				}
163				return
164			} else if cmd == "unlink" {
165				unlinkCmd, write := flagSet("unlink", sesh)
166				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
167					return
168				}
169				opts.Write = *write
170
171				err := opts.unlink(projectName)
172				opts.notice()
173				opts.bail(err)
174				return
175			} else if cmd == "depends" {
176				err := opts.depends(projectName)
177				opts.bail(err)
178				return
179			} else if cmd == "retain" {
180				retainCmd, write := flagSet("retain", sesh)
181				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
182				if !flagCheck(retainCmd, projectName, cmdArgs) {
183					return
184				}
185				opts.Write = *write
186
187				err := opts.prune(projectName, *retainNum)
188				opts.notice()
189				opts.bail(err)
190				return
191			} else if cmd == "prune" {
192				pruneCmd, write := flagSet("prune", sesh)
193				if !flagCheck(pruneCmd, projectName, cmdArgs) {
194					return
195				}
196				opts.Write = *write
197
198				err := opts.prune(projectName, 0)
199				opts.notice()
200				opts.bail(err)
201				return
202			} else if cmd == "rm" {
203				rmCmd, write := flagSet("rm", sesh)
204				if !flagCheck(rmCmd, projectName, cmdArgs) {
205					return
206				}
207				opts.Write = *write
208
209				err := opts.rm(projectName)
210				opts.notice()
211				opts.bail(err)
212				return
213			} else if cmd == "acl" {
214				aclCmd, write := flagSet("acl", sesh)
215				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
216				var acls arrayFlags
217				aclCmd.Var(
218					&acls,
219					"acl",
220					"list of pico usernames or sha256 public keys, delimited by commas",
221				)
222				if !flagCheck(aclCmd, projectName, cmdArgs) {
223					return
224				}
225				opts.Write = *write
226
227				if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
228					err := fmt.Errorf(
229						"acl type must be one of the following: [public, pubkeys, pico], found %s",
230						*aclType,
231					)
232					opts.bail(err)
233					return
234				}
235
236				err := opts.acl(projectName, *aclType, acls)
237				opts.notice()
238				opts.bail(err)
239			} else {
240				next(sesh)
241				return
242			}
243		}
244	}
245}