repos / pico

pico services - prose.sh, pastes.sh, imgs.sh, feeds.sh, pgs.sh
git clone https://github.com/picosh/pico.git

pico / shared
Eric Bower · 19 Apr 24

router.go

  1package shared
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/http"
  9	"net/http/pprof"
 10	"regexp"
 11	"strings"
 12
 13	"github.com/charmbracelet/ssh"
 14	"github.com/picosh/pico/db"
 15	"github.com/picosh/pico/shared/storage"
 16)
 17
 18type Route struct {
 19	Method      string
 20	Regex       *regexp.Regexp
 21	Handler     http.HandlerFunc
 22	CorsEnabled bool
 23}
 24
 25func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
 26	return Route{
 27		method,
 28		regexp.MustCompile("^" + pattern + "$"),
 29		handler,
 30		false,
 31	}
 32}
 33
 34func NewCorsRoute(method, pattern string, handler http.HandlerFunc) Route {
 35	return Route{
 36		method,
 37		regexp.MustCompile("^" + pattern + "$"),
 38		handler,
 39		true,
 40	}
 41}
 42
 43func CreatePProfRoutes(routes []Route) []Route {
 44	return append(routes,
 45		NewRoute("GET", "/debug/pprof/cmdline", pprof.Cmdline),
 46		NewRoute("GET", "/debug/pprof/profile", pprof.Profile),
 47		NewRoute("GET", "/debug/pprof/symbol", pprof.Symbol),
 48		NewRoute("GET", "/debug/pprof/trace", pprof.Trace),
 49		NewRoute("GET", "/debug/pprof/(.*)", pprof.Index),
 50		NewRoute("POST", "/debug/pprof/cmdline", pprof.Cmdline),
 51		NewRoute("POST", "/debug/pprof/profile", pprof.Profile),
 52		NewRoute("POST", "/debug/pprof/symbol", pprof.Symbol),
 53		NewRoute("POST", "/debug/pprof/trace", pprof.Trace),
 54		NewRoute("POST", "/debug/pprof/(.*)", pprof.Index),
 55	)
 56}
 57
 58type ServeFn func(http.ResponseWriter, *http.Request)
 59type ApiConfig struct {
 60	Cfg            *ConfigSite
 61	Dbpool         db.DB
 62	Storage        storage.StorageServe
 63	AnalyticsQueue chan *db.AnalyticsVisits
 64}
 65
 66func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
 67	ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
 68	ctx = context.WithValue(ctx, ctxSubdomainKey{}, subdomain)
 69	ctx = context.WithValue(ctx, ctxDBKey{}, hc.Dbpool)
 70	ctx = context.WithValue(ctx, ctxStorageKey{}, hc.Storage)
 71	ctx = context.WithValue(ctx, ctxCfg{}, hc.Cfg)
 72	ctx = context.WithValue(ctx, ctxAnalyticsQueue{}, hc.AnalyticsQueue)
 73	return ctx
 74}
 75
 76func CreateServeBasic(routes []Route, ctx context.Context) ServeFn {
 77	return func(w http.ResponseWriter, r *http.Request) {
 78		var allow []string
 79		for _, route := range routes {
 80			matches := route.Regex.FindStringSubmatch(r.URL.Path)
 81			if len(matches) > 0 {
 82				if r.Method == "OPTIONS" && route.CorsEnabled {
 83					CorsHeaders(w.Header())
 84					w.WriteHeader(http.StatusOK)
 85					return
 86				} else if r.Method != route.Method {
 87					allow = append(allow, route.Method)
 88					continue
 89				}
 90
 91				if route.CorsEnabled {
 92					CorsHeaders(w.Header())
 93				}
 94
 95				finctx := context.WithValue(ctx, ctxKey{}, matches[1:])
 96				route.Handler(w, r.WithContext(finctx))
 97				return
 98			}
 99		}
100		if len(allow) > 0 {
101			w.Header().Set("Allow", strings.Join(allow, ", "))
102			http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
103			return
104		}
105		http.NotFound(w, r)
106	}
107}
108
109func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
110	var subdomain string
111	curRoutes := routes
112
113	if cfg.IsCustomdomains() || cfg.IsSubdomains() {
114		hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
115		appDomain := strings.ToLower(strings.Split(cfg.Domain, ":")[0])
116
117		if hostDomain != appDomain {
118			if strings.Contains(hostDomain, appDomain) {
119				subdomain = strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
120				if subdomain != "" {
121					curRoutes = subdomainRoutes
122				}
123			} else {
124				subdomain = GetCustomDomain(hostDomain, cfg.Space)
125				if subdomain != "" {
126					curRoutes = subdomainRoutes
127				}
128			}
129		}
130	}
131
132	return curRoutes, subdomain
133}
134
135func CreateServe(routes []Route, subdomainRoutes []Route, apiConfig *ApiConfig) ServeFn {
136	return func(w http.ResponseWriter, r *http.Request) {
137		curRoutes, subdomain := findRouteConfig(r, routes, subdomainRoutes, apiConfig.Cfg)
138		ctx := apiConfig.CreateCtx(r.Context(), subdomain)
139		router := CreateServeBasic(curRoutes, ctx)
140		router(w, r)
141	}
142}
143
144type ctxDBKey struct{}
145type ctxStorageKey struct{}
146type ctxKey struct{}
147type ctxLoggerKey struct{}
148type ctxSubdomainKey struct{}
149type ctxCfg struct{}
150type ctxAnalyticsQueue struct{}
151type CtxSshKey struct{}
152
153func GetSshCtx(r *http.Request) (ssh.Context, error) {
154	payload, ok := r.Context().Value(CtxSshKey{}).(ssh.Context)
155	if payload == nil || !ok {
156		return payload, fmt.Errorf("sshCtx not set on `r.Context()` for connection")
157	}
158	return payload, nil
159}
160
161func GetCfg(r *http.Request) *ConfigSite {
162	return r.Context().Value(ctxCfg{}).(*ConfigSite)
163}
164
165func GetLogger(r *http.Request) *slog.Logger {
166	return r.Context().Value(ctxLoggerKey{}).(*slog.Logger)
167}
168
169func GetDB(r *http.Request) db.DB {
170	return r.Context().Value(ctxDBKey{}).(db.DB)
171}
172
173func GetStorage(r *http.Request) storage.StorageServe {
174	return r.Context().Value(ctxStorageKey{}).(storage.StorageServe)
175}
176
177func GetField(r *http.Request, index int) string {
178	fields := r.Context().Value(ctxKey{}).([]string)
179	if index >= len(fields) {
180		return ""
181	}
182	return fields[index]
183}
184
185func GetSubdomain(r *http.Request) string {
186	return r.Context().Value(ctxSubdomainKey{}).(string)
187}
188
189func GetCustomDomain(host string, space string) string {
190	txt := fmt.Sprintf("_%s.%s", space, host)
191	records, err := net.LookupTXT(txt)
192	if err != nil {
193		return ""
194	}
195
196	for _, v := range records {
197		return strings.TrimSpace(v)
198	}
199
200	return ""
201}
202
203func GetAnalyticsQueue(r *http.Request) chan *db.AnalyticsVisits {
204	return r.Context().Value(ctxAnalyticsQueue{}).(chan *db.AnalyticsVisits)
205}