Antonio Mika
·
22 Nov 24
api.go
1package pipe
2
3import (
4 "bufio"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "os"
12 "regexp"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/google/uuid"
18 "github.com/picosh/pico/db/postgres"
19 "github.com/picosh/pico/shared"
20 "github.com/picosh/utils/pipe"
21)
22
23var (
24 cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`)
25 sshClient *pipe.Client
26)
27
28func serveFile(file string, contentType string) http.HandlerFunc {
29 return func(w http.ResponseWriter, r *http.Request) {
30 logger := shared.GetLogger(r)
31 cfg := shared.GetCfg(r)
32
33 contents, err := os.ReadFile(cfg.StaticPath(fmt.Sprintf("public/%s", file)))
34 if err != nil {
35 logger.Error("could not read statis file", "err", err.Error())
36 http.Error(w, "file not found", 404)
37 }
38 w.Header().Add("Content-Type", contentType)
39
40 _, err = w.Write(contents)
41 if err != nil {
42 logger.Error("could not write static file", "err", err.Error())
43 http.Error(w, "server error", http.StatusInternalServerError)
44 }
45 }
46}
47
48func createStaticRoutes() []shared.Route {
49 return []shared.Route{
50 shared.NewRoute("GET", "/main.css", serveFile("main.css", "text/css")),
51 shared.NewRoute("GET", "/smol.css", serveFile("smol.css", "text/css")),
52 shared.NewRoute("GET", "/syntax.css", serveFile("syntax.css", "text/css")),
53 shared.NewRoute("GET", "/card.png", serveFile("card.png", "image/png")),
54 shared.NewRoute("GET", "/favicon-16x16.png", serveFile("favicon-16x16.png", "image/png")),
55 shared.NewRoute("GET", "/favicon-32x32.png", serveFile("favicon-32x32.png", "image/png")),
56 shared.NewRoute("GET", "/apple-touch-icon.png", serveFile("apple-touch-icon.png", "image/png")),
57 shared.NewRoute("GET", "/favicon.ico", serveFile("favicon.ico", "image/x-icon")),
58 shared.NewRoute("GET", "/robots.txt", serveFile("robots.txt", "text/plain")),
59 shared.NewRoute("GET", "/anim.js", serveFile("anim.js", "text/javascript")),
60 }
61}
62
63type writeFlusher struct {
64 responseWriter http.ResponseWriter
65 controller *http.ResponseController
66}
67
68func (w writeFlusher) Write(p []byte) (n int, err error) {
69 n, err = w.responseWriter.Write(p)
70 if err == nil {
71 err = w.controller.Flush()
72 }
73 return
74}
75
76var _ io.Writer = writeFlusher{}
77
78func handleSub(pubsub bool) http.HandlerFunc {
79 return func(w http.ResponseWriter, r *http.Request) {
80 logger := shared.GetLogger(r)
81
82 clientInfo := shared.NewPicoPipeClient()
83 topic, _ := url.PathUnescape(shared.GetField(r, 0))
84
85 topic = cleanRegex.ReplaceAllString(topic, "")
86
87 logger.Info("sub", "topic", topic, "info", clientInfo, "pubsub", pubsub)
88
89 params := "-p"
90 if r.URL.Query().Get("persist") == "true" {
91 params += " -k"
92 }
93
94 if accessList := r.URL.Query().Get("access"); accessList != "" {
95 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
96 cleanList := cleanRegex.ReplaceAllString(accessList, "")
97 params += fmt.Sprintf(" -a=%s", cleanList)
98 }
99
100 id := uuid.NewString()
101
102 p, err := sshClient.AddSession(id, fmt.Sprintf("sub %s %s", params, topic), 0, -1, -1)
103 if err != nil {
104 logger.Error("sub error", "topic", topic, "info", clientInfo, "err", err.Error())
105 http.Error(w, "server error", http.StatusInternalServerError)
106 return
107 }
108
109 go func() {
110 <-r.Context().Done()
111 err := sshClient.RemoveSession(id)
112 if err != nil {
113 logger.Error("sub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
114 }
115 }()
116
117 if mime := r.URL.Query().Get("mime"); mime != "" {
118 w.Header().Add("Content-Type", r.URL.Query().Get("mime"))
119 }
120
121 w.WriteHeader(http.StatusOK)
122
123 _, err = io.Copy(writeFlusher{w, http.NewResponseController(w)}, p)
124 if err != nil {
125 logger.Error("sub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
126 return
127 }
128 }
129}
130
131func handlePub(pubsub bool) http.HandlerFunc {
132 return func(w http.ResponseWriter, r *http.Request) {
133 logger := shared.GetLogger(r)
134
135 clientInfo := shared.NewPicoPipeClient()
136 topic, _ := url.PathUnescape(shared.GetField(r, 0))
137
138 topic = cleanRegex.ReplaceAllString(topic, "")
139
140 logger.Info("pub", "topic", topic, "info", clientInfo)
141
142 params := "-p"
143 if pubsub {
144 params += " -b=false"
145 }
146
147 if accessList := r.URL.Query().Get("access"); accessList != "" {
148 logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList)
149 cleanList := cleanRegex.ReplaceAllString(accessList, "")
150 params += fmt.Sprintf(" -a=%s", cleanList)
151 }
152
153 var wg sync.WaitGroup
154
155 reader := bufio.NewReaderSize(r.Body, 1)
156
157 first := make([]byte, 1)
158
159 nFirst, err := reader.Read(first)
160 if err != nil && !errors.Is(err, io.EOF) {
161 logger.Error("pub peek error", "topic", topic, "info", clientInfo, "err", err.Error())
162 http.Error(w, "server error", http.StatusInternalServerError)
163 return
164 }
165
166 if nFirst == 0 {
167 params += " -e"
168 }
169
170 id := uuid.NewString()
171
172 p, err := sshClient.AddSession(id, fmt.Sprintf("pub %s %s", params, topic), 0, -1, -1)
173 if err != nil {
174 logger.Error("pub error", "topic", topic, "info", clientInfo, "err", err.Error())
175 http.Error(w, "server error", http.StatusInternalServerError)
176 return
177 }
178
179 go func() {
180 <-r.Context().Done()
181 err := sshClient.RemoveSession(id)
182 if err != nil {
183 logger.Error("pub remove error", "topic", topic, "info", clientInfo, "err", err.Error())
184 }
185 }()
186
187 var scanErr error
188 scanStatus := http.StatusInternalServerError
189
190 wg.Add(1)
191
192 go func() {
193 defer wg.Done()
194
195 s := bufio.NewScanner(p)
196
197 for s.Scan() {
198 if s.Text() == "sending msg ..." {
199 time.Sleep(10 * time.Millisecond)
200 break
201 }
202
203 if strings.HasPrefix(s.Text(), " ssh ") {
204 f := strings.Fields(s.Text())
205 if len(f) > 1 && f[len(f)-1] != topic {
206 scanErr = fmt.Errorf("pub is not same as used, expected `%s` and received `%s`", topic, f[len(f)-1])
207 scanStatus = http.StatusUnauthorized
208 return
209 }
210 }
211 }
212
213 if err := s.Err(); err != nil {
214 scanErr = err
215 return
216 }
217 }()
218
219 wg.Wait()
220
221 if scanErr != nil {
222 logger.Error("pub scan error", "topic", topic, "info", clientInfo, "err", scanErr.Error())
223
224 msg := "server error"
225 if scanStatus == http.StatusUnauthorized {
226 msg = "access denied"
227 }
228
229 http.Error(w, msg, scanStatus)
230 return
231 }
232
233 outer:
234 for {
235 select {
236 case <-r.Context().Done():
237 break outer
238 default:
239 n, err := p.Write(first)
240 if err != nil {
241 logger.Error("pub write error", "topic", topic, "info", clientInfo, "err", err.Error())
242 http.Error(w, "server error", http.StatusInternalServerError)
243 return
244 }
245
246 if n > 0 {
247 break outer
248 }
249
250 time.Sleep(10 * time.Millisecond)
251 }
252 }
253
254 _, err = io.Copy(p, reader)
255 if err != nil {
256 logger.Error("pub copy error", "topic", topic, "info", clientInfo, "err", err.Error())
257 http.Error(w, "server error", http.StatusInternalServerError)
258 return
259 }
260
261 w.WriteHeader(http.StatusOK)
262
263 time.Sleep(10 * time.Millisecond)
264 }
265}
266
267func createMainRoutes(staticRoutes []shared.Route) []shared.Route {
268 routes := []shared.Route{
269 shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")),
270 shared.NewRoute("GET", "/check", shared.CheckHandler),
271 }
272
273 pipeRoutes := []shared.Route{
274 shared.NewRoute("GET", "/topic/(.+)", handleSub(false)),
275 shared.NewRoute("POST", "/topic/(.+)", handlePub(false)),
276 shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)),
277 shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)),
278 }
279
280 for _, route := range pipeRoutes {
281 route.CorsEnabled = true
282 routes = append(routes, route)
283 }
284
285 routes = append(
286 routes,
287 staticRoutes...,
288 )
289
290 return routes
291}
292
293func StartApiServer() {
294 cfg := NewConfigSite()
295 db := postgres.NewDB(cfg.DbURL, cfg.Logger)
296 defer db.Close()
297 logger := cfg.Logger
298
299 staticRoutes := createStaticRoutes()
300
301 if cfg.Debug {
302 staticRoutes = shared.CreatePProfRoutes(staticRoutes)
303 }
304
305 mainRoutes := createMainRoutes(staticRoutes)
306 subdomainRoutes := staticRoutes
307
308 info := shared.NewPicoPipeClient()
309
310 client, err := pipe.NewClient(context.Background(), logger.With("info", info), info)
311 if err != nil {
312 panic(err)
313 }
314
315 sshClient = client
316
317 pingSession, err := sshClient.AddSession("ping", "pub -b=false -c ping", 0, -1, -1)
318 if err != nil {
319 panic(err)
320 }
321
322 go func() {
323 for {
324 _, err := pingSession.Write([]byte(fmt.Sprintf("%s: pipe-web ping\n", time.Now().UTC().Format(time.RFC3339))))
325 if err != nil {
326 logger.Error("pipe ping error", "err", err.Error())
327 }
328
329 time.Sleep(5 * time.Second)
330 }
331 }()
332
333 apiConfig := &shared.ApiConfig{
334 Cfg: cfg,
335 Dbpool: db,
336 }
337 handler := shared.CreateServe(mainRoutes, subdomainRoutes, apiConfig)
338 router := http.HandlerFunc(handler)
339
340 portStr := fmt.Sprintf(":%s", cfg.Port)
341 logger.Info(
342 "Starting server on port",
343 "port", cfg.Port,
344 "domain", cfg.Domain,
345 )
346
347 logger.Error("listen", "err", http.ListenAndServe(portStr, router).Error())
348}