Eric Bower
·
04 Dec 24
wish.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/ssh"
10 "github.com/charmbracelet/wish"
11 bm "github.com/charmbracelet/wish/bubbletea"
12 "github.com/muesli/termenv"
13 "github.com/picosh/pico/db"
14 "github.com/picosh/pico/tui/common"
15 sendutils "github.com/picosh/send/utils"
16 "github.com/picosh/utils"
17)
18
19func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
20 if s.PublicKey() == nil {
21 return nil, fmt.Errorf("key not found")
22 }
23
24 key := utils.KeyForKeyText(s.PublicKey())
25
26 user, err := dbpool.FindUserForKey(s.User(), key)
27 if err != nil {
28 return nil, err
29 }
30
31 if user.Name == "" {
32 return nil, fmt.Errorf("must have username set")
33 }
34
35 return user, nil
36}
37
38type arrayFlags []string
39
40func (i *arrayFlags) String() string {
41 return "array flags"
42}
43
44func (i *arrayFlags) Set(value string) error {
45 *i = append(*i, value)
46 return nil
47}
48
49func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
50 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
51 cmd.SetOutput(sesh)
52 write := cmd.Bool("write", false, "apply changes")
53 return cmd, write
54}
55
56func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
57 _ = cmd.Parse(cmdArgs)
58
59 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
60 cmd.Usage()
61 return false
62 }
63 return true
64}
65
66func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
67 dbpool := handler.DBPool
68 log := handler.Cfg.Logger
69 cfg := handler.Cfg
70 store := handler.Storage
71
72 return func(next ssh.Handler) ssh.Handler {
73 return func(sesh ssh.Session) {
74 args := sesh.Command()
75 if len(args) == 0 {
76 next(sesh)
77 return
78 }
79
80 // default width and height when no pty
81 width := 100
82 height := 24
83 pty, _, ok := sesh.Pty()
84 if ok {
85 width = pty.Window.Width
86 height = pty.Window.Height
87 }
88
89 user, err := getUser(sesh, dbpool)
90 if err != nil {
91 sendutils.ErrorHandler(sesh, err)
92 return
93 }
94
95 renderer := bm.MakeRenderer(sesh)
96 renderer.SetColorProfile(termenv.TrueColor)
97 styles := common.DefaultStyles(renderer)
98
99 opts := Cmd{
100 Session: sesh,
101 User: user,
102 Store: store,
103 Log: log,
104 Dbpool: dbpool,
105 Write: false,
106 Styles: styles,
107 Width: width,
108 Height: height,
109 Cfg: handler.Cfg,
110 }
111
112 cmd := strings.TrimSpace(args[0])
113 if len(args) == 1 {
114 if cmd == "help" {
115 opts.help()
116 return
117 } else if cmd == "stats" {
118 err := opts.stats(cfg.MaxSize)
119 opts.bail(err)
120 return
121 } else if cmd == "ls" {
122 err := opts.ls()
123 opts.bail(err)
124 return
125 } else if cmd == "cache-all" {
126 opts.Write = true
127 err := opts.cacheAll()
128 opts.notice()
129 opts.bail(err)
130 return
131 } else {
132 next(sesh)
133 return
134 }
135 }
136
137 projectName := strings.TrimSpace(args[1])
138 cmdArgs := args[2:]
139 log.Info(
140 "pgs middleware detected command",
141 "args", args,
142 "cmd", cmd,
143 "projectName", projectName,
144 "cmdArgs", cmdArgs,
145 )
146
147 if cmd == "stats" {
148 err := opts.statsByProject(projectName)
149 opts.bail(err)
150 return
151 } else if cmd == "link" {
152 linkCmd, write := flagSet("link", sesh)
153 linkTo := linkCmd.String("to", "", "symbolic link to this project")
154 if !flagCheck(linkCmd, projectName, cmdArgs) {
155 return
156 }
157 opts.Write = *write
158
159 if *linkTo == "" {
160 err := fmt.Errorf(
161 "must provide `--to` flag",
162 )
163 opts.bail(err)
164 return
165 }
166
167 err := opts.link(projectName, *linkTo)
168 opts.notice()
169 if err != nil {
170 opts.bail(err)
171 }
172 return
173 } else if cmd == "unlink" {
174 unlinkCmd, write := flagSet("unlink", sesh)
175 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
176 return
177 }
178 opts.Write = *write
179
180 err := opts.unlink(projectName)
181 opts.notice()
182 opts.bail(err)
183 return
184 } else if cmd == "depends" {
185 err := opts.depends(projectName)
186 opts.bail(err)
187 return
188 } else if cmd == "retain" {
189 retainCmd, write := flagSet("retain", sesh)
190 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
191 if !flagCheck(retainCmd, projectName, cmdArgs) {
192 return
193 }
194 opts.Write = *write
195
196 err := opts.prune(projectName, *retainNum)
197 opts.notice()
198 opts.bail(err)
199 return
200 } else if cmd == "prune" {
201 pruneCmd, write := flagSet("prune", sesh)
202 if !flagCheck(pruneCmd, projectName, cmdArgs) {
203 return
204 }
205 opts.Write = *write
206
207 err := opts.prune(projectName, 0)
208 opts.notice()
209 opts.bail(err)
210 return
211 } else if cmd == "rm" {
212 rmCmd, write := flagSet("rm", sesh)
213 if !flagCheck(rmCmd, projectName, cmdArgs) {
214 return
215 }
216 opts.Write = *write
217
218 err := opts.rm(projectName)
219 opts.notice()
220 opts.bail(err)
221 return
222 } else if cmd == "cache" {
223 cacheCmd, write := flagSet("cache", sesh)
224 if !flagCheck(cacheCmd, projectName, cmdArgs) {
225 return
226 }
227 opts.Write = *write
228
229 err := opts.cache(projectName)
230 opts.notice()
231 opts.bail(err)
232 return
233 } else if cmd == "acl" {
234 aclCmd, write := flagSet("acl", sesh)
235 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
236 var acls arrayFlags
237 aclCmd.Var(
238 &acls,
239 "acl",
240 "list of pico usernames or sha256 public keys, delimited by commas",
241 )
242 if !flagCheck(aclCmd, projectName, cmdArgs) {
243 return
244 }
245 opts.Write = *write
246
247 if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
248 err := fmt.Errorf(
249 "acl type must be one of the following: [public, pubkeys, pico], found %s",
250 *aclType,
251 )
252 opts.bail(err)
253 return
254 }
255
256 err := opts.acl(projectName, *aclType, acls)
257 opts.notice()
258 opts.bail(err)
259 } else {
260 next(sesh)
261 return
262 }
263 }
264 }
265}