repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / pgs
Eric Bower · 04 Dec 24

wish.go

  1package pgs
  2
  3import (
  4	"flag"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/ssh"
 10	"github.com/charmbracelet/wish"
 11	bm "github.com/charmbracelet/wish/bubbletea"
 12	"github.com/muesli/termenv"
 13	"github.com/picosh/pico/db"
 14	"github.com/picosh/pico/tui/common"
 15	sendutils "github.com/picosh/send/utils"
 16	"github.com/picosh/utils"
 17)
 18
 19func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
 20	if s.PublicKey() == nil {
 21		return nil, fmt.Errorf("key not found")
 22	}
 23
 24	key := utils.KeyForKeyText(s.PublicKey())
 25
 26	user, err := dbpool.FindUserForKey(s.User(), key)
 27	if err != nil {
 28		return nil, err
 29	}
 30
 31	if user.Name == "" {
 32		return nil, fmt.Errorf("must have username set")
 33	}
 34
 35	return user, nil
 36}
 37
 38type arrayFlags []string
 39
 40func (i *arrayFlags) String() string {
 41	return "array flags"
 42}
 43
 44func (i *arrayFlags) Set(value string) error {
 45	*i = append(*i, value)
 46	return nil
 47}
 48
 49func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
 50	cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
 51	cmd.SetOutput(sesh)
 52	write := cmd.Bool("write", false, "apply changes")
 53	return cmd, write
 54}
 55
 56func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
 57	_ = cmd.Parse(cmdArgs)
 58
 59	if posArg == "-h" || posArg == "--help" || posArg == "-help" {
 60		cmd.Usage()
 61		return false
 62	}
 63	return true
 64}
 65
 66func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
 67	dbpool := handler.DBPool
 68	log := handler.Cfg.Logger
 69	cfg := handler.Cfg
 70	store := handler.Storage
 71
 72	return func(next ssh.Handler) ssh.Handler {
 73		return func(sesh ssh.Session) {
 74			args := sesh.Command()
 75			if len(args) == 0 {
 76				next(sesh)
 77				return
 78			}
 79
 80			// default width and height when no pty
 81			width := 100
 82			height := 24
 83			pty, _, ok := sesh.Pty()
 84			if ok {
 85				width = pty.Window.Width
 86				height = pty.Window.Height
 87			}
 88
 89			user, err := getUser(sesh, dbpool)
 90			if err != nil {
 91				sendutils.ErrorHandler(sesh, err)
 92				return
 93			}
 94
 95			renderer := bm.MakeRenderer(sesh)
 96			renderer.SetColorProfile(termenv.TrueColor)
 97			styles := common.DefaultStyles(renderer)
 98
 99			opts := Cmd{
100				Session: sesh,
101				User:    user,
102				Store:   store,
103				Log:     log,
104				Dbpool:  dbpool,
105				Write:   false,
106				Styles:  styles,
107				Width:   width,
108				Height:  height,
109				Cfg:     handler.Cfg,
110			}
111
112			cmd := strings.TrimSpace(args[0])
113			if len(args) == 1 {
114				if cmd == "help" {
115					opts.help()
116					return
117				} else if cmd == "stats" {
118					err := opts.stats(cfg.MaxSize)
119					opts.bail(err)
120					return
121				} else if cmd == "ls" {
122					err := opts.ls()
123					opts.bail(err)
124					return
125				} else if cmd == "cache-all" {
126					opts.Write = true
127					err := opts.cacheAll()
128					opts.notice()
129					opts.bail(err)
130					return
131				} else {
132					next(sesh)
133					return
134				}
135			}
136
137			projectName := strings.TrimSpace(args[1])
138			cmdArgs := args[2:]
139			log.Info(
140				"pgs middleware detected command",
141				"args", args,
142				"cmd", cmd,
143				"projectName", projectName,
144				"cmdArgs", cmdArgs,
145			)
146
147			if cmd == "stats" {
148				err := opts.statsByProject(projectName)
149				opts.bail(err)
150				return
151			} else if cmd == "link" {
152				linkCmd, write := flagSet("link", sesh)
153				linkTo := linkCmd.String("to", "", "symbolic link to this project")
154				if !flagCheck(linkCmd, projectName, cmdArgs) {
155					return
156				}
157				opts.Write = *write
158
159				if *linkTo == "" {
160					err := fmt.Errorf(
161						"must provide `--to` flag",
162					)
163					opts.bail(err)
164					return
165				}
166
167				err := opts.link(projectName, *linkTo)
168				opts.notice()
169				if err != nil {
170					opts.bail(err)
171				}
172				return
173			} else if cmd == "unlink" {
174				unlinkCmd, write := flagSet("unlink", sesh)
175				if !flagCheck(unlinkCmd, projectName, cmdArgs) {
176					return
177				}
178				opts.Write = *write
179
180				err := opts.unlink(projectName)
181				opts.notice()
182				opts.bail(err)
183				return
184			} else if cmd == "depends" {
185				err := opts.depends(projectName)
186				opts.bail(err)
187				return
188			} else if cmd == "retain" {
189				retainCmd, write := flagSet("retain", sesh)
190				retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
191				if !flagCheck(retainCmd, projectName, cmdArgs) {
192					return
193				}
194				opts.Write = *write
195
196				err := opts.prune(projectName, *retainNum)
197				opts.notice()
198				opts.bail(err)
199				return
200			} else if cmd == "prune" {
201				pruneCmd, write := flagSet("prune", sesh)
202				if !flagCheck(pruneCmd, projectName, cmdArgs) {
203					return
204				}
205				opts.Write = *write
206
207				err := opts.prune(projectName, 0)
208				opts.notice()
209				opts.bail(err)
210				return
211			} else if cmd == "rm" {
212				rmCmd, write := flagSet("rm", sesh)
213				if !flagCheck(rmCmd, projectName, cmdArgs) {
214					return
215				}
216				opts.Write = *write
217
218				err := opts.rm(projectName)
219				opts.notice()
220				opts.bail(err)
221				return
222			} else if cmd == "cache" {
223				cacheCmd, write := flagSet("cache", sesh)
224				if !flagCheck(cacheCmd, projectName, cmdArgs) {
225					return
226				}
227				opts.Write = *write
228
229				err := opts.cache(projectName)
230				opts.notice()
231				opts.bail(err)
232				return
233			} else if cmd == "acl" {
234				aclCmd, write := flagSet("acl", sesh)
235				aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
236				var acls arrayFlags
237				aclCmd.Var(
238					&acls,
239					"acl",
240					"list of pico usernames or sha256 public keys, delimited by commas",
241				)
242				if !flagCheck(aclCmd, projectName, cmdArgs) {
243					return
244				}
245				opts.Write = *write
246
247				if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
248					err := fmt.Errorf(
249						"acl type must be one of the following: [public, pubkeys, pico], found %s",
250						*aclType,
251					)
252					opts.bail(err)
253					return
254				}
255
256				err := opts.acl(projectName, *aclType, acls)
257				opts.notice()
258				opts.bail(err)
259			} else {
260				next(sesh)
261				return
262			}
263		}
264	}
265}