- commit
- 9d945c2
- parent
- ce404cd
- author
- Antonio Mika
- date
- 2024-11-22 15:48:15 +0000 UTC
feat(pipe) web (#160) refactor(auth): use new http mux router --------- Co-authored-by: Eric Bower <me@erock.io>
11 files changed,
+1330,
-503
M
Makefile
+4,
-0
1@@ -28,6 +28,10 @@ test:
2 go test ./...
3 .PHONY: test
4
5+snaps:
6+ UPDATE_SNAPS=true go test ./...
7+.PHONY: snaps
8+
9 bp-setup:
10 $(DOCKER_CMD) buildx ls | grep pico || $(DOCKER_CMD) buildx create --name pico
11 $(DOCKER_CMD) buildx use pico
+169,
-0
1@@ -0,0 +1,169 @@
2+
3+[TestPaymentWebhook - 1]
4+successfully added pico+ user
5+---
6+
7+[TestAuthApi/authorize - 1]
8+<!doctype html>
9+<html lang="en">
10+ <head>
11+ <meta charset='utf-8'>
12+ <meta name="viewport" content="width=device-width, initial-scale=1" />
13+ <title>auth redirect</title>
14+
15+ <link rel="icon" type="image/png" sizes="16x16" href="/favicon-16x16.png">
16+
17+ <meta name="keywords" content="static, site, hosting" />
18+
19+ <link rel="stylesheet" href="/main.css" />
20+
21+
22+ </head>
23+ <body >
24+<header>
25+ <h1 class="text-2xl">Auth Redirect</h1>
26+ <hr />
27+</header>
28+<main>
29+ <section>
30+ <h2 class="text-xl">You are being redirected to pico.test</h2>
31+ <p>
32+ Here is their auth request data:
33+ </p>
34+
35+ <article>
36+ <h2 class="text-lg">Client ID</h2>
37+ <div>333</div>
38+ </article>
39+
40+ <br />
41+
42+ <article>
43+ <h2 class="text-lg">Redirect URI</h2>
44+ <div>pico.test</div>
45+ </article>
46+
47+ <br />
48+
49+ <article>
50+ <h2 class="text-lg">Scope</h2>
51+ <div>admin</div>
52+ </article>
53+
54+ <br />
55+
56+ <article>
57+ <h2 class="text-lg">Response Type</h2>
58+ <div>json</div>
59+ </article>
60+
61+ <br />
62+
63+ <article>
64+ <h2 class="text-lg">If you would like to continue authenticating with this service, ssh into a pico service and generate a token. Then input it here and click submit.</h2>
65+ <br />
66+ <form action="/redirect" method="POST">
67+ <label for="token">Auth Token:</label><br>
68+ <input type="text" id="token" name="token"><br>
69+ <br />
70+ <input type="hidden" id="redirect_uri" name="redirect_uri" value="pico.test">
71+ <input type="hidden" id="response_type" name="response_type" value="json">
72+ <input type="submit" value="Submit">
73+ </form>
74+ </article>
75+ </section>
76+</main>
77+
78+<footer>
79+ <hr />
80+ <p class="font-italic">Built and maintained by <a href="https://pico.sh">pico.sh</a>.</p>
81+ <div>
82+ <a href="https://github.com/picosh/pico">source</a>
83+ </div>
84+</footer>
85+
86+</body>
87+</html>
88+---
89+
90+[TestUser - 1]
91+[{"id":"1","user_id":"user-1","name":"my-key","key":"nice-pubkey","created_at":"0001-01-01T00:00:00Z"}]
92+---
93+
94+[TestAuthApi/rss - 1]
95+<?xml version="1.0" encoding="UTF-8"?><feed xmlns="http://www.w3.org/2005/Atom">
96+ <title>pico+</title>
97+ <id>https://auth.pico.sh/rss/123</id>
98+ <updated></updated>
99+ <subtitle>get notified of important membership updates</subtitle>
100+ <link href="https://auth.pico.sh/rss/123"></link>
101+ <author>
102+ <name>team pico</name>
103+ </author>
104+ <entry>
105+ <title>pico+ membership activated on 2021-08-15 14:30:45</title>
106+ <updated>2021-08-15T14:30:45Z</updated>
107+ <id>pico-plus-activated-1629037845</id>
108+ <content type="html">Thanks for joining pico+! You now have access to all our premium services for exactly one year. We will send you pico+ expiration notifications through this RSS feed. Go to <a href="https://pico.sh/getting-started#next-steps">pico.sh/getting-started#next-steps</a> to start using our services.</content>
109+ <link href="https://pico.sh" rel="alternate"></link>
110+ <summary type="html">Thanks for joining pico+! You now have access to all our premium services for exactly one year. We will send you pico+ expiration notifications through this RSS feed. Go to <a href="https://pico.sh/getting-started#next-steps">pico.sh/getting-started#next-steps</a> to start using our services.</summary>
111+ <author>
112+ <name>team pico</name>
113+ </author>
114+ </entry>
115+ <entry>
116+ <title>pico+ 1-month expiration notice</title>
117+ <updated>2021-07-16T14:30:45Z</updated>
118+ <id>1626445845</id>
119+ <content type="html">Your pico+ membership is going to expire on 2021-08-16 14:30:45</content>
120+ <link href="https://pico.sh" rel="alternate"></link>
121+ <summary type="html">Your pico+ membership is going to expire on 2021-08-16 14:30:45</summary>
122+ <author>
123+ <name>team pico</name>
124+ </author>
125+ </entry>
126+ <entry>
127+ <title>pico+ 1-week expiration notice</title>
128+ <updated>2021-08-09T14:30:45Z</updated>
129+ <id>1628519445</id>
130+ <content type="html">Your pico+ membership is going to expire on 2021-08-16 14:30:45</content>
131+ <link href="https://pico.sh" rel="alternate"></link>
132+ <summary type="html">Your pico+ membership is going to expire on 2021-08-16 14:30:45</summary>
133+ <author>
134+ <name>team pico</name>
135+ </author>
136+ </entry>
137+ <entry>
138+ <title>pico+ 1-day expiration notice</title>
139+ <updated>2021-08-14T14:30:45Z</updated>
140+ <id>1628951445</id>
141+ <content type="html">Your pico+ membership is going to expire on 2021-08-16 14:30:45</content>
142+ <link href="https://pico.sh" rel="alternate"></link>
143+ <summary type="html">Your pico+ membership is going to expire on 2021-08-16 14:30:45</summary>
144+ <author>
145+ <name>team pico</name>
146+ </author>
147+ </entry>
148+</feed>
149+---
150+
151+[TestKey - 1]
152+{"id":"user-1","name":"user-a","created_at":null}
153+---
154+
155+[TestAuthApi/fileserver - 1]
156+User-agent: *
157+Allow: /
158+---
159+
160+[TestAuthApi/well-known - 1]
161+{"issuer":"auth.pico.test","introspection_endpoint":"http://0.0.0.0:3000/introspect","introspection_endpoint_auth_methods_supported":["none"],"authorization_endpoint":"http://0.0.0.0:3000/authorize","token_endpoint":"http://0.0.0.0:3000/token","response_types_supported":["code"]}
162+---
163+
164+[TestIntrospect - 1]
165+{"active":true,"username":"user-a"}
166+---
167+
168+[TestToken - 1]
169+{"access_token":"123"}
170+---
+395,
-444
1@@ -4,16 +4,16 @@ import (
2 "bufio"
3 "context"
4 "crypto/hmac"
5+ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "html/template"
10 "io"
11+ "io/fs"
12 "log/slog"
13 "net/http"
14 "net/url"
15- "os"
16- "strings"
17 "time"
18
19 "github.com/gorilla/feeds"
20@@ -24,42 +24,8 @@ import (
21 "github.com/picosh/utils/pipe/metrics"
22 )
23
24-type Client struct {
25- Cfg *AuthCfg
26- Dbpool db.DB
27- Logger *slog.Logger
28-}
29-
30-func (client *Client) hasPrivilegedAccess(apiToken string) bool {
31- user, err := client.Dbpool.FindUserForToken(apiToken)
32- if err != nil {
33- return false
34- }
35- return client.Dbpool.HasFeatureForUser(user.ID, "auth")
36-}
37-
38-type ctxClient struct{}
39-type ctxKey struct{}
40-
41-func getClient(r *http.Request) *Client {
42- return r.Context().Value(ctxClient{}).(*Client)
43-}
44-
45-func getField(r *http.Request, index int) string {
46- fields := r.Context().Value(ctxKey{}).([]string)
47- if index >= len(fields) {
48- return ""
49- }
50- return fields[index]
51-}
52-
53-func getApiToken(r *http.Request) string {
54- authHeader := r.Header.Get("authorization")
55- if authHeader == "" {
56- return ""
57- }
58- return strings.TrimPrefix(authHeader, "Bearer ")
59-}
60+//go:embed html/* public/*
61+var embedFS embed.FS
62
63 type oauth2Server struct {
64 Issuer string `json:"issuer"`
65@@ -70,7 +36,7 @@ type oauth2Server struct {
66 ResponseTypesSupported []string `json:"response_types_supported"`
67 }
68
69-func generateURL(cfg *AuthCfg, path string, space string) string {
70+func generateURL(cfg *shared.ConfigSite, path string, space string) string {
71 query := ""
72
73 if space != "" {
74@@ -80,38 +46,30 @@ func generateURL(cfg *AuthCfg, path string, space string) string {
75 return fmt.Sprintf("%s/%s%s", cfg.Domain, path, query)
76 }
77
78-func hasPlusOrSpace(client *Client, user *db.User, space string) bool {
79- return client.Dbpool.HasFeatureForUser(user.ID, "plus") || client.Dbpool.HasFeatureForUser(user.ID, space)
80-}
81-
82-func wellKnownHandler(w http.ResponseWriter, r *http.Request) {
83- client := getClient(r)
84-
85- space, err := url.PathUnescape(getField(r, 0))
86- if err != nil {
87- client.Logger.Error(err.Error())
88- }
89-
90- if space == "" {
91- space = r.URL.Query().Get("space")
92- }
93+func wellKnownHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
94+ return func(w http.ResponseWriter, r *http.Request) {
95+ space := r.PathValue("space")
96+ if space == "" {
97+ space = r.URL.Query().Get("space")
98+ }
99
100- p := oauth2Server{
101- Issuer: client.Cfg.Issuer,
102- IntrospectionEndpoint: generateURL(client.Cfg, "introspect", space),
103- IntrospectionEndpointAuthMethodsSupported: []string{
104- "none",
105- },
106- AuthorizationEndpoint: generateURL(client.Cfg, "authorize", ""),
107- TokenEndpoint: generateURL(client.Cfg, "token", ""),
108- ResponseTypesSupported: []string{"code"},
109- }
110- w.Header().Set("Content-Type", "application/json")
111- w.WriteHeader(http.StatusOK)
112- err = json.NewEncoder(w).Encode(p)
113- if err != nil {
114- client.Logger.Error(err.Error())
115- http.Error(w, err.Error(), http.StatusInternalServerError)
116+ p := oauth2Server{
117+ Issuer: apiConfig.Cfg.Issuer,
118+ IntrospectionEndpoint: generateURL(apiConfig.Cfg, "introspect", space),
119+ IntrospectionEndpointAuthMethodsSupported: []string{
120+ "none",
121+ },
122+ AuthorizationEndpoint: generateURL(apiConfig.Cfg, "authorize", ""),
123+ TokenEndpoint: generateURL(apiConfig.Cfg, "token", ""),
124+ ResponseTypesSupported: []string{"code"},
125+ }
126+ w.Header().Set("Content-Type", "application/json")
127+ w.WriteHeader(http.StatusOK)
128+ err := json.NewEncoder(w).Encode(p)
129+ if err != nil {
130+ apiConfig.Cfg.Logger.Error(err.Error())
131+ http.Error(w, err.Error(), http.StatusInternalServerError)
132+ }
133 }
134 }
135
136@@ -120,148 +78,150 @@ type oauth2Introspection struct {
137 Username string `json:"username"`
138 }
139
140-func introspectHandler(w http.ResponseWriter, r *http.Request) {
141- client := getClient(r)
142- token := r.FormValue("token")
143- client.Logger.Info("introspect token", "token", token)
144+func introspectHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
145+ return func(w http.ResponseWriter, r *http.Request) {
146+ token := r.FormValue("token")
147+ apiConfig.Cfg.Logger.Info("introspect token", "token", token)
148
149- user, err := client.Dbpool.FindUserForToken(token)
150- if err != nil {
151- client.Logger.Error(err.Error())
152- http.Error(w, err.Error(), http.StatusUnauthorized)
153- return
154- }
155+ user, err := apiConfig.Dbpool.FindUserForToken(token)
156+ if err != nil {
157+ apiConfig.Cfg.Logger.Error(err.Error())
158+ http.Error(w, err.Error(), http.StatusUnauthorized)
159+ return
160+ }
161
162- p := oauth2Introspection{
163- Active: true,
164- Username: user.Name,
165- }
166+ p := oauth2Introspection{
167+ Active: true,
168+ Username: user.Name,
169+ }
170
171- space := r.URL.Query().Get("space")
172- if space != "" {
173- if !hasPlusOrSpace(client, user, space) {
174- p.Active = false
175+ space := r.URL.Query().Get("space")
176+ if space != "" {
177+ if !apiConfig.HasPlusOrSpace(user, space) {
178+ p.Active = false
179+ }
180 }
181- }
182
183- w.Header().Set("Content-Type", "application/json")
184- w.WriteHeader(http.StatusOK)
185- err = json.NewEncoder(w).Encode(p)
186- if err != nil {
187- client.Logger.Error(err.Error())
188- http.Error(w, err.Error(), http.StatusInternalServerError)
189+ w.Header().Set("Content-Type", "application/json")
190+ w.WriteHeader(http.StatusOK)
191+ err = json.NewEncoder(w).Encode(p)
192+ if err != nil {
193+ apiConfig.Cfg.Logger.Error(err.Error())
194+ http.Error(w, err.Error(), http.StatusInternalServerError)
195+ }
196 }
197 }
198
199-func authorizeHandler(w http.ResponseWriter, r *http.Request) {
200- client := getClient(r)
201-
202- responseType := r.URL.Query().Get("response_type")
203- clientID := r.URL.Query().Get("client_id")
204- redirectURI := r.URL.Query().Get("redirect_uri")
205- scope := r.URL.Query().Get("scope")
206-
207- client.Logger.Info(
208- "authorize handler",
209- "responseType", responseType,
210- "clientID", clientID,
211- "redirectURI", redirectURI,
212- "scope", scope,
213- )
214+func authorizeHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
215+ return func(w http.ResponseWriter, r *http.Request) {
216+ responseType := r.URL.Query().Get("response_type")
217+ clientID := r.URL.Query().Get("client_id")
218+ redirectURI := r.URL.Query().Get("redirect_uri")
219+ scope := r.URL.Query().Get("scope")
220+
221+ apiConfig.Cfg.Logger.Info(
222+ "authorize handler",
223+ "responseType", responseType,
224+ "clientID", clientID,
225+ "redirectURI", redirectURI,
226+ "scope", scope,
227+ )
228
229- ts, err := template.ParseFiles(
230- "auth/html/redirect.page.tmpl",
231- "auth/html/footer.partial.tmpl",
232- "auth/html/marketing-footer.partial.tmpl",
233- "auth/html/base.layout.tmpl",
234- )
235+ ts, err := template.ParseFS(
236+ embedFS,
237+ "html/redirect.page.tmpl",
238+ "html/footer.partial.tmpl",
239+ "html/marketing-footer.partial.tmpl",
240+ "html/base.layout.tmpl",
241+ )
242
243- if err != nil {
244- client.Logger.Error(err.Error())
245- http.Error(w, err.Error(), http.StatusUnauthorized)
246- return
247- }
248+ if err != nil {
249+ apiConfig.Cfg.Logger.Error(err.Error())
250+ http.Error(w, err.Error(), http.StatusUnauthorized)
251+ return
252+ }
253
254- err = ts.Execute(w, map[string]any{
255- "response_type": responseType,
256- "client_id": clientID,
257- "redirect_uri": redirectURI,
258- "scope": scope,
259- })
260+ err = ts.Execute(w, map[string]any{
261+ "response_type": responseType,
262+ "client_id": clientID,
263+ "redirect_uri": redirectURI,
264+ "scope": scope,
265+ })
266
267- if err != nil {
268- client.Logger.Error(err.Error())
269- http.Error(w, err.Error(), http.StatusUnauthorized)
270- return
271+ if err != nil {
272+ apiConfig.Cfg.Logger.Error(err.Error())
273+ http.Error(w, err.Error(), http.StatusUnauthorized)
274+ return
275+ }
276 }
277 }
278
279-func redirectHandler(w http.ResponseWriter, r *http.Request) {
280- client := getClient(r)
281-
282- token := r.FormValue("token")
283- redirectURI := r.FormValue("redirect_uri")
284- responseType := r.FormValue("response_type")
285-
286- client.Logger.Info("redirect handler",
287- "token", token,
288- "redirectURI", redirectURI,
289- "responseType", responseType,
290- )
291+func redirectHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
292+ return func(w http.ResponseWriter, r *http.Request) {
293+ token := r.FormValue("token")
294+ redirectURI := r.FormValue("redirect_uri")
295+ responseType := r.FormValue("response_type")
296+
297+ apiConfig.Cfg.Logger.Info("redirect handler",
298+ "token", token,
299+ "redirectURI", redirectURI,
300+ "responseType", responseType,
301+ )
302
303- if token == "" || redirectURI == "" || responseType != "code" {
304- http.Error(w, "bad request", http.StatusBadRequest)
305- return
306- }
307+ if token == "" || redirectURI == "" || responseType != "code" {
308+ http.Error(w, "bad request", http.StatusBadRequest)
309+ return
310+ }
311
312- url, err := url.Parse(redirectURI)
313- if err != nil {
314- http.Error(w, err.Error(), http.StatusBadRequest)
315- return
316- }
317+ url, err := url.Parse(redirectURI)
318+ if err != nil {
319+ http.Error(w, err.Error(), http.StatusBadRequest)
320+ return
321+ }
322
323- urlQuery := url.Query()
324- urlQuery.Add("code", token)
325+ urlQuery := url.Query()
326+ urlQuery.Add("code", token)
327
328- url.RawQuery = urlQuery.Encode()
329+ url.RawQuery = urlQuery.Encode()
330
331- http.Redirect(w, r, url.String(), http.StatusFound)
332+ http.Redirect(w, r, url.String(), http.StatusFound)
333+ }
334 }
335
336 type oauth2Token struct {
337 AccessToken string `json:"access_token"`
338 }
339
340-func tokenHandler(w http.ResponseWriter, r *http.Request) {
341- client := getClient(r)
342-
343- token := r.FormValue("code")
344- redirectURI := r.FormValue("redirect_uri")
345- grantType := r.FormValue("grant_type")
346-
347- client.Logger.Info(
348- "handle token",
349- "token", token,
350- "redirectURI", redirectURI,
351- "grantType", grantType,
352- )
353+func tokenHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
354+ return func(w http.ResponseWriter, r *http.Request) {
355+ token := r.FormValue("code")
356+ redirectURI := r.FormValue("redirect_uri")
357+ grantType := r.FormValue("grant_type")
358+
359+ apiConfig.Cfg.Logger.Info(
360+ "handle token",
361+ "token", token,
362+ "redirectURI", redirectURI,
363+ "grantType", grantType,
364+ )
365
366- _, err := client.Dbpool.FindUserForToken(token)
367- if err != nil {
368- client.Logger.Error(err.Error())
369- http.Error(w, err.Error(), http.StatusUnauthorized)
370- return
371- }
372+ _, err := apiConfig.Dbpool.FindUserForToken(token)
373+ if err != nil {
374+ apiConfig.Cfg.Logger.Error(err.Error())
375+ http.Error(w, err.Error(), http.StatusUnauthorized)
376+ return
377+ }
378
379- p := oauth2Token{
380- AccessToken: token,
381- }
382- w.Header().Set("Content-Type", "application/json")
383- w.WriteHeader(http.StatusOK)
384- err = json.NewEncoder(w).Encode(p)
385- if err != nil {
386- client.Logger.Error(err.Error())
387- http.Error(w, err.Error(), http.StatusInternalServerError)
388+ p := oauth2Token{
389+ AccessToken: token,
390+ }
391+ w.Header().Set("Content-Type", "application/json")
392+ w.WriteHeader(http.StatusOK)
393+ err = json.NewEncoder(w).Encode(p)
394+ if err != nil {
395+ apiConfig.Cfg.Logger.Error(err.Error())
396+ http.Error(w, err.Error(), http.StatusInternalServerError)
397+ }
398 }
399 }
400
401@@ -271,98 +231,98 @@ type sishData struct {
402 RemoteAddress string `json:"remote_addr"`
403 }
404
405-func keyHandler(w http.ResponseWriter, r *http.Request) {
406- client := getClient(r)
407+func keyHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
408+ return func(w http.ResponseWriter, r *http.Request) {
409+ var data sishData
410
411- var data sishData
412+ err := json.NewDecoder(r.Body).Decode(&data)
413+ if err != nil {
414+ apiConfig.Cfg.Logger.Error(err.Error())
415+ http.Error(w, err.Error(), http.StatusBadRequest)
416+ return
417+ }
418
419- err := json.NewDecoder(r.Body).Decode(&data)
420- if err != nil {
421- client.Logger.Error(err.Error())
422- http.Error(w, err.Error(), http.StatusBadRequest)
423- return
424- }
425+ space := r.URL.Query().Get("space")
426
427- space := r.URL.Query().Get("space")
428+ apiConfig.Cfg.Logger.Info(
429+ "handle key",
430+ "remoteAddress", data.RemoteAddress,
431+ "user", data.Username,
432+ "space", space,
433+ "publicKey", data.PublicKey,
434+ )
435
436- client.Logger.Info(
437- "handle key",
438- "remoteAddress", data.RemoteAddress,
439- "user", data.Username,
440- "space", space,
441- "publicKey", data.PublicKey,
442- )
443+ user, err := apiConfig.Dbpool.FindUserForKey(data.Username, data.PublicKey)
444+ if err != nil {
445+ apiConfig.Cfg.Logger.Error(err.Error())
446+ w.WriteHeader(http.StatusUnauthorized)
447+ return
448+ }
449
450- user, err := client.Dbpool.FindUserForKey(data.Username, data.PublicKey)
451- if err != nil {
452- client.Logger.Error(err.Error())
453- w.WriteHeader(http.StatusUnauthorized)
454- return
455- }
456+ if !apiConfig.HasPlusOrSpace(user, space) {
457+ w.WriteHeader(http.StatusUnauthorized)
458+ return
459+ }
460
461- if !hasPlusOrSpace(client, user, space) {
462- w.WriteHeader(http.StatusUnauthorized)
463- return
464- }
465+ if !apiConfig.HasPrivilegedAccess(shared.GetApiToken(r)) {
466+ w.WriteHeader(http.StatusOK)
467+ return
468+ }
469
470- if !client.hasPrivilegedAccess(getApiToken(r)) {
471+ w.Header().Set("Content-Type", "application/json")
472 w.WriteHeader(http.StatusOK)
473- return
474- }
475-
476- w.Header().Set("Content-Type", "application/json")
477- w.WriteHeader(http.StatusOK)
478- err = json.NewEncoder(w).Encode(user)
479- if err != nil {
480- client.Logger.Error(err.Error())
481- http.Error(w, err.Error(), http.StatusInternalServerError)
482+ err = json.NewEncoder(w).Encode(user)
483+ if err != nil {
484+ apiConfig.Cfg.Logger.Error(err.Error())
485+ http.Error(w, err.Error(), http.StatusInternalServerError)
486+ }
487 }
488 }
489
490-func userHandler(w http.ResponseWriter, r *http.Request) {
491- client := getClient(r)
492-
493- if !client.hasPrivilegedAccess(getApiToken(r)) {
494- w.WriteHeader(http.StatusForbidden)
495- return
496- }
497+func userHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
498+ return func(w http.ResponseWriter, r *http.Request) {
499+ if !apiConfig.HasPrivilegedAccess(shared.GetApiToken(r)) {
500+ w.WriteHeader(http.StatusForbidden)
501+ return
502+ }
503
504- var data sishData
505+ var data sishData
506
507- err := json.NewDecoder(r.Body).Decode(&data)
508- if err != nil {
509- client.Logger.Error(err.Error())
510- http.Error(w, err.Error(), http.StatusBadRequest)
511- return
512- }
513+ err := json.NewDecoder(r.Body).Decode(&data)
514+ if err != nil {
515+ apiConfig.Cfg.Logger.Error(err.Error())
516+ http.Error(w, err.Error(), http.StatusBadRequest)
517+ return
518+ }
519
520- client.Logger.Info(
521- "handle key",
522- "remoteAddress", data.RemoteAddress,
523- "user", data.Username,
524- "publicKey", data.PublicKey,
525- )
526+ apiConfig.Cfg.Logger.Info(
527+ "handle key",
528+ "remoteAddress", data.RemoteAddress,
529+ "user", data.Username,
530+ "publicKey", data.PublicKey,
531+ )
532
533- user, err := client.Dbpool.FindUserForName(data.Username)
534- if err != nil {
535- client.Logger.Error(err.Error())
536- http.Error(w, err.Error(), http.StatusNotFound)
537- return
538- }
539+ user, err := apiConfig.Dbpool.FindUserForName(data.Username)
540+ if err != nil {
541+ apiConfig.Cfg.Logger.Error(err.Error())
542+ http.Error(w, err.Error(), http.StatusNotFound)
543+ return
544+ }
545
546- keys, err := client.Dbpool.FindKeysForUser(user)
547- if err != nil {
548- client.Logger.Error(err.Error())
549- http.Error(w, err.Error(), http.StatusNotFound)
550- return
551- }
552+ keys, err := apiConfig.Dbpool.FindKeysForUser(user)
553+ if err != nil {
554+ apiConfig.Cfg.Logger.Error(err.Error())
555+ http.Error(w, err.Error(), http.StatusNotFound)
556+ return
557+ }
558
559- w.Header().Set("Content-Type", "application/json")
560- w.WriteHeader(http.StatusOK)
561- err = json.NewEncoder(w).Encode(keys)
562- if err != nil {
563- client.Logger.Error(err.Error())
564- http.Error(w, err.Error(), http.StatusInternalServerError)
565+ w.Header().Set("Content-Type", "application/json")
566+ w.WriteHeader(http.StatusOK)
567+ err = json.NewEncoder(w).Encode(keys)
568+ if err != nil {
569+ apiConfig.Cfg.Logger.Error(err.Error())
570+ http.Error(w, err.Error(), http.StatusInternalServerError)
571+ }
572 }
573 }
574
575@@ -387,121 +347,127 @@ func genFeedItem(now time.Time, expiresAt time.Time, warning time.Time, txt stri
576 return nil
577 }
578
579-func rssHandler(w http.ResponseWriter, r *http.Request) {
580- client := getClient(r)
581- apiToken, err := url.PathUnescape(getField(r, 0))
582- if err != nil {
583- client.Logger.Error(err.Error())
584- http.Error(w, err.Error(), http.StatusNotFound)
585- return
586- }
587- user, err := client.Dbpool.FindUserForToken(apiToken)
588- if err != nil {
589- client.Logger.Error(err.Error())
590- http.Error(w, "invalid token", http.StatusNotFound)
591- return
592- }
593-
594- href := fmt.Sprintf("https://auth.pico.sh/rss/%s", apiToken)
595+func rssHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
596+ return func(w http.ResponseWriter, r *http.Request) {
597+ apiToken := r.PathValue("token")
598+ user, err := apiConfig.Dbpool.FindUserForToken(apiToken)
599+ if err != nil {
600+ apiConfig.Cfg.Logger.Error(err.Error())
601+ http.Error(w, "invalid token", http.StatusNotFound)
602+ return
603+ }
604
605- feed := &feeds.Feed{
606- Title: "pico+",
607- Link: &feeds.Link{Href: href},
608- Description: "get notified of important membership updates",
609- Author: &feeds.Author{Name: "team pico"},
610- Created: time.Now(),
611- }
612- var feedItems []*feeds.Item
613+ href := fmt.Sprintf("https://auth.pico.sh/rss/%s", apiToken)
614
615- now := time.Now()
616- ff, err := client.Dbpool.FindFeatureForUser(user.ID, "plus")
617- if err != nil {
618- // still want to send an empty feed
619- } else {
620- createdAt := ff.CreatedAt
621- createdAtStr := createdAt.Format("2006-01-02 15:04:05")
622- id := fmt.Sprintf("pico-plus-activated-%d", createdAt.Unix())
623- content := `Thanks for joining pico+! You now have access to all our premium services for exactly one year. We will send you pico+ expiration notifications through this RSS feed. Go to <a href="https://pico.sh/getting-started#next-steps">pico.sh/getting-started#next-steps</a> to start using our services.`
624- plus := &feeds.Item{
625- Id: id,
626- Title: fmt.Sprintf("pico+ membership activated on %s", createdAtStr),
627- Link: &feeds.Link{Href: "https://pico.sh"},
628- Content: content,
629- Created: *createdAt,
630- Updated: *createdAt,
631- Description: content,
632+ feed := &feeds.Feed{
633+ Title: "pico+",
634+ Link: &feeds.Link{Href: href},
635+ Description: "get notified of important membership updates",
636 Author: &feeds.Author{Name: "team pico"},
637 }
638- feedItems = append(feedItems, plus)
639+ var feedItems []*feeds.Item
640
641- oneMonthWarning := ff.ExpiresAt.AddDate(0, -1, 0)
642- mo := genFeedItem(now, *ff.ExpiresAt, oneMonthWarning, "1-month")
643- if mo != nil {
644- feedItems = append(feedItems, mo)
645+ now := time.Now()
646+ ff, err := apiConfig.Dbpool.FindFeatureForUser(user.ID, "plus")
647+ if err != nil {
648+ // still want to send an empty feed
649+ } else {
650+ createdAt := ff.CreatedAt
651+ createdAtStr := createdAt.Format("2006-01-02 15:04:05")
652+ id := fmt.Sprintf("pico-plus-activated-%d", createdAt.Unix())
653+ content := `Thanks for joining pico+! You now have access to all our premium services for exactly one year. We will send you pico+ expiration notifications through this RSS feed. Go to <a href="https://pico.sh/getting-started#next-steps">pico.sh/getting-started#next-steps</a> to start using our services.`
654+ plus := &feeds.Item{
655+ Id: id,
656+ Title: fmt.Sprintf("pico+ membership activated on %s", createdAtStr),
657+ Link: &feeds.Link{Href: "https://pico.sh"},
658+ Content: content,
659+ Created: *createdAt,
660+ Updated: *createdAt,
661+ Description: content,
662+ Author: &feeds.Author{Name: "team pico"},
663+ }
664+ feedItems = append(feedItems, plus)
665+
666+ oneMonthWarning := ff.ExpiresAt.AddDate(0, -1, 0)
667+ mo := genFeedItem(now, *ff.ExpiresAt, oneMonthWarning, "1-month")
668+ if mo != nil {
669+ feedItems = append(feedItems, mo)
670+ }
671+
672+ oneWeekWarning := ff.ExpiresAt.AddDate(0, 0, -7)
673+ wk := genFeedItem(now, *ff.ExpiresAt, oneWeekWarning, "1-week")
674+ if wk != nil {
675+ feedItems = append(feedItems, wk)
676+ }
677+
678+ oneDayWarning := ff.ExpiresAt.AddDate(0, 0, -2)
679+ day := genFeedItem(now, *ff.ExpiresAt, oneDayWarning, "1-day")
680+ if day != nil {
681+ feedItems = append(feedItems, day)
682+ }
683 }
684
685- oneWeekWarning := ff.ExpiresAt.AddDate(0, 0, -7)
686- wk := genFeedItem(now, *ff.ExpiresAt, oneWeekWarning, "1-week")
687- if wk != nil {
688- feedItems = append(feedItems, wk)
689+ feed.Items = feedItems
690+
691+ rss, err := feed.ToAtom()
692+ if err != nil {
693+ apiConfig.Cfg.Logger.Error(err.Error())
694+ http.Error(w, "Could not generate atom rss feed", http.StatusInternalServerError)
695 }
696
697- oneDayWarning := ff.ExpiresAt.AddDate(0, 0, -1)
698- day := genFeedItem(now, *ff.ExpiresAt, oneDayWarning, "1-day")
699- if day != nil {
700- feedItems = append(feedItems, day)
701+ w.Header().Add("Content-Type", "application/atom+xml")
702+ _, err = w.Write([]byte(rss))
703+ if err != nil {
704+ apiConfig.Cfg.Logger.Error(err.Error())
705 }
706 }
707+}
708
709- feed.Items = feedItems
710+type CustomDataMeta struct {
711+ PicoUsername string `json:"username"`
712+}
713
714- rss, err := feed.ToAtom()
715- if err != nil {
716- client.Logger.Error(err.Error())
717- http.Error(w, "Could not generate atom rss feed", http.StatusInternalServerError)
718- }
719+type OrderEventMeta struct {
720+ EventName string `json:"event_name"`
721+ CustomData *CustomDataMeta `json:"custom_data"`
722+}
723
724- w.Header().Add("Content-Type", "application/atom+xml")
725- _, err = w.Write([]byte(rss))
726- if err != nil {
727- client.Logger.Error(err.Error())
728- }
729+type OrderEventData struct {
730+ Type string `json:"type"`
731+ ID string `json:"id"`
732+ Attr *OrderEventDataAttr `json:"attributes"`
733+}
734+
735+type OrderEventDataAttr struct {
736+ OrderNumber int `json:"order_number"`
737+ Identifier string `json:"identifier"`
738+ UserName string `json:"user_name"`
739+ UserEmail string `json:"user_email"`
740+ CreatedAt time.Time `json:"created_at"`
741+ Status string `json:"status"` // `paid`, `refund`
742 }
743
744 type OrderEvent struct {
745- Meta *struct {
746- EventName string `json:"event_name"`
747- CustomData *struct {
748- PicoUsername string `json:"username"`
749- } `json:"custom_data"`
750- } `json:"meta"`
751- Data *struct {
752- Type string `json:"type"`
753- ID string `json:"id"`
754- Attr *struct {
755- OrderNumber int `json:"order_number"`
756- Identifier string `json:"identifier"`
757- UserName string `json:"user_name"`
758- UserEmail string `json:"user_email"`
759- CreatedAt time.Time `json:"created_at"`
760- Status string `json:"status"` // `paid`, `refund`
761- } `json:"attributes"`
762- } `json:"data"`
763+ Meta *OrderEventMeta `json:"meta"`
764+ Data *OrderEventData `json:"data"`
765 }
766
767 // Status code must be 200 or else lemonsqueezy will keep retrying
768 // https://docs.lemonsqueezy.com/help/webhooks
769-func paymentWebhookHandler(secret string) func(http.ResponseWriter, *http.Request) {
770+func paymentWebhookHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
771 return func(w http.ResponseWriter, r *http.Request) {
772- client := getClient(r)
773- dbpool := client.Dbpool
774- logger := client.Logger
775+ dbpool := apiConfig.Dbpool
776+ logger := apiConfig.Cfg.Logger
777 const MaxBodyBytes = int64(65536)
778 r.Body = http.MaxBytesReader(w, r.Body, MaxBodyBytes)
779 payload, err := io.ReadAll(r.Body)
780+
781+ w.Header().Add("content-type", "text/plain")
782+
783 if err != nil {
784 logger.Error("error reading request body", "err", err.Error())
785 w.WriteHeader(http.StatusOK)
786+ _, _ = w.Write([]byte(fmt.Sprintf("error reading request body %s", err.Error())))
787 return
788 }
789
790@@ -510,32 +476,37 @@ func paymentWebhookHandler(secret string) func(http.ResponseWriter, *http.Reques
791 if err := json.Unmarshal(payload, &event); err != nil {
792 logger.Error("failed to parse webhook body JSON", "err", err.Error())
793 w.WriteHeader(http.StatusOK)
794+ _, _ = w.Write([]byte(fmt.Sprintf("failed to parse webhook body JSON %s", err.Error())))
795 return
796 }
797
798- hash := shared.HmacString(secret, string(payload))
799+ hash := shared.HmacString(apiConfig.Cfg.SecretWebhook, string(payload))
800 sig := r.Header.Get("X-Signature")
801 if !hmac.Equal([]byte(hash), []byte(sig)) {
802 logger.Error("invalid signature X-Signature")
803 w.WriteHeader(http.StatusOK)
804+ _, _ = w.Write([]byte("invalid signature x-signature"))
805 return
806 }
807
808 if event.Meta == nil {
809 logger.Error("no meta field found")
810 w.WriteHeader(http.StatusOK)
811+ _, _ = w.Write([]byte("no meta field found"))
812 return
813 }
814
815 if event.Meta.EventName != "order_created" {
816 logger.Error("event not order_created", "event", event.Meta.EventName)
817 w.WriteHeader(http.StatusOK)
818+ _, _ = w.Write([]byte("event not order_created"))
819 return
820 }
821
822 if event.Meta.CustomData == nil {
823 logger.Error("no custom data found")
824 w.WriteHeader(http.StatusOK)
825+ _, _ = w.Write([]byte("no custom data found"))
826 return
827 }
828
829@@ -544,6 +515,7 @@ func paymentWebhookHandler(secret string) func(http.ResponseWriter, *http.Reques
830 if event.Data == nil || event.Data.Attr == nil {
831 logger.Error("no data or data.attributes fields found")
832 w.WriteHeader(http.StatusOK)
833+ _, _ = w.Write([]byte("no data or data.attributes fields found"))
834 return
835 }
836
837@@ -567,95 +539,42 @@ func paymentWebhookHandler(secret string) func(http.ResponseWriter, *http.Reques
838 if username == "" {
839 log.Error("no `?checkout[custom][username]=xxx` found in URL, cannot add pico+ membership")
840 w.WriteHeader(http.StatusOK)
841+ _, _ = w.Write([]byte("no `?checkout[custom][username]=xxx` found in URL, cannot add pico+ membership"))
842 return
843 }
844
845 if status != "paid" {
846 log.Error("status not paid")
847 w.WriteHeader(http.StatusOK)
848+ _, _ = w.Write([]byte("status not paid"))
849 return
850 }
851
852 err = dbpool.AddPicoPlusUser(username, "lemonsqueezy", txID)
853 if err != nil {
854 log.Error("failed to add pico+ user", "err", err)
855- } else {
856- log.Info("successfully added pico+ user")
857+ w.WriteHeader(http.StatusOK)
858+ _, _ = w.Write([]byte("status not paid"))
859+ return
860 }
861
862+ log.Info("successfully added pico+ user")
863 w.WriteHeader(http.StatusOK)
864+ _, _ = w.Write([]byte("successfully added pico+ user"))
865 }
866 }
867
868 // URL shortener for out pico+ URL.
869-func checkoutHandler(w http.ResponseWriter, r *http.Request) {
870- username, err := url.PathUnescape(getField(r, 0))
871- if err != nil {
872- w.WriteHeader(http.StatusUnprocessableEntity)
873- return
874- }
875- link := "https://checkout.pico.sh/buy/73c26cf9-3fac-44c3-b744-298b3032a96b"
876- url := fmt.Sprintf(
877- "%s?discount=0&checkout[custom][username]=%s",
878- link,
879- username,
880- )
881- http.Redirect(w, r, url, http.StatusMovedPermanently)
882-}
883-
884-func createMainRoutes() []shared.Route {
885- fileServer := http.FileServer(http.Dir("auth/public"))
886- secret := os.Getenv("PICO_SECRET_WEBHOOK")
887- if secret == "" {
888- panic("must provide PICO_SECRET_WEBHOOK environment variable")
889- }
890-
891- routes := []shared.Route{
892- shared.NewRoute("GET", "/checkout/(.+)", checkoutHandler),
893- shared.NewRoute("GET", "/.well-known/oauth-authorization-server/?(.+)?", wellKnownHandler),
894- shared.NewRoute("POST", "/introspect", introspectHandler),
895- shared.NewRoute("GET", "/authorize", authorizeHandler),
896- shared.NewRoute("POST", "/token", tokenHandler),
897- shared.NewRoute("POST", "/key", keyHandler),
898- shared.NewRoute("POST", "/user", userHandler),
899- shared.NewRoute("GET", "/rss/(.+)", rssHandler),
900- shared.NewRoute("POST", "/redirect", redirectHandler),
901- shared.NewRoute("POST", "/webhook", paymentWebhookHandler(secret)),
902- shared.NewRoute("GET", "/main.css", fileServer.ServeHTTP),
903- shared.NewRoute("GET", "/card.png", fileServer.ServeHTTP),
904- shared.NewRoute("GET", "/favicon-16x16.png", fileServer.ServeHTTP),
905- shared.NewRoute("GET", "/favicon-32x32.png", fileServer.ServeHTTP),
906- shared.NewRoute("GET", "/apple-touch-icon.png", fileServer.ServeHTTP),
907- shared.NewRoute("GET", "/favicon.ico", fileServer.ServeHTTP),
908- shared.NewRoute("GET", "/robots.txt", fileServer.ServeHTTP),
909- }
910-
911- return routes
912-}
913-
914-func handler(routes []shared.Route, client *Client) http.HandlerFunc {
915+func checkoutHandler() http.HandlerFunc {
916 return func(w http.ResponseWriter, r *http.Request) {
917- var allow []string
918-
919- for _, route := range routes {
920- matches := route.Regex.FindStringSubmatch(r.URL.Path)
921- if len(matches) > 0 {
922- if r.Method != route.Method {
923- allow = append(allow, route.Method)
924- continue
925- }
926- clientCtx := context.WithValue(r.Context(), ctxClient{}, client)
927- ctx := context.WithValue(clientCtx, ctxKey{}, matches[1:])
928- route.Handler(w, r.WithContext(ctx))
929- return
930- }
931- }
932- if len(allow) > 0 {
933- w.Header().Set("Allow", strings.Join(allow, ", "))
934- http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
935- return
936- }
937- http.NotFound(w, r)
938+ username := r.PathValue("username")
939+ link := "https://checkout.pico.sh/buy/73c26cf9-3fac-44c3-b744-298b3032a96b"
940+ url := fmt.Sprintf(
941+ "%s?discount=0&checkout[custom][username]=%s",
942+ link,
943+ username,
944+ )
945+ http.Redirect(w, r, url, http.StatusMovedPermanently)
946 }
947 }
948
949@@ -702,56 +621,88 @@ func metricDrainSub(ctx context.Context, dbpool db.DB, logger *slog.Logger, secr
950 }
951 }
952
953-type AuthCfg struct {
954- Debug bool
955- Port string
956- DbURL string
957- Domain string
958- Issuer string
959- Secret string
960+func authMux(apiConfig *shared.ApiConfig) *http.ServeMux {
961+ serverRoot, err := fs.Sub(embedFS, "public")
962+ if err != nil {
963+ panic(err)
964+ }
965+ fileServer := http.FileServerFS(serverRoot)
966+
967+ mux := http.NewServeMux()
968+ // ensure legacy router is disabled
969+ // GODEBUG=httpmuxgo121=0
970+ mux.Handle("GET /checkout/{username}", checkoutHandler())
971+ mux.Handle("GET /.well-known/oauth-authorization-server", wellKnownHandler(apiConfig))
972+ mux.Handle("GET /.well-known/oauth-authorization-server/{space}", wellKnownHandler(apiConfig))
973+ mux.Handle("POST /introspect", introspectHandler(apiConfig))
974+ mux.Handle("GET /authorize", authorizeHandler(apiConfig))
975+ mux.Handle("POST /token", tokenHandler(apiConfig))
976+ mux.Handle("POST /key", keyHandler(apiConfig))
977+ mux.Handle("POST /user", userHandler(apiConfig))
978+ mux.Handle("GET /rss/{token}", rssHandler(apiConfig))
979+ mux.Handle("POST /redirect", redirectHandler(apiConfig))
980+ mux.Handle("POST /webhook", paymentWebhookHandler(apiConfig))
981+ mux.HandleFunc("GET /main.css", fileServer.ServeHTTP)
982+ mux.HandleFunc("GET /card.png", fileServer.ServeHTTP)
983+ mux.HandleFunc("GET /favicon-16x16.png", fileServer.ServeHTTP)
984+ mux.HandleFunc("GET /favicon-32x32.png", fileServer.ServeHTTP)
985+ mux.HandleFunc("GET /apple-touch-icon.png", fileServer.ServeHTTP)
986+ mux.HandleFunc("GET /favicon.ico", fileServer.ServeHTTP)
987+ mux.HandleFunc("GET /robots.txt", fileServer.ServeHTTP)
988+
989+ if apiConfig.Cfg.Debug {
990+ shared.CreatePProfRoutesMux(mux)
991+ }
992+
993+ return mux
994 }
995
996 func StartApiServer() {
997 debug := utils.GetEnv("AUTH_DEBUG", "0")
998- cfg := &AuthCfg{
999- DbURL: utils.GetEnv("DATABASE_URL", ""),
1000- Debug: debug == "1",
1001- Issuer: utils.GetEnv("AUTH_ISSUER", "pico.sh"),
1002- Domain: utils.GetEnv("AUTH_DOMAIN", "http://0.0.0.0:3000"),
1003- Port: utils.GetEnv("AUTH_WEB_PORT", "3000"),
1004- Secret: utils.GetEnv("PICO_SECRET", ""),
1005+
1006+ cfg := &shared.ConfigSite{
1007+ DbURL: utils.GetEnv("DATABASE_URL", ""),
1008+ Debug: debug == "1",
1009+ Issuer: utils.GetEnv("AUTH_ISSUER", "pico.sh"),
1010+ Domain: utils.GetEnv("AUTH_DOMAIN", "http://0.0.0.0:3000"),
1011+ Port: utils.GetEnv("AUTH_WEB_PORT", "3000"),
1012+ Secret: utils.GetEnv("PICO_SECRET", ""),
1013+ SecretWebhook: utils.GetEnv("PICO_SECRET_WEBHOOK", ""),
1014 }
1015+
1016+ if cfg.SecretWebhook == "" {
1017+ panic("must provide PICO_SECRET_WEBHOOK environment variable")
1018+ }
1019+
1020 if cfg.Secret == "" {
1021 panic("must provide PICO_SECRET environment variable")
1022 }
1023
1024 logger := shared.CreateLogger("auth")
1025+
1026+ cfg.Logger = logger
1027+
1028 db := postgres.NewDB(cfg.DbURL, logger)
1029 defer db.Close()
1030
1031- client := &Client{
1032- Cfg: cfg,
1033- Dbpool: db,
1034- Logger: logger,
1035- }
1036-
1037 ctx := context.Background()
1038+
1039 // gather metrics in the auth service
1040 go metricDrainSub(ctx, db, logger, cfg.Secret)
1041 defer ctx.Done()
1042
1043- routes := createMainRoutes()
1044-
1045- if cfg.Debug {
1046- routes = shared.CreatePProfRoutes(routes)
1047+ apiConfig := &shared.ApiConfig{
1048+ Cfg: cfg,
1049+ Dbpool: db,
1050 }
1051
1052- router := http.HandlerFunc(handler(routes, client))
1053+ mux := authMux(apiConfig)
1054
1055 portStr := fmt.Sprintf(":%s", cfg.Port)
1056- client.Logger.Info("starting server on port", "port", cfg.Port)
1057- err := http.ListenAndServe(portStr, router)
1058+ logger.Info("starting server on port", "port", cfg.Port)
1059+
1060+ err := http.ListenAndServe(portStr, mux)
1061 if err != nil {
1062- client.Logger.Info("http-serve", "err", err.Error())
1063+ logger.Info("http-serve", "err", err.Error())
1064 }
1065 }
+283,
-0
1@@ -0,0 +1,283 @@
2+package auth
3+
4+import (
5+ "bytes"
6+ "encoding/json"
7+ "fmt"
8+ "log/slog"
9+ "net/http"
10+ "net/http/httptest"
11+ "strings"
12+ "testing"
13+ "time"
14+
15+ "github.com/gkampitakis/go-snaps/snaps"
16+ "github.com/picosh/pico/db"
17+ "github.com/picosh/pico/db/stub"
18+ "github.com/picosh/pico/shared"
19+)
20+
21+var testUserID = "user-1"
22+var testUsername = "user-a"
23+
24+func TestPaymentWebhook(t *testing.T) {
25+ apiConfig := setupTest()
26+
27+ event := OrderEvent{
28+ Meta: &OrderEventMeta{
29+ EventName: "order_created",
30+ CustomData: &CustomDataMeta{
31+ PicoUsername: testUsername,
32+ },
33+ },
34+ Data: &OrderEventData{
35+ Attr: &OrderEventDataAttr{
36+ UserEmail: "auth@pico.test",
37+ CreatedAt: time.Now(),
38+ Status: "paid",
39+ OrderNumber: 1337,
40+ },
41+ },
42+ }
43+ jso, err := json.Marshal(event)
44+ bail(err)
45+ hash := shared.HmacString(apiConfig.Cfg.SecretWebhook, string(jso))
46+ body := bytes.NewReader(jso)
47+
48+ request := httptest.NewRequest("POST", mkpath("/webhook"), body)
49+ request.Header.Add("X-signature", hash)
50+ responseRecorder := httptest.NewRecorder()
51+
52+ mux := authMux(apiConfig)
53+ mux.ServeHTTP(responseRecorder, request)
54+
55+ testResponse(t, responseRecorder, 200, "text/plain")
56+}
57+
58+func TestUser(t *testing.T) {
59+ apiConfig := setupTest()
60+
61+ data := sishData{
62+ Username: testUsername,
63+ }
64+ jso, err := json.Marshal(data)
65+ bail(err)
66+ body := bytes.NewReader(jso)
67+
68+ request := httptest.NewRequest("POST", mkpath("/user"), body)
69+ request.Header.Add("Authorization", "Bearer 123")
70+ responseRecorder := httptest.NewRecorder()
71+
72+ mux := authMux(apiConfig)
73+ mux.ServeHTTP(responseRecorder, request)
74+
75+ testResponse(t, responseRecorder, 200, "application/json")
76+}
77+
78+func TestKey(t *testing.T) {
79+ apiConfig := setupTest()
80+
81+ data := sishData{
82+ Username: testUsername,
83+ PublicKey: "zzz",
84+ }
85+ jso, err := json.Marshal(data)
86+ bail(err)
87+ body := bytes.NewReader(jso)
88+
89+ request := httptest.NewRequest("POST", mkpath("/key"), body)
90+ request.Header.Add("Authorization", "Bearer 123")
91+ responseRecorder := httptest.NewRecorder()
92+
93+ mux := authMux(apiConfig)
94+ mux.ServeHTTP(responseRecorder, request)
95+
96+ testResponse(t, responseRecorder, 200, "application/json")
97+}
98+
99+func TestCheckout(t *testing.T) {
100+ apiConfig := setupTest()
101+
102+ request := httptest.NewRequest("GET", mkpath("/checkout/"+testUsername), strings.NewReader(""))
103+ request.Header.Add("Authorization", "Bearer 123")
104+ responseRecorder := httptest.NewRecorder()
105+
106+ mux := authMux(apiConfig)
107+ mux.ServeHTTP(responseRecorder, request)
108+
109+ loc := responseRecorder.Header().Get("Location")
110+ if loc != "https://checkout.pico.sh/buy/73c26cf9-3fac-44c3-b744-298b3032a96b?discount=0&checkout[custom][username]=user-a" {
111+ t.Errorf("Have Location %s, want checkout", loc)
112+ }
113+ if responseRecorder.Code != http.StatusMovedPermanently {
114+ t.Errorf("Want status '%d', got '%d'", http.StatusMovedPermanently, responseRecorder.Code)
115+ return
116+ }
117+}
118+
119+func TestIntrospect(t *testing.T) {
120+ apiConfig := setupTest()
121+
122+ request := httptest.NewRequest("POST", mkpath("/introspect?token=123"), strings.NewReader(""))
123+ responseRecorder := httptest.NewRecorder()
124+
125+ mux := authMux(apiConfig)
126+ mux.ServeHTTP(responseRecorder, request)
127+
128+ testResponse(t, responseRecorder, 200, "application/json")
129+}
130+
131+func TestToken(t *testing.T) {
132+ apiConfig := setupTest()
133+
134+ request := httptest.NewRequest("POST", mkpath("/token?code=123"), strings.NewReader(""))
135+ responseRecorder := httptest.NewRecorder()
136+
137+ mux := authMux(apiConfig)
138+ mux.ServeHTTP(responseRecorder, request)
139+
140+ testResponse(t, responseRecorder, 200, "application/json")
141+}
142+
143+func TestAuthApi(t *testing.T) {
144+ apiConfig := setupTest()
145+ tt := []*ApiExample{
146+ {
147+ name: "authorize",
148+ path: "/authorize?response_type=json&client_id=333&redirect_uri=pico.test&scope=admin",
149+ status: http.StatusOK,
150+ contentType: "text/html; charset=utf-8",
151+ dbpool: apiConfig.Dbpool,
152+ },
153+ {
154+ name: "rss",
155+ path: "/rss/123",
156+ status: http.StatusOK,
157+ contentType: "application/atom+xml",
158+ dbpool: apiConfig.Dbpool,
159+ },
160+ {
161+ name: "fileserver",
162+ path: "/robots.txt",
163+ status: http.StatusOK,
164+ contentType: "text/plain; charset=utf-8",
165+ dbpool: apiConfig.Dbpool,
166+ },
167+ {
168+ name: "well-known",
169+ path: "/.well-known/oauth-authorization-server",
170+ status: http.StatusOK,
171+ contentType: "application/json",
172+ dbpool: apiConfig.Dbpool,
173+ },
174+ }
175+
176+ for _, tc := range tt {
177+ t.Run(tc.name, func(t *testing.T) {
178+ request := httptest.NewRequest("GET", mkpath(tc.path), strings.NewReader(""))
179+ responseRecorder := httptest.NewRecorder()
180+
181+ mux := authMux(apiConfig)
182+ mux.ServeHTTP(responseRecorder, request)
183+
184+ testResponse(t, responseRecorder, tc.status, tc.contentType)
185+ })
186+ }
187+}
188+
189+type ApiExample struct {
190+ name string
191+ path string
192+ status int
193+ contentType string
194+ dbpool db.DB
195+}
196+
197+type AuthDb struct {
198+ *stub.StubDB
199+}
200+
201+func (a *AuthDb) AddPicoPlusUser(username, from, txid string) error {
202+ return nil
203+}
204+
205+func (a *AuthDb) FindUserForName(username string) (*db.User, error) {
206+ return &db.User{ID: testUserID, Name: username}, nil
207+}
208+
209+func (a *AuthDb) FindUserForKey(username string, pubkey string) (*db.User, error) {
210+ return &db.User{ID: testUserID, Name: username}, nil
211+}
212+
213+func (a *AuthDb) FindUserForToken(token string) (*db.User, error) {
214+ if token != "123" {
215+ return nil, fmt.Errorf("invalid token")
216+ }
217+ return &db.User{ID: testUserID, Name: testUsername}, nil
218+}
219+
220+func (a *AuthDb) HasFeatureForUser(userID string, feature string) bool {
221+ return true
222+}
223+
224+func (a *AuthDb) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
225+ return []*db.PublicKey{{ID: "1", UserID: user.ID, Name: "my-key", Key: "nice-pubkey", CreatedAt: &time.Time{}}}, nil
226+}
227+
228+func (a *AuthDb) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
229+ now := time.Date(2021, 8, 15, 14, 30, 45, 100, time.UTC)
230+ oneDayWarning := now.AddDate(0, 0, 1)
231+ return &db.FeatureFlag{ID: "2", UserID: userID, Name: "plus", ExpiresAt: &oneDayWarning, CreatedAt: &now}, nil
232+}
233+
234+func NewAuthDb(logger *slog.Logger) *AuthDb {
235+ sb := stub.NewStubDB(logger)
236+ return &AuthDb{
237+ StubDB: sb,
238+ }
239+}
240+
241+func mkpath(path string) string {
242+ return fmt.Sprintf("https://auth.pico.test%s", path)
243+}
244+
245+func setupTest() *shared.ApiConfig {
246+ logger := shared.CreateLogger("auth")
247+ cfg := &shared.ConfigSite{
248+ Issuer: "auth.pico.test",
249+ Domain: "http://0.0.0.0:3000",
250+ Port: "3000",
251+ Secret: "",
252+ SecretWebhook: "my-secret",
253+ }
254+ cfg.Logger = logger
255+ db := NewAuthDb(cfg.Logger)
256+ apiConfig := &shared.ApiConfig{
257+ Cfg: cfg,
258+ Dbpool: db,
259+ }
260+
261+ return apiConfig
262+}
263+
264+func testResponse(t *testing.T, responseRecorder *httptest.ResponseRecorder, status int, contentType string) {
265+ if responseRecorder.Code != status {
266+ t.Errorf("Want status '%d', got '%d'", status, responseRecorder.Code)
267+ return
268+ }
269+
270+ ct := responseRecorder.Header().Get("content-type")
271+ if ct != contentType {
272+ t.Errorf("Want content type '%s', got '%s'", contentType, ct)
273+ return
274+ }
275+
276+ body := strings.TrimSpace(responseRecorder.Body.String())
277+ snaps.MatchSnapshot(t, body)
278+}
279+
280+func bail(err error) {
281+ if err != nil {
282+ panic(bail)
283+ }
284+}
M
go.mod
+12,
-2
1@@ -28,6 +28,7 @@ require (
2 github.com/charmbracelet/promwish v0.7.0
3 github.com/charmbracelet/ssh v0.0.0-20240725163421-eb71b85b27aa
4 github.com/charmbracelet/wish v1.4.3
5+ github.com/gkampitakis/go-snaps v0.5.7
6 github.com/google/go-cmp v0.6.0
7 github.com/google/uuid v1.6.0
8 github.com/gorilla/feeds v1.2.0
9@@ -41,7 +42,7 @@ require (
10 github.com/picosh/pubsub v0.0.0-20241114191831-ec8f16c0eb88
11 github.com/picosh/send v0.0.0-20241107150437-0febb0049b4f
12 github.com/picosh/tunkit v0.0.0-20240905223921-532404cef9d9
13- github.com/picosh/utils v0.0.0-20241118014950-9515a3e4c5f9
14+ github.com/picosh/utils v0.0.0-20241120033529-8ca070c09bf4
15 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
16 github.com/sendgrid/sendgrid-go v3.16.0+incompatible
17 github.com/simplesurance/go-ip-anonymizer v0.0.0-20200429124537-35a880f8e87d
18@@ -110,6 +111,8 @@ require (
19 github.com/forPelevin/gomoji v1.2.0 // indirect
20 github.com/gdamore/encoding v1.0.1 // indirect
21 github.com/gdamore/tcell/v2 v2.7.4 // indirect
22+ github.com/gkampitakis/ciinfo v0.3.0 // indirect
23+ github.com/gkampitakis/go-diff v1.3.2 // indirect
24 github.com/go-errors/errors v1.5.1 // indirect
25 github.com/go-ini/ini v1.67.0 // indirect
26 github.com/go-logfmt/logfmt v0.6.0 // indirect
27@@ -126,8 +129,11 @@ require (
28 github.com/klauspost/compress v1.17.11 // indirect
29 github.com/klauspost/cpuid/v2 v2.2.9 // indirect
30 github.com/kr/fs v0.1.0 // indirect
31+ github.com/kr/pretty v0.3.1 // indirect
32+ github.com/kr/text v0.2.0 // indirect
33 github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
34 github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect
35+ github.com/maruel/natural v1.1.1 // indirect
36 github.com/mattn/go-isatty v0.0.20 // indirect
37 github.com/mattn/go-localereader v0.0.1 // indirect
38 github.com/mattn/go-runewidth v0.0.16 // indirect
39@@ -154,13 +160,17 @@ require (
40 github.com/prometheus/prom2json v1.4.1 // indirect
41 github.com/prometheus/prometheus v0.300.0 // indirect
42 github.com/rivo/uniseg v0.4.7 // indirect
43- github.com/rogpeppe/go-internal v1.11.0 // indirect
44+ github.com/rogpeppe/go-internal v1.12.0 // indirect
45 github.com/rs/xid v1.6.0 // indirect
46 github.com/safchain/ethtool v0.4.1 // indirect
47 github.com/secure-io/sio-go v0.3.1 // indirect
48 github.com/sendgrid/rest v2.6.9+incompatible // indirect
49 github.com/shirou/gopsutil/v3 v3.24.5 // indirect
50 github.com/shoenig/go-m1cpu v0.1.6 // indirect
51+ github.com/tidwall/gjson v1.17.0 // indirect
52+ github.com/tidwall/match v1.1.1 // indirect
53+ github.com/tidwall/pretty v1.2.1 // indirect
54+ github.com/tidwall/sjson v1.2.5 // indirect
55 github.com/tinylib/msgp v1.2.4 // indirect
56 github.com/tklauser/go-sysconf v0.3.14 // indirect
57 github.com/tklauser/numcpus v0.9.0 // indirect
M
go.sum
+25,
-4
1@@ -103,6 +103,7 @@ github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQ
2 github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
3 github.com/charmbracelet/x/termios v0.1.0 h1:y4rjAHeFksBAfGbkRDmVinMg7x7DELIGAFbdNvxg97k=
4 github.com/charmbracelet/x/termios v0.1.0/go.mod h1:H/EVv/KRnrYjz+fCYa9bsKdqF3S8ouDK0AZEbG7r+/U=
5+github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
6 github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
7 github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9@@ -153,6 +154,12 @@ github.com/forPelevin/gomoji v1.2.0/go.mod h1:8+Z3KNGkdslmeGZBC3tCrwMrcPy5GRzAD+
10 github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg=
11 github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
12 github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
13+github.com/gkampitakis/ciinfo v0.3.0 h1:gWZlOC2+RYYttL0hBqcoQhM7h1qNkVqvRCV1fOvpAv8=
14+github.com/gkampitakis/ciinfo v0.3.0/go.mod h1:1NIwaOcFChN4fa/B0hEBdAb6npDlFL8Bwx4dfRLRqAo=
15+github.com/gkampitakis/go-diff v1.3.2 h1:Qyn0J9XJSDTgnsgHRdz9Zp24RaJeKMUHg2+PDZZdC4M=
16+github.com/gkampitakis/go-diff v1.3.2/go.mod h1:LLgOrpqleQe26cte8s36HTWcTmMEur6OPYerdAAS9tk=
17+github.com/gkampitakis/go-snaps v0.5.7 h1:uVGjHR4t4pPHU944udMx7VKHpwepZXmvDMF+yDmI0rg=
18+github.com/gkampitakis/go-snaps v0.5.7/go.mod h1:ZABkO14uCuVxBHAXAfKG+bqNz+aa1bGPAg8jkI0Nk8Y=
19 github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
20 github.com/go-errors/errors v1.0.2/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs=
21 github.com/go-errors/errors v1.1.1/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs=
22@@ -216,6 +223,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69
23 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
24 github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 h1:7UMa6KCCMjZEMDtTVdcGu0B1GmmC7QJKiCCjyTAWQy0=
25 github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
26+github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo=
27+github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg=
28 github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
29 github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
30 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
31@@ -277,8 +286,9 @@ github.com/picosh/senpai v0.0.0-20240503200611-af89e73973b0 h1:pBRIbiCj7K6rGELij
32 github.com/picosh/senpai v0.0.0-20240503200611-af89e73973b0/go.mod h1:QaBDtybFC5gz7EG/9c3bgzuyW7W5W2rYLFZxWNuWk3Q=
33 github.com/picosh/tunkit v0.0.0-20240905223921-532404cef9d9 h1:g5oZmnDFr11HarA8IAXcc4o9PBlolSM59QIATCSoato=
34 github.com/picosh/tunkit v0.0.0-20240905223921-532404cef9d9/go.mod h1:UrDH/VCIc1wg/L6iY2zSYt4TiGw+25GsKSnkVkU40Dw=
35-github.com/picosh/utils v0.0.0-20241118014950-9515a3e4c5f9 h1:utr7lmPRDBubE52IzLkzzjQFmOu8gIKD+JTLRpxqgnI=
36-github.com/picosh/utils v0.0.0-20241118014950-9515a3e4c5f9/go.mod h1:HogYEyJ43IGXrOa3D/kjM1pkzNAyh+pejRyv8Eo//pk=
37+github.com/picosh/utils v0.0.0-20241120033529-8ca070c09bf4 h1:pwbgY9shKyMlpYvpUalTyV0ZVd5paj8pSEYT4OPOYTk=
38+github.com/picosh/utils v0.0.0-20241120033529-8ca070c09bf4/go.mod h1:HogYEyJ43IGXrOa3D/kjM1pkzNAyh+pejRyv8Eo//pk=
39+github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
40 github.com/pkg/sftp v1.13.7 h1:uv+I3nNJvlKZIQGSr8JVQLNHFU9YhhNpvC14Y6KgmSM=
41 github.com/pkg/sftp v1.13.7/go.mod h1:KMKI0t3T6hfA+lTR/ssZdunHo+uwq7ghoN09/FSu3DY=
42 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
43@@ -303,9 +313,10 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
44 github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
45 github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
46 github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
47+github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
48 github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
49-github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
50-github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
51+github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
52+github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
53 github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
54 github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
55 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI=
56@@ -336,6 +347,16 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
57 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
58 github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
59 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
60+github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
61+github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
62+github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
63+github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
64+github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
65+github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
66+github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
67+github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
68+github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
69+github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
70 github.com/tinylib/msgp v1.2.4 h1:yLFeUGostXXSGW5vxfT5dXG/qzkn4schv2I7at5+hVU=
71 github.com/tinylib/msgp v1.2.4/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
72 github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
+3,
-0
1@@ -83,6 +83,9 @@ func NewWebRouter(cfg *shared.ConfigSite, logger *slog.Logger, dbpool db.DB, st
2 }
3
4 func (web *WebRouter) initRouters() {
5+ // ensure legacy router is disabled
6+ // GODEBUG=httpmuxgo121=0
7+
8 // root domain
9 rootRouter := http.NewServeMux()
10 rootRouter.HandleFunc("GET /check", web.checkHandler)
+258,
-1
1@@ -1,12 +1,28 @@
2 package pipe
3
4 import (
5+ "bufio"
6+ "context"
7+ "errors"
8 "fmt"
9+ "io"
10 "net/http"
11+ "net/url"
12 "os"
13+ "regexp"
14+ "strings"
15+ "sync"
16+ "time"
17
18+ "github.com/google/uuid"
19 "github.com/picosh/pico/db/postgres"
20 "github.com/picosh/pico/shared"
21+ "github.com/picosh/utils/pipe"
22+)
23+
24+var (
25+ cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`)
26+ sshClient *pipe.Client
27 )
28
29 func serveFile(file string, contentType string) http.HandlerFunc {
30@@ -24,7 +40,7 @@ func serveFile(file string, contentType string) http.HandlerFunc {
31 _, err = w.Write(contents)
32 if err != nil {
33 logger.Error("could not write static file", "err", err.Error())
34- http.Error(w, "server error", 500)
35+ http.Error(w, "server error", http.StatusInternalServerError)
36 }
37 }
38 }
39@@ -44,12 +60,228 @@ func createStaticRoutes() []shared.Route {
40 }
41 }
42
43+type writeFlusher struct {
44+ responseWriter http.ResponseWriter
45+ controller *http.ResponseController
46+}
47+
48+func (w writeFlusher) Write(p []byte) (n int, err error) {
49+ n, err = w.responseWriter.Write(p)
50+ if err == nil {
51+ err = w.controller.Flush()
52+ }
53+ return
54+}
55+
56+var _ io.Writer = writeFlusher{}
57+
58+func handleSub(pubsub bool) http.HandlerFunc {
59+ return func(w http.ResponseWriter, r *http.Request) {
60+ logger := shared.GetLogger(r)
61+
62+ clientInfo := shared.NewPicoPipeClient()
63+ topic, _ := url.PathUnescape(shared.GetField(r, 0))
64+
65+ topic = cleanRegex.ReplaceAllString(topic, "")
66+
67+ logger.Info("sub", "topic", topic, "info", clientInfo, "pubsub", pubsub)
68+
69+ params := "-p"
70+ if r.URL.Query().Get("persist") == "true" {
71+ params += " -k"
72+ }
73+
74+ if accessList := r.URL.Query().Get("access"); accessList != "" {
75+ logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
76+ cleanList := cleanRegex.ReplaceAllString(accessList, "")
77+ params += fmt.Sprintf(" -a=%s", cleanList)
78+ }
79+
80+ id := uuid.NewString()
81+
82+ p, err := sshClient.AddSession(id, fmt.Sprintf("sub %s %s", params, topic), 0, -1, -1)
83+ if err != nil {
84+ logger.Error("sub error", "topic", topic, "info", clientInfo, "err", err.Error())
85+ http.Error(w, "server error", http.StatusInternalServerError)
86+ return
87+ }
88+
89+ go func() {
90+ <-r.Context().Done()
91+ err := sshClient.RemoveSession(id)
92+ if err != nil {
93+ logger.Error("sub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
94+ }
95+ }()
96+
97+ if mime := r.URL.Query().Get("mime"); mime != "" {
98+ w.Header().Add("Content-Type", r.URL.Query().Get("mime"))
99+ }
100+
101+ w.WriteHeader(http.StatusOK)
102+
103+ _, err = io.Copy(writeFlusher{w, http.NewResponseController(w)}, p)
104+ if err != nil {
105+ logger.Error("sub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
106+ return
107+ }
108+ }
109+}
110+
111+func handlePub(pubsub bool) http.HandlerFunc {
112+ return func(w http.ResponseWriter, r *http.Request) {
113+ logger := shared.GetLogger(r)
114+
115+ clientInfo := shared.NewPicoPipeClient()
116+ topic, _ := url.PathUnescape(shared.GetField(r, 0))
117+
118+ topic = cleanRegex.ReplaceAllString(topic, "")
119+
120+ logger.Info("pub", "topic", topic, "info", clientInfo)
121+
122+ params := "-p"
123+ if pubsub {
124+ params += " -b=false"
125+ }
126+
127+ if accessList := r.URL.Query().Get("access"); accessList != "" {
128+ logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
129+ cleanList := cleanRegex.ReplaceAllString(accessList, "")
130+ params += fmt.Sprintf(" -a=%s", cleanList)
131+ }
132+
133+ var wg sync.WaitGroup
134+
135+ reader := bufio.NewReaderSize(r.Body, 1)
136+
137+ first := make([]byte, 1)
138+
139+ nFirst, err := reader.Read(first)
140+ if err != nil && !errors.Is(err, io.EOF) {
141+ logger.Error("pub peek error", "topic", topic, "info", clientInfo, "err", err.Error())
142+ http.Error(w, "server error", http.StatusInternalServerError)
143+ return
144+ }
145+
146+ if nFirst == 0 {
147+ params += " -e"
148+ }
149+
150+ id := uuid.NewString()
151+
152+ p, err := sshClient.AddSession(id, fmt.Sprintf("pub %s %s", params, topic), 0, -1, -1)
153+ if err != nil {
154+ logger.Error("pub error", "topic", topic, "info", clientInfo, "err", err.Error())
155+ http.Error(w, "server error", http.StatusInternalServerError)
156+ return
157+ }
158+
159+ go func() {
160+ <-r.Context().Done()
161+ err := sshClient.RemoveSession(id)
162+ if err != nil {
163+ logger.Error("pub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
164+ }
165+ }()
166+
167+ var scanErr error
168+ scanStatus := http.StatusInternalServerError
169+
170+ wg.Add(1)
171+
172+ go func() {
173+ defer wg.Done()
174+
175+ s := bufio.NewScanner(p)
176+
177+ for s.Scan() {
178+ if s.Text() == "sending msg ..." {
179+ time.Sleep(10 * time.Millisecond)
180+ break
181+ }
182+
183+ if strings.HasPrefix(s.Text(), " ssh ") {
184+ f := strings.Fields(s.Text())
185+ if len(f) > 1 && f[len(f)-1] != topic {
186+ scanErr = fmt.Errorf("pub is not same as used, expected `%s` and received `%s`", topic, f[len(f)-1])
187+ scanStatus = http.StatusUnauthorized
188+ return
189+ }
190+ }
191+ }
192+
193+ if err := s.Err(); err != nil {
194+ scanErr = err
195+ return
196+ }
197+ }()
198+
199+ wg.Wait()
200+
201+ if scanErr != nil {
202+ logger.Error("pub scan error", "topic", topic, "info", clientInfo, "err", scanErr.Error())
203+
204+ msg := "server error"
205+ if scanStatus == http.StatusUnauthorized {
206+ msg = "access denied"
207+ }
208+
209+ http.Error(w, msg, scanStatus)
210+ return
211+ }
212+
213+ outer:
214+ for {
215+ select {
216+ case <-r.Context().Done():
217+ break outer
218+ default:
219+ n, err := p.Write(first)
220+ if err != nil {
221+ logger.Error("pub write error", "topic", topic, "info", clientInfo, "err", err.Error())
222+ http.Error(w, "server error", http.StatusInternalServerError)
223+ return
224+ }
225+
226+ if n > 0 {
227+ break outer
228+ }
229+
230+ time.Sleep(10 * time.Millisecond)
231+ }
232+ }
233+
234+ _, err = io.Copy(p, reader)
235+ if err != nil {
236+ logger.Error("pub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
237+ http.Error(w, "server error", http.StatusInternalServerError)
238+ return
239+ }
240+
241+ w.WriteHeader(http.StatusOK)
242+
243+ time.Sleep(10 * time.Millisecond)
244+ }
245+}
246+
247 func createMainRoutes(staticRoutes []shared.Route) []shared.Route {
248 routes := []shared.Route{
249 shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")),
250 shared.NewRoute("GET", "/check", shared.CheckHandler),
251 }
252
253+ pipeRoutes := []shared.Route{
254+ shared.NewRoute("GET", "/topic/(.+)", handleSub(false)),
255+ shared.NewRoute("POST", "/topic/(.+)", handlePub(false)),
256+ shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)),
257+ shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)),
258+ }
259+
260+ for _, route := range pipeRoutes {
261+ route.CorsEnabled = true
262+ routes = append(routes, route)
263+ }
264+
265 routes = append(
266 routes,
267 staticRoutes...,
268@@ -73,6 +305,31 @@ func StartApiServer() {
269 mainRoutes := createMainRoutes(staticRoutes)
270 subdomainRoutes := staticRoutes
271
272+ info := shared.NewPicoPipeClient()
273+
274+ client, err := pipe.NewClient(context.Background(), logger.With("info", info), info)
275+ if err != nil {
276+ panic(err)
277+ }
278+
279+ sshClient = client
280+
281+ pingSession, err := sshClient.AddSession("ping", "pub -b=false -c ping", 0, -1, -1)
282+ if err != nil {
283+ panic(err)
284+ }
285+
286+ go func() {
287+ for {
288+ _, err := pingSession.Write([]byte(fmt.Sprintf("%s: pipe-web ping\n", time.Now().UTC().Format(time.RFC3339))))
289+ if err != nil {
290+ logger.Error("pipe ping error", "err", err.Error())
291+ }
292+
293+ time.Sleep(5 * time.Second)
294+ }
295+ }()
296+
297 apiConfig := &shared.ApiConfig{
298 Cfg: cfg,
299 Dbpool: db,
+141,
-52
1@@ -172,6 +172,27 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
2 }
3 }
4
5+ pipeCtx, cancel := context.WithCancel(ctx)
6+
7+ go func() {
8+ defer cancel()
9+
10+ for {
11+ select {
12+ case <-pipeCtx.Done():
13+ return
14+ default:
15+ _, err := sesh.SendRequest("ping@pico.sh", false, nil)
16+ if err != nil {
17+ logger.Error("error sending ping", "err", err)
18+ return
19+ }
20+
21+ time.Sleep(5 * time.Second)
22+ }
23+ }
24+ }()
25+
26 cmd := strings.TrimSpace(args[0])
27 if cmd == "help" {
28 wish.Println(sesh, helpStr(toSshCmd(handler.Cfg)))
29@@ -186,6 +207,9 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
30 topicFilter := fmt.Sprintf("%s/", userName)
31 if isAdmin {
32 topicFilter = ""
33+ if len(args) > 1 {
34+ topicFilter = args[1]
35+ }
36 }
37
38 var channels []*psub.Channel
39@@ -212,7 +236,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
40 for _, channel := range channels {
41 extraData := ""
42
43- if accessList, ok := handler.Access.Load(channel.Topic); ok {
44+ if accessList, ok := handler.Access.Load(channel.Topic); ok && len(accessList) > 0 {
45 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
46 }
47
48@@ -240,7 +264,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
49 for waitingChannel, channelPubs := range waitingChannels {
50 extraData := ""
51
52- if accessList, ok := handler.Access.Load(waitingChannel); ok {
53+ if accessList, ok := handler.Access.Load(waitingChannel); ok && len(accessList) > 0 {
54 extraData += fmt.Sprintf(" (Access List: %s)", strings.Join(accessList, ", "))
55 }
56
57@@ -284,6 +308,8 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
58 public := pubCmd.Bool("p", false, "Publish message to public topic")
59 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
60 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
61+ clean := pubCmd.Bool("c", false, "Don't send status messages")
62+
63 if !flagCheck(pubCmd, topic, cmdArgs) {
64 return
65 }
66@@ -301,6 +327,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
67 "timeout", *timeout,
68 "topic", topic,
69 "access", *access,
70+ "clean", *clean,
71 )
72
73 var accessList []string
74@@ -331,35 +358,43 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
75 if *public {
76 name = toPublicTopic(topic)
77 msgFlag = "-p "
78+ withoutUser = name
79 } else {
80 withoutUser = topic
81 }
82 }
83
84- if len(accessList) > 0 && !*public {
85- _, loaded := handler.Access.LoadOrStore(name, accessList)
86- if !loaded {
87- defer func() {
88- handler.Access.Delete(name)
89- }()
90- }
91- }
92+ var accessListCreator bool
93
94- wish.Printf(
95- sesh,
96- "subscribe to this channel:\n ssh %s sub %s%s\n",
97- toSshCmd(handler.Cfg),
98- msgFlag,
99- topic,
100- )
101+ _, loaded := handler.Access.LoadOrStore(name, accessList)
102+ if !loaded {
103+ defer func() {
104+ handler.Access.Delete(name)
105+ }()
106+
107+ accessListCreator = true
108+ }
109
110- if accessList, ok := handler.Access.Load(withoutUser); !isAdmin && !*public && ok {
111- if checkAccess(accessList, userName, sesh) {
112+ if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
113+ if checkAccess(accessList, userName, sesh) || accessListCreator {
114 name = withoutUser
115+ } else if !*public {
116+ name = toTopic(userName, withoutUser)
117+ } else {
118+ topic = uuid.NewString()
119+ name = toPublicTopic(topic)
120 }
121 }
122
123- var pubCtx context.Context = ctx
124+ if !*clean {
125+ wish.Printf(
126+ sesh,
127+ "subscribe to this channel:\n ssh %s sub %s%s\n",
128+ toSshCmd(handler.Cfg),
129+ msgFlag,
130+ topic,
131+ )
132+ }
133
134 if *block {
135 count := 0
136@@ -383,18 +418,18 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
137 if tt > 0 {
138 termMsg += " " + tt.String()
139 }
140- wish.Println(sesh, termMsg)
141
142- downCtx, cancelFunc := context.WithCancel(ctx)
143- pubCtx = downCtx
144+ if !*clean {
145+ wish.Println(sesh, termMsg)
146+ }
147
148 ready := make(chan struct{})
149
150 go func() {
151 for {
152 select {
153- case <-ctx.Done():
154- cancelFunc()
155+ case <-pipeCtx.Done():
156+ cancel()
157 return
158 case <-time.After(1 * time.Millisecond):
159 count := 0
160@@ -419,10 +454,20 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
161
162 select {
163 case <-ready:
164- case <-ctx.Done():
165+ case <-pipeCtx.Done():
166 case <-time.After(tt):
167- cancelFunc()
168- wish.Fatalln(sesh, "timeout reached, exiting ...")
169+ cancel()
170+
171+ if !*clean {
172+ wish.Fatalln(sesh, "timeout reached, exiting ...")
173+ } else {
174+ err = sesh.Exit(1)
175+ if err != nil {
176+ logger.Error("error exiting session", "err", err)
177+ }
178+
179+ sesh.Close()
180+ }
181 }
182
183 newWaiters, _ := handler.Waiters.LoadOrStore(name, nil)
184@@ -445,10 +490,12 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
185 }
186 }
187
188- wish.Println(sesh, "sending msg ...")
189+ if !*clean {
190+ wish.Println(sesh, "sending msg ...")
191+ }
192
193 err = pubsub.Pub(
194- pubCtx,
195+ pipeCtx,
196 clientID,
197 rw,
198 []*psub.Channel{
199@@ -457,14 +504,20 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
200 *block,
201 )
202
203- wish.Println(sesh, "msg sent!")
204- if err != nil {
205+ if !*clean {
206+ wish.Println(sesh, "msg sent!")
207+ }
208+
209+ if err != nil && !*clean {
210 wish.Errorln(sesh, err)
211 }
212 } else if cmd == "sub" {
213 subCmd := flagSet("sub", sesh)
214+ access := subCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
215 public := subCmd.Bool("p", false, "Subscribe to a public topic")
216 keepAlive := subCmd.Bool("k", false, "Keep the subscription alive even after the publisher has died")
217+ clean := subCmd.Bool("c", false, "Don't send status messages")
218+
219 if !flagCheck(subCmd, topic, cmdArgs) {
220 return
221 }
222@@ -479,8 +532,16 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
223 "public", *public,
224 "keepAlive", *keepAlive,
225 "topic", topic,
226+ "clean", *clean,
227+ "access", *access,
228 )
229
230+ var accessList []string
231+
232+ if *access != "" {
233+ accessList = parseArgList(*access)
234+ }
235+
236 var withoutUser string
237 var name string
238
239@@ -490,19 +551,36 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
240 name = toTopic(userName, topic)
241 if *public {
242 name = toPublicTopic(topic)
243+ withoutUser = name
244 } else {
245 withoutUser = topic
246 }
247 }
248
249- if accessList, ok := handler.Access.Load(withoutUser); !isAdmin && !*public && ok {
250- if checkAccess(accessList, userName, sesh) {
251+ var accessListCreator bool
252+
253+ _, loaded := handler.Access.LoadOrStore(name, accessList)
254+ if !loaded {
255+ defer func() {
256+ handler.Access.Delete(name)
257+ }()
258+
259+ accessListCreator = true
260+ }
261+
262+ if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
263+ if checkAccess(accessList, userName, sesh) || accessListCreator {
264 name = withoutUser
265+ } else if !*public {
266+ name = toTopic(userName, withoutUser)
267+ } else {
268+ wish.Errorln(sesh, "access denied")
269+ return
270 }
271 }
272
273 err = pubsub.Sub(
274- ctx,
275+ pipeCtx,
276 clientID,
277 sesh,
278 []*psub.Channel{
279@@ -511,7 +589,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
280 *keepAlive,
281 )
282
283- if err != nil {
284+ if err != nil && !*clean {
285 wish.Errorln(sesh, err)
286 }
287 } else if cmd == "pipe" {
288@@ -519,6 +597,8 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
289 access := pipeCmd.String("a", "", "Comma separated list of pico usernames or ssh-key fingerprints to allow access to a topic")
290 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
291 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
292+ clean := pipeCmd.Bool("c", false, "Don't send status messages")
293+
294 if !flagCheck(pipeCmd, topic, cmdArgs) {
295 return
296 }
297@@ -534,6 +614,7 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
298 "replay", *replay,
299 "topic", topic,
300 "access", *access,
301+ "clean", *clean,
302 )
303
304 var accessList []string
305@@ -558,21 +639,35 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
306 if *public {
307 name = toPublicTopic(topic)
308 flagMsg = "-p "
309+ withoutUser = name
310 } else {
311 withoutUser = topic
312 }
313 }
314
315- if len(accessList) > 0 && !*public {
316- _, loaded := handler.Access.LoadOrStore(name, accessList)
317- if !loaded {
318- defer func() {
319- handler.Access.Delete(name)
320- }()
321+ var accessListCreator bool
322+
323+ _, loaded := handler.Access.LoadOrStore(name, accessList)
324+ if !loaded {
325+ defer func() {
326+ handler.Access.Delete(name)
327+ }()
328+
329+ accessListCreator = true
330+ }
331+
332+ if accessList, ok := handler.Access.Load(withoutUser); ok && len(accessList) > 0 && !isAdmin {
333+ if checkAccess(accessList, userName, sesh) || accessListCreator {
334+ name = withoutUser
335+ } else if !*public {
336+ name = toTopic(userName, withoutUser)
337+ } else {
338+ topic = uuid.NewString()
339+ name = toPublicTopic(topic)
340 }
341 }
342
343- if isCreator {
344+ if isCreator && !*clean {
345 wish.Printf(
346 sesh,
347 "subscribe to this topic:\n ssh %s sub %s%s\n",
348@@ -582,14 +677,8 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
349 )
350 }
351
352- if accessList, ok := handler.Access.Load(withoutUser); !isAdmin && !*public && ok {
353- if checkAccess(accessList, userName, sesh) {
354- name = withoutUser
355- }
356- }
357-
358 readErr, writeErr := pubsub.Pipe(
359- ctx,
360+ pipeCtx,
361 clientID,
362 sesh,
363 []*psub.Channel{
364@@ -598,11 +687,11 @@ func WishMiddleware(handler *CliHandler) wish.Middleware {
365 *replay,
366 )
367
368- if readErr != nil {
369+ if readErr != nil && !*clean {
370 wish.Errorln(sesh, "error reading from pipe", readErr)
371 }
372
373- if writeErr != nil {
374+ if writeErr != nil && !*clean {
375 wish.Errorln(sesh, "error writing to pipe", writeErr)
376 }
377 }
1@@ -42,6 +42,9 @@ type ConfigSite struct {
2 MinioUser string
3 MinioPass string
4 Space string
5+ Issuer string
6+ Secret string
7+ SecretWebhook string
8 AllowedExt []string
9 HiddenPosts []string
10 MaxSize uint64
1@@ -55,6 +55,19 @@ func CreatePProfRoutes(routes []Route) []Route {
2 )
3 }
4
5+func CreatePProfRoutesMux(mux *http.ServeMux) {
6+ mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
7+ mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
8+ mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
9+ mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
10+ mux.HandleFunc("GET /debug/pprof/(.*)", pprof.Index)
11+ mux.HandleFunc("POST /debug/pprof/cmdline", pprof.Cmdline)
12+ mux.HandleFunc("POST /debug/pprof/profile", pprof.Profile)
13+ mux.HandleFunc("POST /debug/pprof/symbol", pprof.Symbol)
14+ mux.HandleFunc("POST /debug/pprof/trace", pprof.Trace)
15+ mux.HandleFunc("POST /debug/pprof/(.*)", pprof.Index)
16+}
17+
18 type ApiConfig struct {
19 Cfg *ConfigSite
20 Dbpool db.DB
21@@ -62,6 +75,18 @@ type ApiConfig struct {
22 AnalyticsQueue chan *db.AnalyticsVisits
23 }
24
25+func (hc *ApiConfig) HasPrivilegedAccess(apiToken string) bool {
26+ user, err := hc.Dbpool.FindUserForToken(apiToken)
27+ if err != nil {
28+ return false
29+ }
30+ return hc.Dbpool.HasFeatureForUser(user.ID, "auth")
31+}
32+
33+func (hc *ApiConfig) HasPlusOrSpace(user *db.User, space string) bool {
34+ return hc.Dbpool.HasFeatureForUser(user.ID, "plus") || hc.Dbpool.HasFeatureForUser(user.ID, space)
35+}
36+
37 func (hc *ApiConfig) CreateCtx(prevCtx context.Context, subdomain string) context.Context {
38 ctx := context.WithValue(prevCtx, ctxLoggerKey{}, hc.Cfg.Logger)
39 ctx = context.WithValue(ctx, CtxSubdomainKey{}, subdomain)
40@@ -123,6 +148,10 @@ func GetSubdomainFromRequest(r *http.Request, domain, space string) string {
41 }
42
43 func findRouteConfig(r *http.Request, routes []Route, subdomainRoutes []Route, cfg *ConfigSite) ([]Route, string) {
44+ if len(subdomainRoutes) == 0 {
45+ return routes, ""
46+ }
47+
48 subdomain := GetSubdomainFromRequest(r, cfg.Domain, cfg.Space)
49 if subdomain == "" {
50 return routes, subdomain
51@@ -202,3 +231,11 @@ func GetCustomDomain(host string, space string) string {
52 func GetAnalyticsQueue(r *http.Request) chan *db.AnalyticsVisits {
53 return r.Context().Value(ctxAnalyticsQueue{}).(chan *db.AnalyticsVisits)
54 }
55+
56+func GetApiToken(r *http.Request) string {
57+ authHeader := r.Header.Get("authorization")
58+ if authHeader == "" {
59+ return ""
60+ }
61+ return strings.TrimPrefix(authHeader, "Bearer ")
62+}