Eric Bower
·
15 Nov 24
api.go
1package shared
2
3import (
4 "encoding/json"
5 "fmt"
6 "html/template"
7 "net/http"
8 "os"
9 "strings"
10
11 "github.com/charmbracelet/ssh"
12 "github.com/picosh/pico/db"
13 "github.com/picosh/utils"
14)
15
16type SubdomainProps struct {
17 ProjectName string
18 Username string
19}
20
21func GetProjectFromSubdomain(subdomain string) (*SubdomainProps, error) {
22 props := &SubdomainProps{}
23 strs := strings.SplitN(subdomain, "-", 2)
24 props.Username = strs[0]
25 if len(strs) == 2 {
26 props.ProjectName = strs[1]
27 } else {
28 props.ProjectName = props.Username
29 }
30 return props, nil
31}
32
33func CorsHeaders(headers http.Header) {
34 headers.Add("Access-Control-Allow-Origin", "*")
35 headers.Add("Vary", "Origin")
36 headers.Add("Vary", "Access-Control-Request-Method")
37 headers.Add("Vary", "Access-Control-Request-Headers")
38 headers.Add("Access-Control-Allow-Headers", "Content-Type, Accept")
39 headers.Add("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, PATCH, DELETE")
40}
41
42func UnauthorizedHandler(w http.ResponseWriter, r *http.Request) {
43 http.Error(w, "You do not have access to this site", http.StatusUnauthorized)
44}
45
46type errPayload struct {
47 Message string `json:"message"`
48}
49
50func JSONError(w http.ResponseWriter, msg string, code int) {
51 w.Header().Set("Content-Type", "application/json")
52 w.WriteHeader(code)
53 _ = json.NewEncoder(w).Encode(errPayload{Message: msg})
54}
55
56type UserApi struct {
57 *db.User
58 Fingerprint string `json:"fingerprint"`
59}
60
61func NewUserApi(user *db.User, pubkey ssh.PublicKey) *UserApi {
62 return &UserApi{
63 User: user,
64 Fingerprint: utils.KeyForSha256(pubkey),
65 }
66}
67
68func CheckHandler(w http.ResponseWriter, r *http.Request) {
69 dbpool := GetDB(r)
70 cfg := GetCfg(r)
71
72 if cfg.IsCustomdomains() {
73 hostDomain := r.URL.Query().Get("domain")
74 appDomain := strings.Split(cfg.Domain, ":")[0]
75
76 if !strings.Contains(hostDomain, appDomain) {
77 subdomain := GetCustomDomain(hostDomain, cfg.Space)
78 if subdomain != "" {
79 u, err := dbpool.FindUserForName(subdomain)
80 if u != nil && err == nil {
81 w.WriteHeader(http.StatusOK)
82 return
83 }
84 }
85 }
86 }
87
88 w.WriteHeader(http.StatusNotFound)
89}
90
91func GetUsernameFromRequest(r *http.Request) string {
92 subdomain := GetSubdomain(r)
93 cfg := GetCfg(r)
94
95 if !cfg.IsSubdomains() || subdomain == "" {
96 return GetField(r, 0)
97 }
98 return subdomain
99}
100
101func ServeFile(file string, contentType string) http.HandlerFunc {
102 return func(w http.ResponseWriter, r *http.Request) {
103 logger := GetLogger(r)
104 cfg := GetCfg(r)
105
106 contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file)))
107 if err != nil {
108 logger.Error(err.Error())
109 http.Error(w, "file not found", 404)
110 }
111
112 w.Header().Add("Content-Type", contentType)
113
114 _, err = w.Write(contents)
115 if err != nil {
116 logger.Error(err.Error())
117 }
118 }
119}
120
121func minus(a, b int) int {
122 return a - b
123}
124
125func intRange(start, end int) []int {
126 n := end - start + 1
127 result := make([]int, n)
128 for i := 0; i < n; i++ {
129 result[i] = start + i
130 }
131 return result
132}
133
134var FuncMap = template.FuncMap{
135 "minus": minus,
136 "intRange": intRange,
137}
138
139func RenderTemplate(cfg *ConfigSite, templates []string) (*template.Template, error) {
140 files := make([]string, len(templates))
141 copy(files, templates)
142 files = append(
143 files,
144 cfg.StaticPath("html/footer.partial.tmpl"),
145 cfg.StaticPath("html/marketing-footer.partial.tmpl"),
146 cfg.StaticPath("html/base.layout.tmpl"),
147 )
148
149 ts, err := template.New("base").Funcs(FuncMap).ParseFiles(files...)
150 if err != nil {
151 return nil, err
152 }
153 return ts, nil
154}
155
156func CreatePageHandler(fname string) http.HandlerFunc {
157 return func(w http.ResponseWriter, r *http.Request) {
158 logger := GetLogger(r)
159 cfg := GetCfg(r)
160 ts, err := RenderTemplate(cfg, []string{cfg.StaticPath(fname)})
161
162 if err != nil {
163 logger.Error(err.Error())
164 http.Error(w, err.Error(), http.StatusInternalServerError)
165 return
166 }
167
168 data := PageData{
169 Site: *cfg.GetSiteData(),
170 }
171 err = ts.Execute(w, data)
172 if err != nil {
173 logger.Error(err.Error())
174 http.Error(w, err.Error(), http.StatusInternalServerError)
175 }
176 }
177}