- commit
- 7d1f66a
- parent
- d5b9a59
- author
- Eric Bower
- date
- 2024-03-10 18:25:17 +0000 UTC
feat(pgs): allow cors for private sites
4 files changed,
+44,
-14
+1,
-1
1@@ -76,7 +76,7 @@ func createHttpHandler(httpCtx *shared.HttpCtx) CtxHttpBridge {
2
3 routes := []shared.Route{
4 // special API endpoint for tunnel users accessing site
5- shared.NewRoute("GET", "/api/current_user", func(w http.ResponseWriter, r *http.Request) {
6+ shared.NewCorsRoute("GET", "/api/current_user", func(w http.ResponseWriter, r *http.Request) {
7 w.Header().Set("Content-Type", "application/json")
8 pico := &db.PicoApi{
9 UserID: "",
+9,
-9
1@@ -18,6 +18,7 @@ type registerPayload struct {
2 func registerUser(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
3 logger := httpCtx.Cfg.Logger
4 return func(w http.ResponseWriter, r *http.Request) {
5+ w.Header().Set("Content-Type", "application/json")
6 dbpool := shared.GetDB(r)
7 var payload registerPayload
8 body, _ := io.ReadAll(r.Body)
9@@ -31,7 +32,6 @@ func registerUser(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.
10 return
11 }
12
13- w.Header().Set("Content-Type", "application/json")
14 pico := &db.PicoApi{
15 UserID: user.ID,
16 UserName: user.Name,
17@@ -83,6 +83,7 @@ type rssTokenPayload struct {
18 func findOrCreateRssToken(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
19 logger := httpCtx.Cfg.Logger
20 return func(w http.ResponseWriter, r *http.Request) {
21+ w.Header().Set("Content-Type", "application/json")
22 dbpool := shared.GetDB(r)
23 user, err := dbpool.FindUserForKey("", pubkey)
24 if err != nil {
25@@ -104,7 +105,6 @@ func findOrCreateRssToken(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey strin
26 }
27 }
28
29- w.Header().Set("Content-Type", "application/json")
30 err = json.NewEncoder(w).Encode(&rssTokenPayload{Token: rssToken})
31 if err != nil {
32 logger.Error(err.Error())
33@@ -119,6 +119,7 @@ type pubkeysPayload struct {
34 func getPublicKeys(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
35 logger := httpCtx.Cfg.Logger
36 return func(w http.ResponseWriter, r *http.Request) {
37+ w.Header().Set("Content-Type", "application/json")
38 dbpool := shared.GetDB(r)
39 user, err := dbpool.FindUserForKey("", pubkey)
40 if err != nil {
41@@ -132,7 +133,6 @@ func getPublicKeys(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http
42 return
43 }
44
45- w.Header().Set("Content-Type", "application/json")
46 err = json.NewEncoder(w).Encode(&pubkeysPayload{Pubkeys: pubkeys})
47 if err != nil {
48 logger.Error(err.Error())
49@@ -147,6 +147,7 @@ type tokensPayload struct {
50 func getTokens(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.HandlerFunc {
51 logger := httpCtx.Cfg.Logger
52 return func(w http.ResponseWriter, r *http.Request) {
53+ w.Header().Set("Content-Type", "application/json")
54 dbpool := shared.GetDB(r)
55 user, err := dbpool.FindUserForKey("", pubkey)
56 if err != nil {
57@@ -164,7 +165,6 @@ func getTokens(httpCtx *shared.HttpCtx, ctx ssh.Context, pubkey string) http.Han
58 tokens = []*db.Token{}
59 }
60
61- w.Header().Set("Content-Type", "application/json")
62 err = json.NewEncoder(w).Encode(&tokensPayload{Tokens: tokens})
63 if err != nil {
64 logger.Error(err.Error())
65@@ -187,10 +187,10 @@ func CreateRoutes(httpCtx *shared.HttpCtx, ctx ssh.Context) []shared.Route {
66 }
67
68 return []shared.Route{
69- shared.NewRoute("POST", "/api/users", registerUser(httpCtx, ctx, pubkeyStr)),
70- shared.NewRoute("GET", "/api/features", getFeatures(httpCtx, ctx, pubkeyStr)),
71- shared.NewRoute("PUT", "/api/rss-token", findOrCreateRssToken(httpCtx, ctx, pubkeyStr)),
72- shared.NewRoute("GET", "/api/pubkeys", getPublicKeys(httpCtx, ctx, pubkeyStr)),
73- shared.NewRoute("GET", "/api/tokens", getTokens(httpCtx, ctx, pubkeyStr)),
74+ shared.NewCorsRoute("POST", "/api/users", registerUser(httpCtx, ctx, pubkeyStr)),
75+ shared.NewCorsRoute("GET", "/api/features", getFeatures(httpCtx, ctx, pubkeyStr)),
76+ shared.NewCorsRoute("PUT", "/api/rss-token", findOrCreateRssToken(httpCtx, ctx, pubkeyStr)),
77+ shared.NewCorsRoute("GET", "/api/pubkeys", getPublicKeys(httpCtx, ctx, pubkeyStr)),
78+ shared.NewCorsRoute("GET", "/api/tokens", getTokens(httpCtx, ctx, pubkeyStr)),
79 }
80 }
1@@ -9,6 +9,16 @@ import (
2 "strings"
3 )
4
5+func CorsHeaders(w http.ResponseWriter) {
6+ headers := w.Header()
7+ headers.Add("Access-Control-Allow-Origin", "*")
8+ headers.Add("Vary", "Origin")
9+ headers.Add("Vary", "Access-Control-Request-Method")
10+ headers.Add("Vary", "Access-Control-Request-Headers")
11+ headers.Add("Access-Control-Allow-Headers", "Content-Type, Accept")
12+ headers.Add("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, PATCH, DELETE")
13+}
14+
15 func UnauthorizedHandler(w http.ResponseWriter, r *http.Request) {
16 http.Error(w, "You do not have access to this site", http.StatusUnauthorized)
17 }
1@@ -15,9 +15,10 @@ import (
2 )
3
4 type Route struct {
5- Method string
6- Regex *regexp.Regexp
7- Handler http.HandlerFunc
8+ Method string
9+ Regex *regexp.Regexp
10+ Handler http.HandlerFunc
11+ CorsEnabled bool
12 }
13
14 func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
15@@ -25,6 +26,16 @@ func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
16 method,
17 regexp.MustCompile("^" + pattern + "$"),
18 handler,
19+ false,
20+ }
21+}
22+
23+func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
24+ return Route{
25+ method,
26+ regexp.MustCompile("^" + pattern + "$"),
27+ handler,
28+ true,
29 }
30 }
31
32@@ -65,10 +76,19 @@ func CreateServeBasic(routes []Route, ctx context.Context) ServeFn {
33 for _, route := range routes {
34 matches := route.Regex.FindStringSubmatch(r.URL.Path)
35 if len(matches) > 0 {
36- if r.Method != route.Method {
37+ if r.Method == "OPTIONS" && route.CorsEnabled {
38+ CorsHeaders(w)
39+ w.WriteHeader(http.StatusOK)
40+ return
41+ } else if r.Method != route.Method {
42 allow = append(allow, route.Method)
43 continue
44 }
45+
46+ if route.CorsEnabled {
47+ CorsHeaders(w)
48+ }
49+
50 finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
51 route.Handler(w, r.WithContext(finctx))
52 return