repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / auth
Eric Bower · 20 Sep 24

api.go

  1package auth
  2
  3import (
  4	"context"
  5	"crypto/hmac"
  6	"encoding/json"
  7	"fmt"
  8	"html/template"
  9	"io"
 10	"log/slog"
 11	"net/http"
 12	"net/url"
 13	"os"
 14	"strings"
 15	"time"
 16
 17	"github.com/gorilla/feeds"
 18	"github.com/picosh/pico/db"
 19	"github.com/picosh/pico/db/postgres"
 20	"github.com/picosh/pico/shared"
 21)
 22
 23type Client struct {
 24	Cfg    *AuthCfg
 25	Dbpool db.DB
 26	Logger *slog.Logger
 27}
 28
 29func (client *Client) hasPrivilegedAccess(apiToken string) bool {
 30	user, err := client.Dbpool.FindUserForToken(apiToken)
 31	if err != nil {
 32		return false
 33	}
 34	return client.Dbpool.HasFeatureForUser(user.ID, "auth")
 35}
 36
 37type ctxClient struct{}
 38type ctxKey struct{}
 39
 40func getClient(r *http.Request) *Client {
 41	return r.Context().Value(ctxClient{}).(*Client)
 42}
 43
 44func getField(r *http.Request, index int) string {
 45	fields := r.Context().Value(ctxKey{}).([]string)
 46	if index >= len(fields) {
 47		return ""
 48	}
 49	return fields[index]
 50}
 51
 52func getApiToken(r *http.Request) string {
 53	authHeader := r.Header.Get("authorization")
 54	if authHeader == "" {
 55		return ""
 56	}
 57	return strings.TrimPrefix(authHeader, "Bearer ")
 58}
 59
 60type oauth2Server struct {
 61	Issuer                                    string   `json:"issuer"`
 62	IntrospectionEndpoint                     string   `json:"introspection_endpoint"`
 63	IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"`
 64	AuthorizationEndpoint                     string   `json:"authorization_endpoint"`
 65	TokenEndpoint                             string   `json:"token_endpoint"`
 66	ResponseTypesSupported                    []string `json:"response_types_supported"`
 67}
 68
 69func generateURL(cfg *AuthCfg, path string) string {
 70	return fmt.Sprintf("%s/%s", cfg.Domain, path)
 71}
 72
 73func wellKnownHandler(w http.ResponseWriter, r *http.Request) {
 74	client := getClient(r)
 75
 76	p := oauth2Server{
 77		Issuer:                client.Cfg.Issuer,
 78		IntrospectionEndpoint: generateURL(client.Cfg, "introspect"),
 79		IntrospectionEndpointAuthMethodsSupported: []string{
 80			"none",
 81		},
 82		AuthorizationEndpoint:  generateURL(client.Cfg, "authorize"),
 83		TokenEndpoint:          generateURL(client.Cfg, "token"),
 84		ResponseTypesSupported: []string{"code"},
 85	}
 86	w.Header().Set("Content-Type", "application/json")
 87	w.WriteHeader(http.StatusOK)
 88	err := json.NewEncoder(w).Encode(p)
 89	if err != nil {
 90		client.Logger.Error(err.Error())
 91		http.Error(w, err.Error(), http.StatusInternalServerError)
 92	}
 93}
 94
 95type oauth2Introspection struct {
 96	Active   bool   `json:"active"`
 97	Username string `json:"username"`
 98}
 99
