repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

commit
7d1f66a
parent
d5b9a59
author
Eric Bower
date
2024-03-10 18:25:17 +0000 UTC
feat(pgs): allow cors for private sites
4 files changed,  +44, -14
M pgs/tunnel.go
+1, -1
1@@ -76,7 +76,7 @@ func createHttpHandler(httpCtx *shared.HttpCtx) CtxHttpBridge {
2 
3 		routes := []shared.Route{
4 			// special API endpoint for tunnel users accessing site
5-			shared.NewRoute("GET", "/api/current_user", func(w http.ResponseWriter, r *http.Request) {
6+			shared.NewCorsRoute("GET", "/api/current_user", func(w http.ResponseWriter, r *http.Request) {
7 				w.Header().Set("Content-Type", "application/json")
8 				pico := &db.PicoApi{
9 					UserID:    "",
M plus/routes.go
+9, -9
 1@@ -18,6 +18,7 @@ type registerPayload struct {
 2 func registerUser(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
 3 	logger := httpCtx.Cfg.Logger
 4 	return func(w http.ResponseWriter, r *http.Request) {
 5+		w.Header().Set("Content-Type", "application/json")
 6 		dbpool := shared.GetDB(r)
 7 		var payload registerPayload
 8 		body, _ := io.ReadAll(r.Body)
 9@@ -31,7 +32,6 @@ func registerUser(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.
10 			return
11 		}
12 
13-		w.Header().Set("Content-Type", "application/json")
14 		pico := &db.PicoApi{
15 			UserID:    user.ID,
16 			UserName:  user.Name,
17@@ -83,6 +83,7 @@ type rssTokenPayload struct {
18 func findOrCreateRssToken(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
19 	logger := httpCtx.Cfg.Logger
20 	return func(w http.ResponseWriter, r *http.Request) {
21+		w.Header().Set("Content-Type", "application/json")
22 		dbpool := shared.GetDB(r)
23 		user, err := dbpool.FindUserForKey("", pubkey)
24 		if err != nil {
25@@ -104,7 +105,6 @@ func findOrCreateRssToken(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey strin
26 			}
27 		}
28 
29-		w.Header().Set("Content-Type", "application/json")
30 		err = json.NewEncoder(w).Encode(&rssTokenPayload{Token: rssToken})
31 		if err != nil {
32 			logger.Error(err.Error())
33@@ -119,6 +119,7 @@ type pubkeysPayload struct {
34 func getPublicKeys(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
35 	logger := httpCtx.Cfg.Logger
36 	return func(w http.ResponseWriter, r *http.Request) {
37+		w.Header().Set("Content-Type", "application/json")
38 		dbpool := shared.GetDB(r)
39 		user, err := dbpool.FindUserForKey("", pubkey)
40 		if err != nil {
41@@ -132,7 +133,6 @@ func getPublicKeys(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http
42 			return
43 		}
44 
45-		w.Header().Set("Content-Type", "application/json")
46 		err = json.NewEncoder(w).Encode(&pubkeysPayload{Pubkeys: pubkeys})
47 		if err != nil {
48 			logger.Error(err.Error())
49@@ -147,6 +147,7 @@ type tokensPayload struct {
50 func getTokens(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
51 	logger := httpCtx.Cfg.Logger
52 	return func(w http.ResponseWriter, r *http.Request) {
53+		w.Header().Set("Content-Type", "application/json")
54 		dbpool := shared.GetDB(r)
55 		user, err := dbpool.FindUserForKey("", pubkey)
56 		if err != nil {
57@@ -164,7 +165,6 @@ func getTokens(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.Han
58 			tokens = []*db.Token{}
59 		}
60 
61-		w.Header().Set("Content-Type", "application/json")
62 		err = json.NewEncoder(w).Encode(&tokensPayload{Tokens: tokens})
63 		if err != nil {
64 			logger.Error(err.Error())
65@@ -187,10 +187,10 @@ func CreateRoutes(httpCtx *shared.HttpCtx, ctx ssh.Context) []shared.Route {
66 	}
67 
68 	return []shared.Route{
69-		shared.NewRoute("POST", "/api/users", registerUser(httpCtx, ctx, pubkeyStr)),
70-		shared.NewRoute("GET", "/api/features", getFeatures(httpCtx, ctx, pubkeyStr)),
71-		shared.NewRoute("PUT", "/api/rss-token", findOrCreateRssToken(httpCtx, ctx, pubkeyStr)),
72-		shared.NewRoute("GET", "/api/pubkeys", getPublicKeys(httpCtx, ctx, pubkeyStr)),
73-		shared.NewRoute("GET", "/api/tokens", getTokens(httpCtx, ctx, pubkeyStr)),
74+		shared.NewCorsRoute("POST", "/api/users", registerUser(httpCtx, ctx, pubkeyStr)),
75+		shared.NewCorsRoute("GET", "/api/features", getFeatures(httpCtx, ctx, pubkeyStr)),
76+		shared.NewCorsRoute("PUT", "/api/rss-token", findOrCreateRssToken(httpCtx, ctx, pubkeyStr)),
77+		shared.NewCorsRoute("GET", "/api/pubkeys", getPublicKeys(httpCtx, ctx, pubkeyStr)),
78+		shared.NewCorsRoute("GET", "/api/tokens", getTokens(httpCtx, ctx, pubkeyStr)),
79 	}
80 }
M shared/api.go
+10, -0
 1@@ -9,6 +9,16 @@ import (
 2 	"strings"
 3 )
 4 
 5+func CorsHeaders(w http.ResponseWriter) {
 6+	headers := w.Header()
 7+	headers.Add("Access-Control-Allow-Origin", "*")
 8+	headers.Add("Vary", "Origin")
 9+	headers.Add("Vary", "Access-Control-Request-Method")
10+	headers.Add("Vary", "Access-Control-Request-Headers")
11+	headers.Add("Access-Control-Allow-Headers", "Content-Type, Accept")
12+	headers.Add("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, PATCH, DELETE")
13+}
14+
15 func UnauthorizedHandler(w http.ResponseWriter, r *http.Request) {
16 	http.Error(w, "You do not have access to this site", http.StatusUnauthorized)
17 }
M shared/router.go
+24, -4
 1@@ -15,9 +15,10 @@ import (
 2 )
 3 
 4 type Route struct {
 5-	Method  string
 6-	Regex   *regexp.Regexp
 7-	Handler http.HandlerFunc
 8+	Method      string
 9+	Regex       *regexp.Regexp
10+	Handler     http.HandlerFunc
11+	CorsEnabled bool
12 }
13 
14 func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
15@@ -25,6 +26,16 @@ func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
16 		method,
17 		regexp.MustCompile("^" + pattern + "$"),
18 		handler,
19+		false,
20+	}
21+}
22+
23+func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
24+	return Route{
25+		method,
26+		regexp.MustCompile("^" + pattern + "$"),
27+		handler,
28+		true,
29 	}
30 }
31 
32@@ -65,10 +76,19 @@ func CreateServeBasic(routes []Route, ctx context.Context) ServeFn {
33 		for _, route := range routes {
34 			matches := route.Regex.FindStringSubmatch(r.URL.Path)
35 			if len(matches) > 0 {
36-				if r.Method != route.Method {
37+				if r.Method == "OPTIONS" && route.CorsEnabled {
38+					CorsHeaders(w)
39+					w.WriteHeader(http.StatusOK)
40+					return
41+				} else if r.Method != route.Method {
42 					allow = append(allow, route.Method)
43 					continue
44 				}
45+
46+				if route.CorsEnabled {
47+					CorsHeaders(w)
48+				}
49+
50 				finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
51 				route.Handler(w, r.WithContext(finctx))
52 				return