repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / shared
Eric Bower · 15 Nov 24

router.go

  1package shared
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/http/pprof"
 10	"regexp"
 11	"strings"
 12
 13	"github.com/charmbracelet/ssh"
 14	"github.com/picosh/pico/db"
 15	"github.com/picosh/pico/shared/storage"
 16)
 17
 18type Route struct {
 19	Method      string
 20	Regex       *regexp.Regexp
 21	Handler     http.HandlerFunc
 22	CorsEnabled bool
 23}
 24
 25func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
 26	return Route{
 27		method,
 28		regexp.MustCompile("^" + pattern + "$"),
 29		handler,
 30		false,
 31	}
 32}
 33
 34func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
 35	return Route{
 36		method,
 37		regexp.MustCompile("^" + pattern + "$"),
 38		handler,
 39		true,
 40	}
 41}
 42
 43func CreatePProfRoutes(routes []Route) []Route {
 44	return append(routes,
 45		NewRoute("GET", "/debug/pprof/cmdline", pprof.Cmdline),
 46		NewRoute("GET", "/debug/pprof/profile", pprof.Profile),
 47		NewRoute("GET", "/debug/pprof/symbol", pprof.Symbol),
 48		NewRoute("GET", "/debug/pprof/trace", pprof.Trace),
 49		NewRoute("GET", "/debug/pprof/(.*)", pprof.Index),
 50		NewRoute("POST", "/debug/pprof/cmdline", pprof.Cmdline),
 51		NewRoute("POST", "/debug/pprof/profile", pprof.Profile),
 52		NewRoute("POST", "/debug/pprof/symbol", pprof.Symbol),
 53		NewRoute("POST", "/debug/pprof/trace", pprof.Trace),
 54		NewRoute("POST", "/debug/pprof/(.*)", pprof.Index),
 55	)
 56}
 57
 58func CreatePProfRoutesMux(mux *http.ServeMux) {
 59	mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
 60	mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
 61	mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
 62	mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
 63	mux.HandleFunc("GET /debug/pprof/(.*)", pprof.Index)
 64	mux.HandleFunc("POST /debug/pprof/cmdline", pprof.Cmdline)
 65	mux.HandleFunc("POST /debug/pprof/profile", pprof.Profile)
 66	mux.HandleFunc("POST /debug/pprof/symbol", pprof.Symbol)
 67	mux.HandleFunc("POST /debug/pprof/trace", pprof.Trace)
 68	mux.HandleFunc("POST /debug/pprof/(.*)", pprof.Index)
 69}
 70
 71type ApiConfig struct {
 72	Cfg     *ConfigSite
 73	Dbpool  db.DB
 74	Storage storage.StorageServe
 75}
 76
 77func (hc *ApiConfig) HasPrivilegedAccess(apiToken string) bool {
 78	user, err := hc.Dbpool.FindUserForToken(apiToken)
 79	if err != nil {
 80		return false
 81	}
 82	return hc.Dbpool.HasFeatureForUser(user.ID, "auth")
 83}
 84
 85func (hc *ApiConfig) HasPlusOrSpace(user *db.User, space string) bool {
 86	return hc.Dbpool.HasFeatureForUser(user.ID, "plus") || hc.Dbpool.HasFeatureForUser(user.ID, space)
 87}
 88
 89func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
 90	ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
 91	ctx = context.WithValue(ctx, CtxSubdomainKey{}, subdomain)
 92	ctx = context.WithValue(ctx, ctxDBKey{}, hc.Dbpool)
 93	ctx = context.WithValue(ctx, ctxStorageKey{}, hc.Storage)
 94	ctx = context.WithValue(ctx, ctxCfg{}, hc.Cfg)
 95	return ctx
 96}
 97
 98func CreateServeBasic(routes []Route, ctx context.Context) http.HandlerFunc {
 99	return func(w http.ResponseWriter, r *http.Request) {
100		var allow []string
101		for _, route := range routes {
102			matches := route.Regex.FindStringSubmatch(r.URL.Path)
103			if len(matches) > 0 {
104				if r.Method == "OPTIONS" && route.CorsEnabled {
105					CorsHeaders(w.Header())
106					w.WriteHeader(http.StatusOK)
107					return
108				} else if r.Method != route.Method {
109					allow = append(allow, route.Method)
110					continue
111				}
112
113				if route.CorsEnabled {
114					CorsHeaders(w.Header())
115				}
116
117				finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
118				route.Handler(w, r.WithContext(finctx))
119				return
120			}
121		}
122		if len(allow) > 0 {
123			w.Header().Set("Allow", strings.Join(allow, ", "))
124			http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
125			return
126		}
127		http.NotFound(w, r)
128	}
129}
130
131func GetSubdomainFromRequest(r *http.Request, domain, space string) string {
132	hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
133	appDomain := strings.ToLower(strings.Split(domain, ":")[0])
134
135	if hostDomain != appDomain {
136		if strings.Contains(hostDomain, appDomain) {
137			subdomain := strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
138			return subdomain
139		} else {
140			subdomain := GetCustomDomain(hostDomain, space)
141			return subdomain
142		}
143	}
144
145	return ""
146}
147
148func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
149	if len(subdomainRoutes) == 0 {
150		return routes, ""
151	}
152
153	subdomain := GetSubdomainFromRequest(r, cfg.Domain, cfg.Space)
154	if subdomain == "" {
155		return routes, subdomain
156	}
157	return subdomainRoutes, subdomain
158}
159
160func CreateServe(routes []Route, subdomainRoutes []Route, apiConfig *ApiConfig) http.HandlerFunc {
161	return func(w http.ResponseWriter, r *http.Request) {
162		curRoutes, subdomain := findRouteConfig(r, routes, subdomainRoutes, apiConfig.Cfg)
163		ctx := apiConfig.CreateCtx(r.Context(), subdomain)
164		router := CreateServeBasic(curRoutes, ctx)
165		router(w, r)
166	}
167}
168
169type ctxDBKey struct{}
170type ctxStorageKey struct{}
171type ctxLoggerKey struct{}
172type ctxCfg struct{}
173
174type CtxSubdomainKey struct{}
175type ctxKey struct{}
176type CtxSshKey struct{}
177
178func GetSshCtx(r *http.Request) (ssh.Context, error) {
179	payload, ok := r.Context().Value(CtxSshKey{}).(ssh.Context)
180	if payload == nil || !ok {
181		return payload, fmt.Errorf("sshCtx not set on `r.Context()` for connection")
182	}
183	return payload, nil
184}
185
186func GetCfg(r *http.Request) *ConfigSite {
187	return r.Context().Value(ctxCfg{}).(*ConfigSite)
188}
189
190func GetLogger(r *http.Request) *slog.Logger {
191	return r.Context().Value(ctxLoggerKey{}).(*slog.Logger)
192}
193
194func GetDB(r *http.Request) db.DB {
195	return r.Context().Value(ctxDBKey{}).(db.DB)
196}
197
198func GetStorage(r *http.Request) storage.StorageServe {
199	return r.Context().Value(ctxStorageKey{}).(storage.StorageServe)
200}
201
202func GetField(r *http.Request, index int) string {
203	fields := r.Context().Value(ctxKey{}).([]string)
204	if index >= len(fields) {
205		return ""
206	}
207	return fields[index]
208}
209
210func GetSubdomain(r *http.Request) string {
211	return r.Context().Value(CtxSubdomainKey{}).(string)
212}
213
214func GetCustomDomain(host string, space string) string {
215	txt := fmt.Sprintf("_%s.%s", space, host)
216	records, err := net.LookupTXT(txt)
217	if err != nil {
218		return ""
219	}
220
221	for _, v := range records {
222		return strings.TrimSpace(v)
223	}
224
225	return ""
226}
227
228func GetApiToken(r *http.Request) string {
229	authHeader := r.Header.Get("authorization")
230	if authHeader == "" {
231		return ""
232	}
233	return strings.TrimPrefix(authHeader, "Bearer ")
234}