Eric Bower
·
28 Oct 24
wish.go
1package pgs
2
3import (
4 "flag"
5 "fmt"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/ssh"
10 "github.com/charmbracelet/wish"
11 bm "github.com/charmbracelet/wish/bubbletea"
12 "github.com/picosh/pico/db"
13 "github.com/picosh/pico/tui/common"
14 sendutils "github.com/picosh/send/utils"
15 "github.com/picosh/utils"
16)
17
18func getUser(s ssh.Session, dbpool db.DB) (*db.User, error) {
19 if s.PublicKey() == nil {
20 return nil, fmt.Errorf("key not found")
21 }
22
23 key := utils.KeyForKeyText(s.PublicKey())
24
25 user, err := dbpool.FindUserForKey(s.User(), key)
26 if err != nil {
27 return nil, err
28 }
29
30 if user.Name == "" {
31 return nil, fmt.Errorf("must have username set")
32 }
33
34 return user, nil
35}
36
37type arrayFlags []string
38
39func (i *arrayFlags) String() string {
40 return "array flags"
41}
42
43func (i *arrayFlags) Set(value string) error {
44 *i = append(*i, value)
45 return nil
46}
47
48func flagSet(cmdName string, sesh ssh.Session) (*flag.FlagSet, *bool) {
49 cmd := flag.NewFlagSet(cmdName, flag.ContinueOnError)
50 cmd.SetOutput(sesh)
51 write := cmd.Bool("write", false, "apply changes")
52 return cmd, write
53}
54
55func flagCheck(cmd *flag.FlagSet, posArg string, cmdArgs []string) bool {
56 _ = cmd.Parse(cmdArgs)
57
58 if posArg == "-h" || posArg == "--help" || posArg == "-help" {
59 cmd.Usage()
60 return false
61 }
62 return true
63}
64
65func WishMiddleware(handler *UploadAssetHandler) wish.Middleware {
66 dbpool := handler.DBPool
67 log := handler.Cfg.Logger
68 cfg := handler.Cfg
69 store := handler.Storage
70
71 return func(next ssh.Handler) ssh.Handler {
72 return func(sesh ssh.Session) {
73 args := sesh.Command()
74 if len(args) == 0 {
75 next(sesh)
76 return
77 }
78
79 // default width and height when no pty
80 width := 100
81 height := 24
82 pty, _, ok := sesh.Pty()
83 if ok {
84 width = pty.Window.Width
85 height = pty.Window.Height
86 }
87
88 user, err := getUser(sesh, dbpool)
89 if err != nil {
90 sendutils.ErrorHandler(sesh, err)
91 return
92 }
93
94 renderer := bm.MakeRenderer(sesh)
95 styles := common.DefaultStyles(renderer)
96
97 opts := Cmd{
98 Session: sesh,
99 User: user,
100 Store: store,
101 Log: log,
102 Dbpool: dbpool,
103 Write: false,
104 Styles: styles,
105 Width: width,
106 Height: height,
107 }
108
109 cmd := strings.TrimSpace(args[0])
110 if len(args) == 1 {
111 if cmd == "help" {
112 opts.help()
113 return
114 } else if cmd == "stats" {
115 err := opts.stats(cfg.MaxSize)
116 opts.bail(err)
117 return
118 } else if cmd == "ls" {
119 err := opts.ls()
120 opts.bail(err)
121 return
122 } else {
123 next(sesh)
124 return
125 }
126 }
127
128 projectName := strings.TrimSpace(args[1])
129 cmdArgs := args[2:]
130 log.Info(
131 "pgs middleware detected command",
132 "args", args,
133 "cmd", cmd,
134 "projectName", projectName,
135 "cmdArgs", cmdArgs,
136 )
137
138 if cmd == "stats" {
139 err := opts.statsByProject(projectName)
140 opts.bail(err)
141 return
142 } else if cmd == "link" {
143 linkCmd, write := flagSet("link", sesh)
144 linkTo := linkCmd.String("to", "", "symbolic link to this project")
145 if !flagCheck(linkCmd, projectName, cmdArgs) {
146 return
147 }
148 opts.Write = *write
149
150 if *linkTo == "" {
151 err := fmt.Errorf(
152 "must provide `--to` flag",
153 )
154 opts.bail(err)
155 return
156 }
157
158 err := opts.link(projectName, *linkTo)
159 opts.notice()
160 if err != nil {
161 opts.bail(err)
162 }
163 return
164 } else if cmd == "unlink" {
165 unlinkCmd, write := flagSet("unlink", sesh)
166 if !flagCheck(unlinkCmd, projectName, cmdArgs) {
167 return
168 }
169 opts.Write = *write
170
171 err := opts.unlink(projectName)
172 opts.notice()
173 opts.bail(err)
174 return
175 } else if cmd == "depends" {
176 err := opts.depends(projectName)
177 opts.bail(err)
178 return
179 } else if cmd == "retain" {
180 retainCmd, write := flagSet("retain", sesh)
181 retainNum := retainCmd.Int("n", 3, "latest number of projects to keep")
182 if !flagCheck(retainCmd, projectName, cmdArgs) {
183 return
184 }
185 opts.Write = *write
186
187 err := opts.prune(projectName, *retainNum)
188 opts.notice()
189 opts.bail(err)
190 return
191 } else if cmd == "prune" {
192 pruneCmd, write := flagSet("prune", sesh)
193 if !flagCheck(pruneCmd, projectName, cmdArgs) {
194 return
195 }
196 opts.Write = *write
197
198 err := opts.prune(projectName, 0)
199 opts.notice()
200 opts.bail(err)
201 return
202 } else if cmd == "rm" {
203 rmCmd, write := flagSet("rm", sesh)
204 if !flagCheck(rmCmd, projectName, cmdArgs) {
205 return
206 }
207 opts.Write = *write
208
209 err := opts.rm(projectName)
210 opts.notice()
211 opts.bail(err)
212 return
213 } else if cmd == "acl" {
214 aclCmd, write := flagSet("acl", sesh)
215 aclType := aclCmd.String("type", "", "access type: public, pico, pubkeys")
216 var acls arrayFlags
217 aclCmd.Var(
218 &acls,
219 "acl",
220 "list of pico usernames or sha256 public keys, delimited by commas",
221 )
222 if !flagCheck(aclCmd, projectName, cmdArgs) {
223 return
224 }
225 opts.Write = *write
226
227 if !slices.Contains([]string{"public", "pubkeys", "pico"}, *aclType) {
228 err := fmt.Errorf(
229 "acl type must be one of the following: [public, pubkeys, pico], found %s",
230 *aclType,
231 )
232 opts.bail(err)
233 return
234 }
235
236 err := opts.acl(projectName, *aclType, acls)
237 opts.notice()
238 opts.bail(err)
239 } else {
240 next(sesh)
241 return
242 }
243 }
244 }
245}