100func introspectHandler(w http.ResponseWriter, r *http.Request) {
101	client := getClient(r)
102	token := r.FormValue("token")
103	client.Logger.Info("introspect token", "token", token)
104
105	user, err := client.Dbpool.FindUserForToken(token)
106	if err != nil {
107		client.Logger.Error(err.Error())
108		http.Error(w, err.Error(), http.StatusUnauthorized)
109		return
110	}
111
112	p := oauth2Introspection{
113		Active:   true,
114		Username: user.Name,
115	}
116	w.Header().Set("Content-Type", "application/json")
117	w.WriteHeader(http.StatusOK)
118	err = json.NewEncoder(w).Encode(p)
119	if err != nil {
120		client.Logger.Error(err.Error())
121		http.Error(w, err.Error(), http.StatusInternalServerError)
122	}
123}
124
125func authorizeHandler(w http.ResponseWriter, r *http.Request) {
126	client := getClient(r)
127
128	responseType := r.URL.Query().Get("response_type")
129	clientID := r.URL.Query().Get("client_id")
130	redirectURI := r.URL.Query().Get("redirect_uri")
131	scope := r.URL.Query().Get("scope")
132
133	client.Logger.Info(
134		"authorize handler",
135		"responseType", responseType,
136		"clientID", clientID,
137		"redirectURI", redirectURI,
138		"scope", scope,
139	)
140
141	ts, err := template.ParseFiles(
142		"auth/html/redirect.page.tmpl",
143		"auth/html/footer.partial.tmpl",
144		"auth/html/marketing-footer.partial.tmpl",
145		"auth/html/base.layout.tmpl",
146	)
147
148	if err != nil {
149		client.Logger.Error(err.Error())
150		http.Error(w, err.Error(), http.StatusUnauthorized)
151		return
152	}
153
154	err = ts.Execute(w, map[string]any{
155		"response_type": responseType,
156		"client_id":     clientID,
157		"redirect_uri":  redirectURI,
158		"scope":         scope,
159	})
160
161	if err != nil {
162		client.Logger.Error(err.Error())
163		http.Error(w, err.Error(), http.StatusUnauthorized)
164		return
165	}
166}
167
168func redirectHandler(w http.ResponseWriter, r *http.Request) {
169	client := getClient(r)
170
171	token := r.FormValue("token")
172	redirectURI := r.FormValue("redirect_uri")
173	responseType := r.FormValue("response_type")
174
175	client.Logger.Info("redirect handler",
176		"token", token,
177		"redirectURI", redirectURI,
178		"responseType", responseType,
179	)
180
181	if token == "" || redirectURI == "" || responseType != "code" {
182		http.Error(w, "bad request", http.StatusBadRequest)
183		return
184	}
185
186	url, err := url.Parse(redirectURI)
187	if err != nil {
188		http.Error(w, err.Error(), http.StatusBadRequest)
189		return
190	}
191
192	urlQuery := url.Query()
193	urlQuery.Add("code", token)
194
195	url.RawQuery = urlQuery.Encode()
196
197	http.Redirect(w, r, url.String(), http.StatusFound)
198}
199
200type oauth2Token struct {
201	AccessToken string `json:"access_token"`
202}
203
204func tokenHandler(w http.ResponseWriter, r *http.Request) {
205	client := getClient(r)
206
207	token := r.FormValue("code")
208	redirectURI := r.FormValue("redirect_uri")
209	grantType := r.FormValue("grant_type")
210
211	client.Logger.Info(
212		"handle token",
213		"token", token,
214		"redirectURI", redirectURI,
215		"grantType", grantType,
216	)
217
218	_, err := client.Dbpool.FindUserForToken(token)
219	if err != nil {
220		client.Logger.Error(err.Error())
221		http.Error(w, err.Error(), http.StatusUnauthorized)
222		return
223	}
224
225	p := oauth2Token{
226		AccessToken: token,
227	}
228	w.Header().Set("Content-Type", "application/json")
229	w.WriteHeader(http.StatusOK)
230	err = json.NewEncoder(w).Encode(p)
231	if err != nil {
232		client.Logger.Error(err.Error())
233		http.Error(w, err.Error(), http.StatusInternalServerError)
234	}
235}
236
237type sishData struct {
238	PublicKey     string `json:"auth_key"`
239	Username      string `json:"user"`
240	RemoteAddress string `json:"remote_addr"`
241}
242
243func keyHandler(w http.ResponseWriter, r *http.Request) {
244	client := getClient(r)
245
246	var data sishData
247
248	err := json.NewDecoder(r.Body).Decode(&data)
249	if err != nil {
250		client.Logger.Error(err.Error())
251		http.Error(w, err.Error(), http.StatusBadRequest)
252		return
253	}
254
255	space := r.URL.Query().Get("space")
256	if space == "" {
257		spaceErr := fmt.Errorf("Must provide `space` query parameter")
258		client.Logger.Error(spaceErr.Error())
259		http.Error(w, spaceErr.Error(), http.StatusUnprocessableEntity)
260	}
261
262	client.Logger.Info(
263		"handle key",
264		"remoteAddress", data.RemoteAddress,
265		"user", data.Username,
266		"space", space,
267		"publicKey", data.PublicKey,
268	)
269
270	user, err := client.Dbpool.FindUserForKey(data.Username, data.PublicKey)
271	if err != nil {
272		client.Logger.Error(err.Error())
273		http.Error(w, err.Error(), http.StatusUnauthorized)
274		return
275	}
276
277	if space == "tuns" {
278		if !client.Dbpool.HasFeatureForUser(user.ID, "plus") {
279			w.WriteHeader(http.StatusUnauthorized)
280			return
281		}
282	} else if !client.Dbpool.HasFeatureForUser(user.ID, space) {
283		w.WriteHeader(http.StatusUnauthorized)
284		return
285	}
286
287	if !client.hasPrivilegedAccess(getApiToken(r)) {
288		w.WriteHeader(http.StatusOK)
289		return
290	}
291
292	w.Header().Set("Content-Type", "application/json")
293	w.WriteHeader(http.StatusOK)
294	err = json.NewEncoder(w).Encode(user)
295	if err != nil {
296		client.Logger.Error(err.Error())
297		http.Error(w, err.Error(), http.StatusInternalServerError)
298	}
299}
300
301func userHandler(w http.ResponseWriter, r *http.Request) {
302	client := getClient(r)
303
304	if !client.hasPrivilegedAccess(getApiToken(r)) {
305		w.WriteHeader(http.StatusForbidden)
306		return
307	}
308
309	var data sishData
310
311	err := json.NewDecoder(r.Body).Decode(&data)
312	if err != nil {
313		client.Logger.Error(err.Error())
314		http.Error(w, err.Error(), http.StatusBadRequest)
315		return
316	}
317
318	client.Logger.Info(
319		"handle key",
320		"remoteAddress", data.RemoteAddress,
321		"user", data.Username,
322		"publicKey", data.PublicKey,
323	)
324
325	user, err := client.Dbpool.FindUserForName(data.Username)
326	if err != nil {
327		client.Logger.Error(err.Error())
328		http.Error(w, err.Error(), http.StatusNotFound)
329		return
330	}
331
332	keys, err := client.Dbpool.FindKeysForUser(user)
333	if err != nil {
334		client.Logger.Error(err.Error())
335		http.Error(w, err.Error(), http.StatusNotFound)
336		return
337	}
338
339	w.Header().Set("Content-Type", "application/json")
340	w.WriteHeader(http.StatusOK)
341	err = json.NewEncoder(w).Encode(keys)
342	if err != nil {
343		client.Logger.Error(err.Error())
344		http.Error(w, err.Error(), http.StatusInternalServerError)
345	}
346}
347
348func genFeedItem(now time.Time, expiresAt time.Time, warning time.Time, txt string) *feeds.Item {
349	if now.After(warning) {
350		content := fmt.Sprintf(
351			"Your pico+ membership is going to expire on %s",
352			expiresAt.Format("2006-01-02 15:04:05"),
353		)
354		return &feeds.Item{
355			Id:          fmt.Sprintf("%d", warning.Unix()),
356			Title:       fmt.Sprintf("pico+ %s expiration notice", txt),
357			Link:        &feeds.Link{Href: "https://pico.sh"},
358			Content:     content,
359			Created:     warning,
360			Updated:     warning,
361			Description: content,
362			Author:      &feeds.Author{Name: "team pico"},
363		}
364	}
365
366	return nil
367}
368
369func rssHandler(w http.ResponseWriter, r *http.Request) {
370	client := getClient(r)
371	apiToken, err := url.PathUnescape(getField(r, 0))
372	if err != nil {
373		client.Logger.Error(err.Error())
374		http.Error(w, err.Error(), http.StatusNotFound)
375		return
376	}
377	user, err := client.Dbpool.FindUserForToken(apiToken)
378	if err != nil {
379		client.Logger.Error(err.Error())
380		http.Error(w, "invalid token", http.StatusNotFound)
381		return
382	}
383
384	href := fmt.Sprintf("https://auth.pico.sh/rss/%s", apiToken)
385
386	feed := &feeds.Feed{
387		Title:       "pico+",
388		Link:        &feeds.Link{Href: href},
389		Description: "get notified of important membership updates",
390		Author:      &feeds.Author{Name: "team pico"},
391		Created:     time.Now(),
392	}
393	var feedItems []*feeds.Item
394
395	now := time.Now()
396	ff, err := client.Dbpool.FindFeatureForUser(user.ID, "plus")
397	if err != nil {
398		// still want to send an empty feed
399	} else {
400		createdAt := ff.CreatedAt
401		createdAtStr := createdAt.Format("2006-01-02 15:04:05")
402		id := fmt.Sprintf("pico-plus-activated-%d", createdAt.Unix())
403		content := `Thanks for joining pico+! You now have access to all our premium services for exactly one year.  We will send you pico+ expiration notifications through this RSS feed.  Go to <a href="https://pico.sh/getting-started#next-steps">pico.sh/getting-started#next-steps</a> to start using our services.`
404		plus := &feeds.Item{
405			Id:          id,
406			Title:       fmt.Sprintf("pico+ membership activated on %s", createdAtStr),
407			Link:        &feeds.Link{Href: "https://pico.sh"},
408			Content:     content,
409			Created:     *createdAt,
410			Updated:     *createdAt,
411			Description: content,
412			Author:      &feeds.Author{Name: "team pico"},
413		}
414		feedItems = append(feedItems, plus)
415
416		oneMonthWarning := ff.ExpiresAt.AddDate(0, -1, 0)
417		mo := genFeedItem(now, *ff.ExpiresAt, oneMonthWarning, "1-month")
418		if mo != nil {
419			feedItems = append(feedItems, mo)
420		}
421
422		oneWeekWarning := ff.ExpiresAt.AddDate(0, 0, -7)
423		wk := genFeedItem(now, *ff.ExpiresAt, oneWeekWarning, "1-week")
424		if wk != nil {
425			feedItems = append(feedItems, wk)
426		}
427
428		oneDayWarning := ff.ExpiresAt.AddDate(0, 0, -1)
429		day := genFeedItem(now, *ff.ExpiresAt, oneDayWarning, "1-day")
430		if day != nil {
431			feedItems = append(feedItems, day)
432		}
433	}
434
435	feed.Items = feedItems
436
437	rss, err := feed.ToAtom()
438	if err != nil {
439		client.Logger.Error(err.Error())
440		http.Error(w, "Could not generate atom rss feed", http.StatusInternalServerError)
441	}
442
443	w.Header().Add("Content-Type", "application/atom+xml")
444	_, err = w.Write([]byte(rss))
445	if err != nil {
446		client.Logger.Error(err.Error())
447	}
448}
449
450type OrderEvent struct {
451	Meta *struct {
452		EventName  string `json:"event_name"`
453		CustomData *struct {
454			PicoUsername string `json:"username"`
455		} `json:"custom_data"`
456	} `json:"meta"`
457	Data *struct {
458		Type string `json:"type"`
459		ID   string `json:"id"`
460		Attr *struct {
461			OrderNumber int       `json:"order_number"`
462			Identifier  string    `json:"identifier"`
463			UserName    string    `json:"user_name"`
464			UserEmail   string    `json:"user_email"`
465			CreatedAt   time.Time `json:"created_at"`
466			Status      string    `json:"status"` // `paid`, `refund`
467		} `json:"attributes"`
468	} `json:"data"`
469}
470
471// Status code must be 200 or else lemonsqueezy will keep retrying
472// https://docs.lemonsqueezy.com/help/webhooks
473func paymentWebhookHandler(secret string) func(http.ResponseWriter, *http.Request) {
474	return func(w http.ResponseWriter, r *http.Request) {
475		client := getClient(r)
476		dbpool := client.Dbpool
477		logger := client.Logger
478		const MaxBodyBytes = int64(65536)
479		r.Body = http.MaxBytesReader(w, r.Body, MaxBodyBytes)
480		payload, err := io.ReadAll(r.Body)
481		if err != nil {
482			logger.Error("error reading request body", "err", err.Error())
483			w.WriteHeader(http.StatusOK)
484			return
485		}
486
487		event := OrderEvent{}
488
489		if err := json.Unmarshal(payload, &event); err != nil {
490			logger.Error("failed to parse webhook body JSON", "err", err.Error())
491			w.WriteHeader(http.StatusOK)
492			return
493		}
494
495		hash := shared.HmacString(secret, string(payload))
496		sig := r.Header.Get("X-Signature")
497		if !hmac.Equal([]byte(hash), []byte(sig)) {
498			logger.Error("invalid signature X-Signature")
499			w.WriteHeader(http.StatusOK)
500			return
501		}
502
503		if event.Meta == nil {
504			logger.Error("no meta field found")
505			w.WriteHeader(http.StatusOK)
506			return
507		}
508
509		if event.Meta.EventName != "order_created" {
510			logger.Error("event not order_created", "event", event.Meta.EventName)
511			w.WriteHeader(http.StatusOK)
512			return
513		}
514
515		if event.Meta.CustomData == nil {
516			logger.Error("no custom data found")
517			w.WriteHeader(http.StatusOK)
518			return
519		}
520
521		username := event.Meta.CustomData.PicoUsername
522
523		if event.Data == nil || event.Data.Attr == nil {
524			logger.Error("no data or data.attributes fields found")
525			w.WriteHeader(http.StatusOK)
526			return
527		}
528
529		email := event.Data.Attr.UserEmail
530		created := event.Data.Attr.CreatedAt
531		status := event.Data.Attr.Status
532		txID := fmt.Sprint(event.Data.Attr.OrderNumber)
533
534		log := logger.With(
535			"username", username,
536			"email", email,
537			"created", created,
538			"paymentStatus", status,
539			"txId", txID,
540		)
541		log.Info(
542			"order_created event",
543		)
544
545		// https://checkout.pico.sh/buy/35b1be57-1e25-487f-84dd-5f09bb8783ec?discount=0&checkout[custom][username]=erock
546		if username == "" {
547			log.Error("no `?checkout[custom][username]=xxx` found in URL, cannot add pico+ membership")
548			w.WriteHeader(http.StatusOK)
549			return
550		}
551
552		if status != "paid" {
553			log.Error("status not paid")
554			w.WriteHeader(http.StatusOK)
555			return
556		}
557
558		err = dbpool.AddPicoPlusUser(username, "lemonsqueezy", txID)
559		if err != nil {
560			log.Error("failed to add pico+ user", "err", err)
561		} else {
562			log.Info("successfully added pico+ user")
563		}
564
565		w.WriteHeader(http.StatusOK)
566	}
567}
568
569// URL shortener for out pico+ URL.
570func checkoutHandler(w http.ResponseWriter, r *http.Request) {
571	username, err := url.PathUnescape(getField(r, 0))
572	if err != nil {
573		w.WriteHeader(http.StatusUnprocessableEntity)
574		return
575	}
576	link := "https://checkout.pico.sh/buy/73c26cf9-3fac-44c3-b744-298b3032a96b"
577	url := fmt.Sprintf(
578		"%s?discount=0&checkout[custom][username]=%s",
579		link,
580		username,
581	)
582	http.Redirect(w, r, url, http.StatusMovedPermanently)
583}
584
585func createMainRoutes() []shared.Route {
586	fileServer := http.FileServer(http.Dir("auth/public"))
587	secret := os.Getenv("PICO_SECRET_WEBHOOK")
588	if secret == "" {
589		panic("must provide PICO_SECRET_WEBHOOK environment variable")
590	}
591
592	routes := []shared.Route{
593		shared.NewRoute("GET", "/checkout/(.+)", checkoutHandler),
594		shared.NewRoute("GET", "/.well-known/oauth-authorization-server", wellKnownHandler),
595		shared.NewRoute("POST", "/introspect", introspectHandler),
596		shared.NewRoute("GET", "/authorize", authorizeHandler),
597		shared.NewRoute("POST", "/token", tokenHandler),
598		shared.NewRoute("POST", "/key", keyHandler),
599		shared.NewRoute("POST", "/user", userHandler),
600		shared.NewRoute("GET", "/rss/(.+)", rssHandler),
601		shared.NewRoute("POST", "/redirect", redirectHandler),
602		shared.NewRoute("POST", "/webhook", paymentWebhookHandler(secret)),
603		shared.NewRoute("GET", "/main.css", fileServer.ServeHTTP),
604		shared.NewRoute("GET", "/card.png", fileServer.ServeHTTP),
605		shared.NewRoute("GET", "/favicon-16x16.png", fileServer.ServeHTTP),
606		shared.NewRoute("GET", "/favicon-32x32.png", fileServer.ServeHTTP),
607		shared.NewRoute("GET", "/apple-touch-icon.png", fileServer.ServeHTTP),
608		shared.NewRoute("GET", "/favicon.ico", fileServer.ServeHTTP),
609		shared.NewRoute("GET", "/robots.txt", fileServer.ServeHTTP),
610	}
611
612	return routes
613}
614
615func handler(routes []shared.Route, client *Client) shared.ServeFn {
616	return func(w http.ResponseWriter, r *http.Request) {
617		var allow []string
618
619		for _, route := range routes {
620			matches := route.Regex.FindStringSubmatch(r.URL.Path)
621			if len(matches) > 0 {
622				if r.Method != route.Method {
623					allow = append(allow, route.Method)
624					continue
625				}
626				clientCtx := context.WithValue(r.Context(), ctxClient{}, client)
627				ctx := context.WithValue(clientCtx, ctxKey{}, matches[1:])
628				route.Handler(w, r.WithContext(ctx))
629				return
630			}
631		}
632		if len(allow) > 0 {
633			w.Header().Set("Allow", strings.Join(allow, ", "))
634			http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
635			return
636		}
637		http.NotFound(w, r)
638	}
639}
640
641type AuthCfg struct {
642	Debug  bool
643	Port   string
644	DbURL  string
645	Domain string
646	Issuer string
647}
648
649func StartApiServer() {
650	debug := shared.GetEnv("AUTH_DEBUG", "0")
651	cfg := &AuthCfg{
652		DbURL:  shared.GetEnv("DATABASE_URL", ""),
653		Debug:  debug == "1",
654		Issuer: shared.GetEnv("AUTH_ISSUER", "pico.sh"),
655		Domain: shared.GetEnv("AUTH_DOMAIN", "http://0.0.0.0:3000"),
656		Port:   shared.GetEnv("AUTH_WEB_PORT", "3000"),
657	}
658
659	logger := shared.CreateLogger("auth")
660	db := postgres.NewDB(cfg.DbURL, logger)
661	defer db.Close()
662
663	client := &Client{
664		Cfg:    cfg,
665		Dbpool: db,
666		Logger: logger,
667	}
668
669	routes := createMainRoutes()
670
671	if cfg.Debug {
672		routes = shared.CreatePProfRoutes(routes)
673	}
674
675	router := http.HandlerFunc(handler(routes, client))
676
677	portStr := fmt.Sprintf(":%s", cfg.Port)
678	client.Logger.Info("starting server on port", "port", cfg.Port)
679	err := http.ListenAndServe(portStr, router)
680	if err != nil {
681		client.Logger.Info(err.Error())
682	}
683}