Eric Bower
·
15 Nov 24
router.go
1package shared
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net"
8 "net/http"
9 "net/http/pprof"
10 "regexp"
11 "strings"
12
13 "github.com/charmbracelet/ssh"
14 "github.com/picosh/pico/db"
15 "github.com/picosh/pico/shared/storage"
16)
17
18type Route struct {
19 Method string
20 Regex *regexp.Regexp
21 Handler http.HandlerFunc
22 CorsEnabled bool
23}
24
25func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
26 return Route{
27 method,
28 regexp.MustCompile("^" + pattern + "$"),
29 handler,
30 false,
31 }
32}
33
34func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
35 return Route{
36 method,
37 regexp.MustCompile("^" + pattern + "$"),
38 handler,
39 true,
40 }
41}
42
43func CreatePProfRoutes(routes []Route) []Route {
44 return append(routes,
45 NewRoute("GET", "/debug/pprof/cmdline", pprof.Cmdline),
46 NewRoute("GET", "/debug/pprof/profile", pprof.Profile),
47 NewRoute("GET", "/debug/pprof/symbol", pprof.Symbol),
48 NewRoute("GET", "/debug/pprof/trace", pprof.Trace),
49 NewRoute("GET", "/debug/pprof/(.*)", pprof.Index),
50 NewRoute("POST", "/debug/pprof/cmdline", pprof.Cmdline),
51 NewRoute("POST", "/debug/pprof/profile", pprof.Profile),
52 NewRoute("POST", "/debug/pprof/symbol", pprof.Symbol),
53 NewRoute("POST", "/debug/pprof/trace", pprof.Trace),
54 NewRoute("POST", "/debug/pprof/(.*)", pprof.Index),
55 )
56}
57
58func CreatePProfRoutesMux(mux *http.ServeMux) {
59 mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
60 mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
61 mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
62 mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
63 mux.HandleFunc("GET /debug/pprof/(.*)", pprof.Index)
64 mux.HandleFunc("POST /debug/pprof/cmdline", pprof.Cmdline)
65 mux.HandleFunc("POST /debug/pprof/profile", pprof.Profile)
66 mux.HandleFunc("POST /debug/pprof/symbol", pprof.Symbol)
67 mux.HandleFunc("POST /debug/pprof/trace", pprof.Trace)
68 mux.HandleFunc("POST /debug/pprof/(.*)", pprof.Index)
69}
70
71type ApiConfig struct {
72 Cfg *ConfigSite
73 Dbpool db.DB
74 Storage storage.StorageServe
75}
76
77func (hc *ApiConfig) HasPrivilegedAccess(apiToken string) bool {
78 user, err := hc.Dbpool.FindUserForToken(apiToken)
79 if err != nil {
80 return false
81 }
82 return hc.Dbpool.HasFeatureForUser(user.ID, "auth")
83}
84
85func (hc *ApiConfig) HasPlusOrSpace(user *db.User, space string) bool {
86 return hc.Dbpool.HasFeatureForUser(user.ID, "plus") || hc.Dbpool.HasFeatureForUser(user.ID, space)
87}
88
89func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
90 ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
91 ctx = context.WithValue(ctx, CtxSubdomainKey{}, subdomain)
92 ctx = context.WithValue(ctx, ctxDBKey{}, hc.Dbpool)
93 ctx = context.WithValue(ctx, ctxStorageKey{}, hc.Storage)
94 ctx = context.WithValue(ctx, ctxCfg{}, hc.Cfg)
95 return ctx
96}
97
98func CreateServeBasic(routes []Route, ctx context.Context) http.HandlerFunc {
99 return func(w http.ResponseWriter, r *http.Request) {
100 var allow []string
101 for _, route := range routes {
102 matches := route.Regex.FindStringSubmatch(r.URL.Path)
103 if len(matches) > 0 {
104 if r.Method == "OPTIONS" && route.CorsEnabled {
105 CorsHeaders(w.Header())
106 w.WriteHeader(http.StatusOK)
107 return
108 } else if r.Method != route.Method {
109 allow = append(allow, route.Method)
110 continue
111 }
112
113 if route.CorsEnabled {
114 CorsHeaders(w.Header())
115 }
116
117 finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
118 route.Handler(w, r.WithContext(finctx))
119 return
120 }
121 }
122 if len(allow) > 0 {
123 w.Header().Set("Allow", strings.Join(allow, ", "))
124 http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
125 return
126 }
127 http.NotFound(w, r)
128 }
129}
130
131func GetSubdomainFromRequest(r *http.Request, domain, space string) string {
132 hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
133 appDomain := strings.ToLower(strings.Split(domain, ":")[0])
134
135 if hostDomain != appDomain {
136 if strings.Contains(hostDomain, appDomain) {
137 subdomain := strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
138 return subdomain
139 } else {
140 subdomain := GetCustomDomain(hostDomain, space)
141 return subdomain
142 }
143 }
144
145 return ""
146}
147
148func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
149 if len(subdomainRoutes) == 0 {
150 return routes, ""
151 }
152
153 subdomain := GetSubdomainFromRequest(r, cfg.Domain, cfg.Space)
154 if subdomain == "" {
155 return routes, subdomain
156 }
157 return subdomainRoutes, subdomain
158}
159
160func CreateServe(routes []Route, subdomainRoutes []Route, apiConfig *ApiConfig) http.HandlerFunc {
161 return func(w http.ResponseWriter, r *http.Request) {
162 curRoutes, subdomain := findRouteConfig(r, routes, subdomainRoutes, apiConfig.Cfg)
163 ctx := apiConfig.CreateCtx(r.Context(), subdomain)
164 router := CreateServeBasic(curRoutes, ctx)
165 router(w, r)
166 }
167}
168
169type ctxDBKey struct{}
170type ctxStorageKey struct{}
171type ctxLoggerKey struct{}
172type ctxCfg struct{}
173
174type CtxSubdomainKey struct{}
175type ctxKey struct{}
176type CtxSshKey struct{}
177
178func GetSshCtx(r *http.Request) (ssh.Context, error) {
179 payload, ok := r.Context().Value(CtxSshKey{}).(ssh.Context)
180 if payload == nil || !ok {
181 return payload, fmt.Errorf("sshCtx not set on `r.Context()` for connection")
182 }
183 return payload, nil
184}
185
186func GetCfg(r *http.Request) *ConfigSite {
187 return r.Context().Value(ctxCfg{}).(*ConfigSite)
188}
189
190func GetLogger(r *http.Request) *slog.Logger {
191 return r.Context().Value(ctxLoggerKey{}).(*slog.Logger)
192}
193
194func GetDB(r *http.Request) db.DB {
195 return r.Context().Value(ctxDBKey{}).(db.DB)
196}
197
198func GetStorage(r *http.Request) storage.StorageServe {
199 return r.Context().Value(ctxStorageKey{}).(storage.StorageServe)
200}
201
202func GetField(r *http.Request, index int) string {
203 fields := r.Context().Value(ctxKey{}).([]string)
204 if index >= len(fields) {
205 return ""
206 }
207 return fields[index]
208}
209
210func GetSubdomain(r *http.Request) string {
211 return r.Context().Value(CtxSubdomainKey{}).(string)
212}
213
214func GetCustomDomain(host string, space string) string {
215 txt := fmt.Sprintf("_%s.%s", space, host)
216 records, err := net.LookupTXT(txt)
217 if err != nil {
218 return ""
219 }
220
221 for _, v := range records {
222 return strings.TrimSpace(v)
223 }
224
225 return ""
226}
227
228func GetApiToken(r *http.Request) string {
229 authHeader := r.Header.Get("authorization")
230 if authHeader == "" {
231 return ""
232 }
233 return strings.TrimPrefix(authHeader, "Bearer ")
234}