Eric Bower
·
20 Sep 24
api.go
1package auth
2
3import (
4 "context"
5 "crypto/hmac"
6 "encoding/json"
7 "fmt"
8 "html/template"
9 "io"
10 "log/slog"
11 "net/http"
12 "net/url"
13 "os"
14 "strings"
15 "time"
16
17 "github.com/gorilla/feeds"
18 "github.com/picosh/pico/db"
19 "github.com/picosh/pico/db/postgres"
20 "github.com/picosh/pico/shared"
21)
22
23type Client struct {
24 Cfg *AuthCfg
25 Dbpool db.DB
26 Logger *slog.Logger
27}
28
29func (client *Client) hasPrivilegedAccess(apiToken string) bool {
30 user, err := client.Dbpool.FindUserForToken(apiToken)
31 if err != nil {
32 return false
33 }
34 return client.Dbpool.HasFeatureForUser(user.ID, "auth")
35}
36
37type ctxClient struct{}
38type ctxKey struct{}
39
40func getClient(r *http.Request) *Client {
41 return r.Context().Value(ctxClient{}).(*Client)
42}
43
44func getField(r *http.Request, index int) string {
45 fields := r.Context().Value(ctxKey{}).([]string)
46 if index >= len(fields) {
47 return ""
48 }
49 return fields[index]
50}
51
52func getApiToken(r *http.Request) string {
53 authHeader := r.Header.Get("authorization")
54 if authHeader == "" {
55 return ""
56 }
57 return strings.TrimPrefix(authHeader, "Bearer ")
58}
59
60type oauth2Server struct {
61 Issuer string `json:"issuer"`
62 IntrospectionEndpoint string `json:"introspection_endpoint"`
63 IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"`
64 AuthorizationEndpoint string `json:"authorization_endpoint"`
65 TokenEndpoint string `json:"token_endpoint"`
66 ResponseTypesSupported []string `json:"response_types_supported"`
67}
68
69func generateURL(cfg *AuthCfg, path string) string {
70 return fmt.Sprintf("%s/%s", cfg.Domain, path)
71}
72
73func wellKnownHandler(w http.ResponseWriter, r *http.Request) {
74 client := getClient(r)
75
76 p := oauth2Server{
77 Issuer: client.Cfg.Issuer,
78 IntrospectionEndpoint: generateURL(client.Cfg, "introspect"),
79 IntrospectionEndpointAuthMethodsSupported: []string{
80 "none",
81 },
82 AuthorizationEndpoint: generateURL(client.Cfg, "authorize"),
83 TokenEndpoint: generateURL(client.Cfg, "token"),
84 ResponseTypesSupported: []string{"code"},
85 }
86 w.Header().Set("Content-Type", "application/json")
87 w.WriteHeader(http.StatusOK)
88 err := json.NewEncoder(w).Encode(p)
89 if err != nil {
90 client.Logger.Error(err.Error())
91 http.Error(w, err.Error(), http.StatusInternalServerError)
92 }
93}
94
95type oauth2Introspection struct {
96 Active bool `json:"active"`
97 Username string `json:"username"`
98}
99
100func introspectHandler(w http.ResponseWriter, r *http.Request) {
101 client := getClient(r)
102 token := r.FormValue("token")
103 client.Logger.Info("introspect token", "token", token)
104
105 user, err := client.Dbpool.FindUserForToken(token)
106 if err != nil {
107 client.Logger.Error(err.Error())
108 http.Error(w, err.Error(), http.StatusUnauthorized)
109 return
110 }
111
112 p := oauth2Introspection{
113 Active: true,
114 Username: user.Name,
115 }
116 w.Header().Set("Content-Type", "application/json")
117 w.WriteHeader(http.StatusOK)
118 err = json.NewEncoder(w).Encode(p)
119 if err != nil {
120 client.Logger.Error(err.Error())
121 http.Error(w, err.Error(), http.StatusInternalServerError)
122 }
123}
124
125func authorizeHandler(w http.ResponseWriter, r *http.Request) {
126 client := getClient(r)
127
128 responseType := r.URL.Query().Get("response_type")
129 clientID := r.URL.Query().Get("client_id")
130 redirectURI := r.URL.Query().Get("redirect_uri")
131 scope := r.URL.Query().Get("scope")
132
133 client.Logger.Info(
134 "authorize handler",
135 "responseType", responseType,
136 "clientID", clientID,
137 "redirectURI", redirectURI,
138 "scope", scope,
139 )
140
141 ts, err := template.ParseFiles(
142 "auth/html/redirect.page.tmpl",
143 "auth/html/footer.partial.tmpl",
144 "auth/html/marketing-footer.partial.tmpl",
145 "auth/html/base.layout.tmpl",
146 )
147
148 if err != nil {
149 client.Logger.Error(err.Error())
150 http.Error(w, err.Error(), http.StatusUnauthorized)
151 return
152 }
153
154 err = ts.Execute(w, map[string]any{
155 "response_type": responseType,
156 "client_id": clientID,
157 "redirect_uri": redirectURI,
158 "scope": scope,
159 })
160
161 if err != nil {
162 client.Logger.Error(err.Error())
163 http.Error(w, err.Error(), http.StatusUnauthorized)
164 return
165 }
166}
167
168func redirectHandler(w http.ResponseWriter, r *http.Request) {
169 client := getClient(r)
170
171 token := r.FormValue("token")
172 redirectURI := r.FormValue("redirect_uri")
173 responseType := r.FormValue("response_type")
174
175 client.Logger.Info("redirect handler",
176 "token", token,
177 "redirectURI", redirectURI,
178 "responseType", responseType,
179 )
180
181 if token == "" || redirectURI == "" || responseType != "code" {
182 http.Error(w, "bad request", http.StatusBadRequest)
183 return
184 }
185
186 url, err := url.Parse(redirectURI)
187 if err != nil {
188 http.Error(w, err.Error(), http.StatusBadRequest)
189 return
190 }
191
192 urlQuery := url.Query()
193 urlQuery.Add("code", token)
194
195 url.RawQuery = urlQuery.Encode()
196
197 http.Redirect(w, r, url.String(), http.StatusFound)
198}
199
200type oauth2Token struct {
201 AccessToken string `json:"access_token"`
202}
203
204func tokenHandler(w http.ResponseWriter, r *http.Request) {
205 client := getClient(r)
206
207 token := r.FormValue("code")
208 redirectURI := r.FormValue("redirect_uri")
209 grantType := r.FormValue("grant_type")
210
211 client.Logger.Info(
212 "handle token",
213 "token", token,
214 "redirectURI", redirectURI,
215 "grantType", grantType,
216 )
217
218 _, err := client.Dbpool.FindUserForToken(token)
219 if err != nil {
220 client.Logger.Error(err.Error())
221 http.Error(w, err.Error(), http.StatusUnauthorized)
222 return
223 }
224
225 p := oauth2Token{
226 AccessToken: token,
227 }
228 w.Header().Set("Content-Type", "application/json")
229 w.WriteHeader(http.StatusOK)
230 err = json.NewEncoder(w).Encode(p)
231 if err != nil {
232 client.Logger.Error(err.Error())
233 http.Error(w, err.Error(), http.StatusInternalServerError)
234 }
235}
236
237type sishData struct {
238 PublicKey string `json:"auth_key"`
239 Username string `json:"user"`
240 RemoteAddress string `json:"remote_addr"`
241}
242
243func keyHandler(w http.ResponseWriter, r *http.Request) {
244 client := getClient(r)
245
246 var data sishData
247
248 err := json.NewDecoder(r.Body).Decode(&data)
249 if err != nil {
250 client.Logger.Error(err.Error())
251 http.Error(w, err.Error(), http.StatusBadRequest)
252 return
253 }
254
255 space := r.URL.Query().Get("space")
256 if space == "" {
257 spaceErr := fmt.Errorf("Must provide `space` query parameter")
258 client.Logger.Error(spaceErr.Error())
259 http.Error(w, spaceErr.Error(), http.StatusUnprocessableEntity)
260 }
261
262 client.Logger.Info(
263 "handle key",
264 "remoteAddress", data.RemoteAddress,
265 "user", data.Username,
266 "space", space,
267 "publicKey", data.PublicKey,
268 )
269
270 user, err := client.Dbpool.FindUserForKey(data.Username, data.PublicKey)
271 if err != nil {
272 client.Logger.Error(err.Error())
273 http.Error(w, err.Error(), http.StatusUnauthorized)
274 return
275 }
276
277 if space == "tuns" {
278 if !client.Dbpool.HasFeatureForUser(user.ID, "plus") {
279 w.WriteHeader(http.StatusUnauthorized)
280 return
281 }
282 } else if !client.Dbpool.HasFeatureForUser(user.ID, space) {
283 w.WriteHeader(http.StatusUnauthorized)
284 return
285 }
286
287 if !client.hasPrivilegedAccess(getApiToken(r)) {
288 w.WriteHeader(http.StatusOK)
289 return
290 }
291
292 w.Header().Set("Content-Type", "application/json")
293 w.WriteHeader(http.StatusOK)
294 err = json.NewEncoder(w).Encode(user)
295 if err != nil {
296 client.Logger.Error(err.Error())
297 http.Error(w, err.Error(), http.StatusInternalServerError)
298 }
299}
300
301func userHandler(w http.ResponseWriter, r *http.Request) {
302 client := getClient(r)
303
304 if !client.hasPrivilegedAccess(getApiToken(r)) {
305 w.WriteHeader(http.StatusForbidden)
306 return
307 }
308
309 var data sishData
310
311 err := json.NewDecoder(r.Body).Decode(&data)
312 if err != nil {
313 client.Logger.Error(err.Error())
314 http.Error(w, err.Error(), http.StatusBadRequest)
315 return
316 }
317
318 client.Logger.Info(
319 "handle key",
320 "remoteAddress", data.RemoteAddress,
321 "user", data.Username,
322 "publicKey", data.PublicKey,
323 )
324
325 user, err := client.Dbpool.FindUserForName(data.Username)
326 if err != nil {
327 client.Logger.Error(err.Error())
328 http.Error(w, err.Error(), http.StatusNotFound)
329 return
330 }
331
332 keys, err := client.Dbpool.FindKeysForUser(user)
333 if err != nil {
334 client.Logger.Error(err.Error())
335 http.Error(w, err.Error(), http.StatusNotFound)
336 return
337 }
338
339 w.Header().Set("Content-Type", "application/json")
340 w.WriteHeader(http.StatusOK)
341 err = json.NewEncoder(w).Encode(keys)
342 if err != nil {
343 client.Logger.Error(err.Error())
344 http.Error(w, err.Error(), http.StatusInternalServerError)
345 }
346}
347
348func genFeedItem(now time.Time, expiresAt time.Time, warning time.Time, txt string) *feeds.Item {
349 if now.After(warning) {
350 content := fmt.Sprintf(
351 "Your pico+ membership is going to expire on %s",
352 expiresAt.Format("2006-01-02 15:04:05"),
353 )
354 return &feeds.Item{
355 Id: fmt.Sprintf("%d", warning.Unix()),
356 Title: fmt.Sprintf("pico+ %s expiration notice", txt),
357 Link: &feeds.Link{Href: "https://pico.sh"},
358 Content: content,
359 Created: warning,
360 Updated: warning,
361 Description: content,
362 Author: &feeds.Author{Name: "team pico"},
363 }
364 }
365
366 return nil
367}
368
369func rssHandler(w http.ResponseWriter, r *http.Request) {
370 client := getClient(r)
371 apiToken, err := url.PathUnescape(getField(r, 0))
372 if err != nil {
373 client.Logger.Error(err.Error())
374 http.Error(w, err.Error(), http.StatusNotFound)
375 return
376 }
377 user, err := client.Dbpool.FindUserForToken(apiToken)
378 if err != nil {
379 client.Logger.Error(err.Error())
380 http.Error(w, "invalid token", http.StatusNotFound)
381 return
382 }
383
384 href := fmt.Sprintf("https://auth.pico.sh/rss/%s", apiToken)
385
386 feed := &feeds.Feed{
387 Title: "pico+",
388 Link: &feeds.Link{Href: href},
389 Description: "get notified of important membership updates",
390 Author: &feeds.Author{Name: "team pico"},
391 Created: time.Now(),
392 }
393 var feedItems []*feeds.Item
394
395 now := time.Now()
396 ff, err := client.Dbpool.FindFeatureForUser(user.ID, "plus")
397 if err != nil {
398 // still want to send an empty feed
399 } else {
400 createdAt := ff.CreatedAt
401 createdAtStr := createdAt.Format("2006-01-02 15:04:05")
402 id := fmt.Sprintf("pico-plus-activated-%d", createdAt.Unix())
403 content := `Thanks for joining pico+! You now have access to all our premium services for exactly one year. We will send you pico+ expiration notifications through this RSS feed. Go to <a href="https://pico.sh/getting-started#next-steps">pico.sh/getting-started#next-steps</a> to start using our services.`
404 plus := &feeds.Item{
405 Id: id,
406 Title: fmt.Sprintf("pico+ membership activated on %s", createdAtStr),
407 Link: &feeds.Link{Href: "https://pico.sh"},
408 Content: content,
409 Created: *createdAt,
410 Updated: *createdAt,
411 Description: content,
412 Author: &feeds.Author{Name: "team pico"},
413 }
414 feedItems = append(feedItems, plus)
415
416 oneMonthWarning := ff.ExpiresAt.AddDate(0, -1, 0)
417 mo := genFeedItem(now, *ff.ExpiresAt, oneMonthWarning, "1-month")
418 if mo != nil {
419 feedItems = append(feedItems, mo)
420 }
421
422 oneWeekWarning := ff.ExpiresAt.AddDate(0, 0, -7)
423 wk := genFeedItem(now, *ff.ExpiresAt, oneWeekWarning, "1-week")
424 if wk != nil {
425 feedItems = append(feedItems, wk)
426 }
427
428 oneDayWarning := ff.ExpiresAt.AddDate(0, 0, -1)
429 day := genFeedItem(now, *ff.ExpiresAt, oneDayWarning, "1-day")
430 if day != nil {
431 feedItems = append(feedItems, day)
432 }
433 }
434
435 feed.Items = feedItems
436
437 rss, err := feed.ToAtom()
438 if err != nil {
439 client.Logger.Error(err.Error())
440 http.Error(w, "Could not generate atom rss feed", http.StatusInternalServerError)
441 }
442
443 w.Header().Add("Content-Type", "application/atom+xml")
444 _, err = w.Write([]byte(rss))
445 if err != nil {
446 client.Logger.Error(err.Error())
447 }
448}
449
450type OrderEvent struct {
451 Meta *struct {
452 EventName string `json:"event_name"`
453 CustomData *struct {
454 PicoUsername string `json:"username"`
455 } `json:"custom_data"`
456 } `json:"meta"`
457 Data *struct {
458 Type string `json:"type"`
459 ID string `json:"id"`
460 Attr *struct {
461 OrderNumber int `json:"order_number"`
462 Identifier string `json:"identifier"`
463 UserName string `json:"user_name"`
464 UserEmail string `json:"user_email"`
465 CreatedAt time.Time `json:"created_at"`
466 Status string `json:"status"` // `paid`, `refund`
467 } `json:"attributes"`
468 } `json:"data"`
469}
470
471// Status code must be 200 or else lemonsqueezy will keep retrying
472// https://docs.lemonsqueezy.com/help/webhooks
473func paymentWebhookHandler(secret string) func(http.ResponseWriter, *http.Request) {
474 return func(w http.ResponseWriter, r *http.Request) {
475 client := getClient(r)
476 dbpool := client.Dbpool
477 logger := client.Logger
478 const MaxBodyBytes = int64(65536)
479 r.Body = http.MaxBytesReader(w, r.Body, MaxBodyBytes)
480 payload, err := io.ReadAll(r.Body)
481 if err != nil {
482 logger.Error("error reading request body", "err", err.Error())
483 w.WriteHeader(http.StatusOK)
484 return
485 }
486
487 event := OrderEvent{}
488
489 if err := json.Unmarshal(payload, &event); err != nil {
490 logger.Error("failed to parse webhook body JSON", "err", err.Error())
491 w.WriteHeader(http.StatusOK)
492 return
493 }
494
495 hash := shared.HmacString(secret, string(payload))
496 sig := r.Header.Get("X-Signature")
497 if !hmac.Equal([]byte(hash), []byte(sig)) {
498 logger.Error("invalid signature X-Signature")
499 w.WriteHeader(http.StatusOK)
500 return
501 }
502
503 if event.Meta == nil {
504 logger.Error("no meta field found")
505 w.WriteHeader(http.StatusOK)
506 return
507 }
508
509 if event.Meta.EventName != "order_created" {
510 logger.Error("event not order_created", "event", event.Meta.EventName)
511 w.WriteHeader(http.StatusOK)
512 return
513 }
514
515 if event.Meta.CustomData == nil {
516 logger.Error("no custom data found")
517 w.WriteHeader(http.StatusOK)
518 return
519 }
520
521 username := event.Meta.CustomData.PicoUsername
522
523 if event.Data == nil || event.Data.Attr == nil {
524 logger.Error("no data or data.attributes fields found")
525 w.WriteHeader(http.StatusOK)
526 return
527 }
528
529 email := event.Data.Attr.UserEmail
530 created := event.Data.Attr.CreatedAt
531 status := event.Data.Attr.Status
532 txID := fmt.Sprint(event.Data.Attr.OrderNumber)
533
534 log := logger.With(
535 "username", username,
536 "email", email,
537 "created", created,
538 "paymentStatus", status,
539 "txId", txID,
540 )
541 log.Info(
542 "order_created event",
543 )
544
545 // https://checkout.pico.sh/buy/35b1be57-1e25-487f-84dd-5f09bb8783ec?discount=0&checkout[custom][username]=erock
546 if username == "" {
547 log.Error("no `?checkout[custom][username]=xxx` found in URL, cannot add pico+ membership")
548 w.WriteHeader(http.StatusOK)
549 return
550 }
551
552 if status != "paid" {
553 log.Error("status not paid")
554 w.WriteHeader(http.StatusOK)
555 return
556 }
557
558 err = dbpool.AddPicoPlusUser(username, "lemonsqueezy", txID)
559 if err != nil {
560 log.Error("failed to add pico+ user", "err", err)
561 } else {
562 log.Info("successfully added pico+ user")
563 }
564
565 w.WriteHeader(http.StatusOK)
566 }
567}
568
569// URL shortener for out pico+ URL.
570func checkoutHandler(w http.ResponseWriter, r *http.Request) {
571 username, err := url.PathUnescape(getField(r, 0))
572 if err != nil {
573 w.WriteHeader(http.StatusUnprocessableEntity)
574 return
575 }
576 link := "https://checkout.pico.sh/buy/73c26cf9-3fac-44c3-b744-298b3032a96b"
577 url := fmt.Sprintf(
578 "%s?discount=0&checkout[custom][username]=%s",
579 link,
580 username,
581 )
582 http.Redirect(w, r, url, http.StatusMovedPermanently)
583}
584
585func createMainRoutes() []shared.Route {
586 fileServer := http.FileServer(http.Dir("auth/public"))
587 secret := os.Getenv("PICO_SECRET_WEBHOOK")
588 if secret == "" {
589 panic("must provide PICO_SECRET_WEBHOOK environment variable")
590 }
591
592 routes := []shared.Route{
593 shared.NewRoute("GET", "/checkout/(.+)", checkoutHandler),
594 shared.NewRoute("GET", "/.well-known/oauth-authorization-server", wellKnownHandler),
595 shared.NewRoute("POST", "/introspect", introspectHandler),
596 shared.NewRoute("GET", "/authorize", authorizeHandler),
597 shared.NewRoute("POST", "/token", tokenHandler),
598 shared.NewRoute("POST", "/key", keyHandler),
599 shared.NewRoute("POST", "/user", userHandler),
600 shared.NewRoute("GET", "/rss/(.+)", rssHandler),
601 shared.NewRoute("POST", "/redirect", redirectHandler),
602 shared.NewRoute("POST", "/webhook", paymentWebhookHandler(secret)),
603 shared.NewRoute("GET", "/main.css", fileServer.ServeHTTP),
604 shared.NewRoute("GET", "/card.png", fileServer.ServeHTTP),
605 shared.NewRoute("GET", "/favicon-16x16.png", fileServer.ServeHTTP),
606 shared.NewRoute("GET", "/favicon-32x32.png", fileServer.ServeHTTP),
607 shared.NewRoute("GET", "/apple-touch-icon.png", fileServer.ServeHTTP),
608 shared.NewRoute("GET", "/favicon.ico", fileServer.ServeHTTP),
609 shared.NewRoute("GET", "/robots.txt", fileServer.ServeHTTP),
610 }
611
612 return routes
613}
614
615func handler(routes []shared.Route, client *Client) shared.ServeFn {
616 return func(w http.ResponseWriter, r *http.Request) {
617 var allow []string
618
619 for _, route := range routes {
620 matches := route.Regex.FindStringSubmatch(r.URL.Path)
621 if len(matches) > 0 {
622 if r.Method != route.Method {
623 allow = append(allow, route.Method)
624 continue
625 }
626 clientCtx := context.WithValue(r.Context(), ctxClient{}, client)
627 ctx := context.WithValue(clientCtx, ctxKey{}, matches[1:])
628 route.Handler(w, r.WithContext(ctx))
629 return
630 }
631 }
632 if len(allow) > 0 {
633 w.Header().Set("Allow", strings.Join(allow, ", "))
634 http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
635 return
636 }
637 http.NotFound(w, r)
638 }
639}
640
641type AuthCfg struct {
642 Debug bool
643 Port string
644 DbURL string
645 Domain string
646 Issuer string
647}
648
649func StartApiServer() {
650 debug := shared.GetEnv("AUTH_DEBUG", "0")
651 cfg := &AuthCfg{
652 DbURL: shared.GetEnv("DATABASE_URL", ""),
653 Debug: debug == "1",
654 Issuer: shared.GetEnv("AUTH_ISSUER", "pico.sh"),
655 Domain: shared.GetEnv("AUTH_DOMAIN", "http://0.0.0.0:3000"),
656 Port: shared.GetEnv("AUTH_WEB_PORT", "3000"),
657 }
658
659 logger := shared.CreateLogger("auth")
660 db := postgres.NewDB(cfg.DbURL, logger)
661 defer db.Close()
662
663 client := &Client{
664 Cfg: cfg,
665 Dbpool: db,
666 Logger: logger,
667 }
668
669 routes := createMainRoutes()
670
671 if cfg.Debug {
672 routes = shared.CreatePProfRoutes(routes)
673 }
674
675 router := http.HandlerFunc(handler(routes, client))
676
677 portStr := fmt.Sprintf(":%s", cfg.Port)
678 client.Logger.Info("starting server on port", "port", cfg.Port)
679 err := http.ListenAndServe(portStr, router)
680 if err != nil {
681 client.Logger.Info(err.Error())
682 }
683}