Eric Bower
·
19 Aug 24
wish.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/ssh"
10 "github.com/charmbracelet/wish"
11 bm "github.com/charmbracelet/wish/bubbletea"
12 "github.com/picosh/pico/db"
13 uploadassets "github.com/picosh/pico/filehandlers/assets"
14 "github.com/picosh/pico/shared"
15 "github.com/picosh/pico/tui/common"
16 "github.com/picosh/send/send/utils"
17)
18
19func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
20 var err error
21 key, err := shared.KeyText(s)
22 if err != nil {
23 return nil, fmt.Errorf("key not found")
24 }
25
26 user, err := dbpool.FindUserForKey(s.User(), key)
27 if err != nil {
28 return nil, err
29 }
30
31 if user.Name == "" {
32 return nil, fmt.Errorf("must have username set")
33 }
34
35 return user, nil
36}
37
38type arrayFlags []string
39
40func (i *arrayFlags) String() string {
41 return "array flags"
42}
43
44func (i *arrayFlags) Set(value string) error {
45 *i = append(*i, value)
46 return nil
47}
48
49func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
50 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
51 cmd.SetOutput(sesh)
52 write := cmd.Bool("write", false, "apply changes")
53 return cmd, write
54}
55
56func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
57 _ = cmd.Parse(cmdArgs)
58
59 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
60 cmd.Usage()
61 return false
62 }
63 return true
64}
65
66func WishMiddleware(handler *uploadassets.UploadAssetHandler) wish.Middleware {
67 dbpool := handler.DBPool
68 log := handler.Cfg.Logger
69 cfg := handler.Cfg
70 store := handler.Storage
71
72 return func(next ssh.Handler) ssh.Handler {
73 return func(sesh ssh.Session) {
74 args := sesh.Command()
75 if len(args) == 0 {
76 next(sesh)
77 return
78 }
79
80 // default width and height when no pty
81 width := 100
82 height := 24
83 pty, _, ok := sesh.Pty()
84 if ok {
85 width = pty.Window.Width
86 height = pty.Window.Height
87 }
88
89 user, err := getUser(sesh, dbpool)
90 if err != nil {
91 utils.ErrorHandler(sesh, err)
92 return
93 }
94
95 renderer := bm.MakeRenderer(sesh)
96 styles := common.DefaultStyles(renderer)
97
98 opts := Cmd{
99 Session: sesh,
100 User: user,
101 Store: store,
102 Log: log,
103 Dbpool: dbpool,
104 Write: false,
105 Styles: styles,
106 Width: width,
107 Height: height,
108 }
109
110 cmd := strings.TrimSpace(args[0])
111 if len(args) == 1 {
112 if cmd == "help" {
113 opts.help()
114 return
115 } else if cmd == "stats" {
116 err := opts.stats(cfg.MaxSize)
117 opts.bail(err)
118 return
119 } else if cmd == "ls" {
120 err := opts.ls()
121 opts.bail(err)
122 return
123 } else {
124 next(sesh)
125 return
126 }
127 }
128
129 projectName := strings.TrimSpace(args[1])
130 cmdArgs := args[2:]
131 log.Info(
132 "pgs middleware detected command",
133 "args", args,
134 "cmd", cmd,
135 "projectName", projectName,
136 "cmdArgs", cmdArgs,
137 )
138
139 if cmd == "stats" {
140 err := opts.statsByProject(projectName)
141 opts.bail(err)
142 return
143 } else if cmd == "link" {
144 linkCmd, write := flagSet("link", sesh)
145 linkTo := linkCmd.String("to", "", "symbolic link to this project")
146 if !flagCheck(linkCmd, projectName, cmdArgs) {
147 return
148 }
149 opts.Write = *write
150
151 if *linkTo == "" {
152 err := fmt.Errorf(
153 "must provide `--to` flag",
154 )
155 opts.bail(err)
156 return
157 }
158
159 err := opts.link(projectName, *linkTo)
160 opts.notice()
161 if err != nil {
162 opts.bail(err)
163 }
164 return
165 } else if cmd == "unlink" {
166 unlinkCmd, write := flagSet("unlink", sesh)
167 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
168 return
169 }
170 opts.Write = *write
171
172 err := opts.unlink(projectName)
173 opts.notice()
174 opts.bail(err)
175 return
176 } else if cmd == "depends" {
177 err := opts.depends(projectName)
178 opts.bail(err)
179 return
180 } else if cmd == "retain" {
181 retainCmd, write := flagSet("retain", sesh)
182 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
183 if !flagCheck(retainCmd, projectName, cmdArgs) {
184 return
185 }
186 opts.Write = *write
187
188 err := opts.prune(projectName, *retainNum)
189 opts.notice()
190 opts.bail(err)
191 return
192 } else if cmd == "prune" {
193 pruneCmd, write := flagSet("prune", sesh)
194 if !flagCheck(pruneCmd, projectName, cmdArgs) {
195 return
196 }
197 opts.Write = *write
198
199 err := opts.prune(projectName, 0)
200 opts.notice()
201 opts.bail(err)
202 return
203 } else if cmd == "rm" {
204 rmCmd, write := flagSet("rm", sesh)
205 if !flagCheck(rmCmd, projectName, cmdArgs) {
206 return
207 }
208 opts.Write = *write
209
210 err := opts.rm(projectName)
211 opts.notice()
212 opts.bail(err)
213 return
214 } else if cmd == "acl" {
215 aclCmd, write := flagSet("acl", sesh)
216 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
217 var acls arrayFlags
218 aclCmd.Var(
219 &acls,
220 "acl",
221 "list of pico usernames or sha256 public keys, delimited by commas",
222 )
223 if !flagCheck(aclCmd, projectName, cmdArgs) {
224 return
225 }
226 opts.Write = *write
227
228 if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
229 err := fmt.Errorf(
230 "acl type must be one of the following: [public, pubkeys, pico], found %s",
231 *aclType,
232 )
233 opts.bail(err)
234 return
235 }
236
237 err := opts.acl(projectName, *aclType, acls)
238 opts.notice()
239 opts.bail(err)
240 } else {
241 next(sesh)
242 return
243 }
244 }
245 }
246}