Eric Bower
·
19 Apr 24
router.go
1package shared
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/http"
9 "net/http/pprof"
10 "regexp"
11 "strings"
12
13 "github.com/charmbracelet/ssh"
14 "github.com/picosh/pico/db"
15 "github.com/picosh/pico/shared/storage"
16)
17
18type Route struct {
19 Method string
20 Regex *regexp.Regexp
21 Handler http.HandlerFunc
22 CorsEnabled bool
23}
24
25func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
26 return Route{
27 method,
28 regexp.MustCompile("^" + pattern + "$"),
29 handler,
30 false,
31 }
32}
33
34func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
35 return Route{
36 method,
37 regexp.MustCompile("^" + pattern + "$"),
38 handler,
39 true,
40 }
41}
42
43func CreatePProfRoutes(routes []Route) []Route {
44 return append(routes,
45 NewRoute("GET", "/debug/pprof/cmdline", pprof.Cmdline),
46 NewRoute("GET", "/debug/pprof/profile", pprof.Profile),
47 NewRoute("GET", "/debug/pprof/symbol", pprof.Symbol),
48 NewRoute("GET", "/debug/pprof/trace", pprof.Trace),
49 NewRoute("GET", "/debug/pprof/(.*)", pprof.Index),
50 NewRoute("POST", "/debug/pprof/cmdline", pprof.Cmdline),
51 NewRoute("POST", "/debug/pprof/profile", pprof.Profile),
52 NewRoute("POST", "/debug/pprof/symbol", pprof.Symbol),
53 NewRoute("POST", "/debug/pprof/trace", pprof.Trace),
54 NewRoute("POST", "/debug/pprof/(.*)", pprof.Index),
55 )
56}
57
58type ServeFn func(http.ResponseWriter, *http.Request)
59type ApiConfig struct {
60 Cfg *ConfigSite
61 Dbpool db.DB
62 Storage storage.StorageServe
63 AnalyticsQueue chan *db.AnalyticsVisits
64}
65
66func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
67 ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
68 ctx = context.WithValue(ctx, ctxSubdomainKey{}, subdomain)
69 ctx = context.WithValue(ctx, ctxDBKey{}, hc.Dbpool)
70 ctx = context.WithValue(ctx, ctxStorageKey{}, hc.Storage)
71 ctx = context.WithValue(ctx, ctxCfg{}, hc.Cfg)
72 ctx = context.WithValue(ctx, ctxAnalyticsQueue{}, hc.AnalyticsQueue)
73 return ctx
74}
75
76func CreateServeBasic(routes []Route, ctx context.Context) ServeFn {
77 return func(w http.ResponseWriter, r *http.Request) {
78 var allow []string
79 for _, route := range routes {
80 matches := route.Regex.FindStringSubmatch(r.URL.Path)
81 if len(matches) > 0 {
82 if r.Method == "OPTIONS" && route.CorsEnabled {
83 CorsHeaders(w.Header())
84 w.WriteHeader(http.StatusOK)
85 return
86 } else if r.Method != route.Method {
87 allow = append(allow, route.Method)
88 continue
89 }
90
91 if route.CorsEnabled {
92 CorsHeaders(w.Header())
93 }
94
95 finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
96 route.Handler(w, r.WithContext(finctx))
97 return
98 }
99 }
100 if len(allow) > 0 {
101 w.Header().Set("Allow", strings.Join(allow, ", "))
102 http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
103 return
104 }
105 http.NotFound(w, r)
106 }
107}
108
109func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
110 var subdomain string
111 curRoutes := routes
112
113 if cfg.IsCustomdomains() || cfg.IsSubdomains() {
114 hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
115 appDomain := strings.ToLower(strings.Split(cfg.Domain, ":")[0])
116
117 if hostDomain != appDomain {
118 if strings.Contains(hostDomain, appDomain) {
119 subdomain = strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
120 if subdomain != "" {
121 curRoutes = subdomainRoutes
122 }
123 } else {
124 subdomain = GetCustomDomain(hostDomain, cfg.Space)
125 if subdomain != "" {
126 curRoutes = subdomainRoutes
127 }
128 }
129 }
130 }
131
132 return curRoutes, subdomain
133}
134
135func CreateServe(routes []Route, subdomainRoutes []Route, apiConfig *ApiConfig) ServeFn {
136 return func(w http.ResponseWriter, r *http.Request) {
137 curRoutes, subdomain := findRouteConfig(r, routes, subdomainRoutes, apiConfig.Cfg)
138 ctx := apiConfig.CreateCtx(r.Context(), subdomain)
139 router := CreateServeBasic(curRoutes, ctx)
140 router(w, r)
141 }
142}
143
144type ctxDBKey struct{}
145type ctxStorageKey struct{}
146type ctxKey struct{}
147type ctxLoggerKey struct{}
148type ctxSubdomainKey struct{}
149type ctxCfg struct{}
150type ctxAnalyticsQueue struct{}
151type CtxSshKey struct{}
152
153func GetSshCtx(r *http.Request) (ssh.Context, error) {
154 payload, ok := r.Context().Value(CtxSshKey{}).(ssh.Context)
155 if payload == nil || !ok {
156 return payload, fmt.Errorf("sshCtx not set on `r.Context()` for connection")
157 }
158 return payload, nil
159}
160
161func GetCfg(r *http.Request) *ConfigSite {
162 return r.Context().Value(ctxCfg{}).(*ConfigSite)
163}
164
165func GetLogger(r *http.Request) *slog.Logger {
166 return r.Context().Value(ctxLoggerKey{}).(*slog.Logger)
167}
168
169func GetDB(r *http.Request) db.DB {
170 return r.Context().Value(ctxDBKey{}).(db.DB)
171}
172
173func GetStorage(r *http.Request) storage.StorageServe {
174 return r.Context().Value(ctxStorageKey{}).(storage.StorageServe)
175}
176
177func GetField(r *http.Request, index int) string {
178 fields := r.Context().Value(ctxKey{}).([]string)
179 if index >= len(fields) {
180 return ""
181 }
182 return fields[index]
183}
184
185func GetSubdomain(r *http.Request) string {
186 return r.Context().Value(ctxSubdomainKey{}).(string)
187}
188
189func GetCustomDomain(host string, space string) string {
190 txt := fmt.Sprintf("_%s.%s", space, host)
191 records, err := net.LookupTXT(txt)
192 if err != nil {
193 return ""
194 }
195
196 for _, v := range records {
197 return strings.TrimSpace(v)
198 }
199
200 return ""
201}
202
203func GetAnalyticsQueue(r *http.Request) chan *db.AnalyticsVisits {
204 return r.Context().Value(ctxAnalyticsQueue{}).(chan *db.AnalyticsVisits)
205}