Eric Bower
·
03 Dec 24
storage.go
1package postgres
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log/slog"
9 "math"
10 "strings"
11 "time"
12
13 "slices"
14
15 _ "github.com/lib/pq"
16 "github.com/picosh/pico/db"
17 "github.com/picosh/utils"
18)
19
20var PAGER_SIZE = 15
21
22var SelectPost = `
23 posts.id, user_id, app_users.name, filename, slug, title, text, description,
24 posts.created_at, publish_at, posts.updated_at, hidden, file_size, mime_type, shasum, data, expires_at, views`
25
26var (
27 sqlSelectPosts = fmt.Sprintf(`
28 SELECT %s
29 FROM posts
30 LEFT JOIN app_users ON app_users.id = posts.user_id`, SelectPost)
31
32 sqlSelectPostsBeforeDate = fmt.Sprintf(`
33 SELECT %s
34 FROM posts
35 LEFT JOIN app_users ON app_users.id = posts.user_id
36 WHERE publish_at::date <= $1 AND cur_space = $2`, SelectPost)
37
38 sqlSelectPostWithFilename = fmt.Sprintf(`
39 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
40 FROM posts
41 LEFT JOIN app_users ON app_users.id = posts.user_id
42 LEFT JOIN post_tags ON post_tags.post_id = posts.id
43 WHERE filename = $1 AND user_id = $2 AND cur_space = $3
44 GROUP BY %s`, SelectPost, SelectPost)
45
46 sqlSelectPostWithSlug = fmt.Sprintf(`
47 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
48 FROM posts
49 LEFT JOIN app_users ON app_users.id = posts.user_id
50 LEFT JOIN post_tags ON post_tags.post_id = posts.id
51 WHERE slug = $1 AND user_id = $2 AND cur_space = $3
52 GROUP BY %s`, SelectPost, SelectPost)
53
54 sqlSelectPost = fmt.Sprintf(`
55 SELECT %s
56 FROM posts
57 LEFT JOIN app_users ON app_users.id = posts.user_id
58 WHERE posts.id = $1`, SelectPost)
59
60 sqlSelectUpdatedPostsForUser = fmt.Sprintf(`
61 SELECT %s
62 FROM posts
63 LEFT JOIN app_users ON app_users.id = posts.user_id
64 WHERE user_id = $1 AND publish_at::date <= CURRENT_DATE AND cur_space = $2
65 ORDER BY posts.updated_at DESC`, SelectPost)
66
67 sqlSelectExpiredPosts = fmt.Sprintf(`
68 SELECT %s
69 FROM posts
70 LEFT JOIN app_users ON app_users.id = posts.user_id
71 WHERE
72 cur_space = $1 AND
73 expires_at <= now();
74 `, SelectPost)
75
76 sqlSelectPostsForUser = fmt.Sprintf(`
77 SELECT %s, STRING_AGG(coalesce(post_tags.name, ''), ',') tags
78 FROM posts
79 LEFT JOIN app_users ON app_users.id = posts.user_id
80 LEFT JOIN post_tags ON post_tags.post_id = posts.id
81 WHERE
82 hidden = FALSE AND
83 user_id = $1 AND
84 publish_at::date <= CURRENT_DATE AND
85 cur_space = $2
86 GROUP BY %s
87 ORDER BY publish_at DESC, slug DESC
88 LIMIT $3 OFFSET $4`, SelectPost, SelectPost)
89
90 sqlSelectAllPostsForUser = fmt.Sprintf(`
91 SELECT %s
92 FROM posts
93 LEFT JOIN app_users ON app_users.id = posts.user_id
94 WHERE
95 user_id = $1 AND
96 cur_space = $2
97 ORDER BY publish_at DESC`, SelectPost)
98
99 sqlSelectPostsByTag = `
100 SELECT
101 posts.id,
102 user_id,
103 filename,
104 slug,
105 title,
106 text,
107 description,
108 publish_at,
109 app_users.name as username,
110 posts.updated_at,
111 posts.mime_type
112 FROM posts
113 LEFT JOIN app_users ON app_users.id = posts.user_id
114 LEFT JOIN post_tags ON post_tags.post_id = posts.id
115 WHERE
116 post_tags.name = $3 AND
117 publish_at::date <= CURRENT_DATE AND
118 cur_space = $4
119 ORDER BY publish_at DESC
120 LIMIT $1 OFFSET $2`
121
122 sqlSelectUserPostsByTag = fmt.Sprintf(`
123 SELECT %s
124 FROM posts
125 LEFT JOIN app_users ON app_users.id = posts.user_id
126 LEFT JOIN post_tags ON post_tags.post_id = posts.id
127 WHERE
128 hidden = FALSE AND
129 user_id = $1 AND
130 (post_tags.name = $2 OR hidden = true) AND
131 publish_at::date <= CURRENT_DATE AND
132 cur_space = $3
133 ORDER BY publish_at DESC
134 LIMIT $4 OFFSET $5`, SelectPost)
135)
136
137const (
138 sqlSelectPublicKey = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE public_key = $1`
139 sqlSelectPublicKeys = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE user_id = $1 ORDER BY created_at ASC`
140 sqlSelectUser = `SELECT id, name, created_at FROM app_users WHERE id = $1`
141 sqlSelectUserForName = `SELECT id, name, created_at FROM app_users WHERE name = $1`
142 sqlSelectUserForNameAndKey = `SELECT app_users.id, app_users.name, app_users.created_at, public_keys.id as pk_id, public_keys.public_key, public_keys.created_at as pk_created_at FROM app_users LEFT JOIN public_keys ON public_keys.user_id = app_users.id WHERE app_users.name = $1 AND public_keys.public_key = $2`
143 sqlSelectUsers = `SELECT id, name, created_at FROM app_users ORDER BY name ASC`
144
145 sqlSelectUserForToken = `
146 SELECT app_users.id, app_users.name, app_users.created_at
147 FROM app_users
148 LEFT JOIN tokens ON tokens.user_id = app_users.id
149 WHERE tokens.token = $1 AND tokens.expires_at > NOW()`
150 sqlInsertToken = `INSERT INTO tokens (user_id, name) VALUES($1, $2) RETURNING token;`
151 sqlRemoveToken = `DELETE FROM tokens WHERE id = $1`
152 sqlSelectTokensForUser = `SELECT id, user_id, name, created_at, expires_at FROM tokens WHERE user_id = $1`
153 sqlSelectTokenByNameForUser = `SELECT token FROM tokens WHERE user_id = $1 AND name = $2`
154
155 sqlSelectTotalUsers = `SELECT count(id) FROM app_users`
156 sqlSelectUsersAfterDate = `SELECT count(id) FROM app_users WHERE created_at >= $1`
157 sqlSelectTotalPosts = `SELECT count(id) FROM posts WHERE cur_space = $1`
158 sqlSelectTotalPostsAfterDate = `SELECT count(id) FROM posts WHERE created_at >= $1 AND cur_space = $2`
159 sqlSelectUsersWithPost = `SELECT count(app_users.id) FROM app_users WHERE EXISTS (SELECT 1 FROM posts WHERE user_id = app_users.id AND cur_space = $1);`
160
161 sqlSelectFeatureForUser = `SELECT id, user_id, payment_history_id, name, data, created_at, expires_at FROM feature_flags WHERE user_id = $1 AND name = $2 ORDER BY expires_at DESC LIMIT 1`
162 sqlSelectSizeForUser = `SELECT COALESCE(sum(file_size), 0) FROM posts WHERE user_id = $1`
163
164 sqlSelectPostIdByAliasSlug = `SELECT post_id FROM post_aliases WHERE slug = $1`
165 sqlSelectTagPostCount = `
166 SELECT count(posts.id)
167 FROM posts
168 LEFT JOIN post_tags ON post_tags.post_id = posts.id
169 WHERE hidden = FALSE AND cur_space=$1 and post_tags.name = $2`
170 sqlSelectPostCount = `SELECT count(id) FROM posts WHERE hidden = FALSE AND cur_space=$1`
171 sqlSelectAllUpdatedPosts = `
172 SELECT
173 posts.id,
174 user_id,
175 filename,
176 slug,
177 title,
178 text,
179 description,
180 publish_at,
181 app_users.name as username,
182 posts.updated_at,
183 posts.mime_type
184 FROM posts
185 LEFT JOIN app_users ON app_users.id = posts.user_id
186 WHERE hidden = FALSE AND publish_at::date <= CURRENT_DATE AND cur_space = $3
187 ORDER BY updated_at DESC
188 LIMIT $1 OFFSET $2`
189 // add some users to deny list since they are robogenerating a bunch of posts
190 // per day and are creating a lot of noise.
191 sqlSelectPostsByRank = `
192 SELECT
193 posts.id,
194 user_id,
195 filename,
196 slug,
197 title,
198 text,
199 description,
200 publish_at,
201 app_users.name as username,
202 posts.updated_at,
203 posts.mime_type
204 FROM posts
205 LEFT JOIN app_users ON app_users.id = posts.user_id
206 WHERE
207 hidden = FALSE AND
208 publish_at::date <= CURRENT_DATE AND
209 cur_space = $3 AND
210 app_users.name NOT IN ('algiegray', 'mrrccc')
211 ORDER BY publish_at DESC
212 LIMIT $1 OFFSET $2`
213
214 sqlSelectPopularTags = `
215 SELECT name, count(post_id) as "tally"
216 FROM post_tags
217 LEFT JOIN posts ON posts.id = post_id
218 WHERE posts.cur_space = $1
219 GROUP BY name
220 ORDER BY tally DESC
221 LIMIT 5`
222 sqlSelectTagsForPost = `SELECT name FROM post_tags WHERE post_id=$1`
223 sqlSelectFeedItemsByPost = `SELECT id, post_id, guid, data, created_at FROM feed_items WHERE post_id=$1`
224
225 sqlInsertPublicKey = `INSERT INTO public_keys (user_id, public_key) VALUES ($1, $2)`
226 sqlInsertPost = `
227 INSERT INTO posts
228 (user_id, filename, slug, title, text, description, publish_at, hidden, cur_space,
229 file_size, mime_type, shasum, data, expires_at, updated_at)
230 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
231 RETURNING id`
232 sqlInsertUser = `INSERT INTO app_users (name) VALUES($1) returning id`
233 sqlInsertTag = `INSERT INTO post_tags (post_id, name) VALUES($1, $2) RETURNING id;`
234 sqlInsertAliases = `INSERT INTO post_aliases (post_id, slug) VALUES($1, $2) RETURNING id;`
235 sqlInsertFeedItems = `INSERT INTO feed_items (post_id, guid, data) VALUES ($1, $2, $3) RETURNING id;`
236
237 sqlUpdatePost = `
238 UPDATE posts
239 SET slug = $1, title = $2, text = $3, description = $4, updated_at = $5, publish_at = $6,
240 file_size = $7, shasum = $8, data = $9, hidden = $11, expires_at = $12
241 WHERE id = $10`
242 sqlUpdateUserName = `UPDATE app_users SET name = $1 WHERE id = $2`
243 sqlIncrementViews = `UPDATE posts SET views = views + 1 WHERE id = $1 RETURNING views`
244
245 sqlRemoveAliasesByPost = `DELETE FROM post_aliases WHERE post_id = $1`
246 sqlRemoveTagsByPost = `DELETE FROM post_tags WHERE post_id = $1`
247 sqlRemovePosts = `DELETE FROM posts WHERE id = ANY($1::uuid[])`
248 sqlRemoveKeys = `DELETE FROM public_keys WHERE id = ANY($1::uuid[])`
249 sqlRemoveUsers = `DELETE FROM app_users WHERE id = ANY($1::uuid[])`
250
251 sqlInsertProject = `INSERT INTO projects (user_id, name, project_dir) VALUES ($1, $2, $3) RETURNING id;`
252 sqlUpdateProject = `UPDATE projects SET updated_at = $3 WHERE user_id = $1 AND name = $2;`
253 sqlUpdateProjectAcl = `UPDATE projects SET acl = $3, updated_at = $4 WHERE user_id = $1 AND name = $2;`
254 sqlFindProjectByName = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = $2;`
255 sqlSelectProjectCount = `SELECT count(id) FROM projects`
256 sqlFindProjectsByUser = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 ORDER BY name ASC, updated_at DESC;`
257 sqlFindProjectsByPrefix = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name = project_dir AND name ILIKE $2 ORDER BY updated_at ASC, name ASC;`
258 sqlFindProjectLinks = `SELECT id, user_id, name, project_dir, acl, blocked, created_at, updated_at FROM projects WHERE user_id = $1 AND name != project_dir AND project_dir = $2 ORDER BY name ASC;`
259 sqlLinkToProject = `UPDATE projects SET project_dir = $1, updated_at = $2 WHERE id = $3;`
260 sqlRemoveProject = `DELETE FROM projects WHERE id = $1;`
261)
262
263type PsqlDB struct {
264 Logger *slog.Logger
265 Db *sql.DB
266}
267
268type RowScanner interface {
269 Scan(dest ...any) error
270}
271
272func CreatePostFromRow(r RowScanner) (*db.Post, error) {
273 post := &db.Post{}
274 err := r.Scan(
275 &post.ID,
276 &post.UserID,
277 &post.Username,
278 &post.Filename,
279 &post.Slug,
280 &post.Title,
281 &post.Text,
282 &post.Description,
283 &post.CreatedAt,
284 &post.PublishAt,
285 &post.UpdatedAt,
286 &post.Hidden,
287 &post.FileSize,
288 &post.MimeType,
289 &post.Shasum,
290 &post.Data,
291 &post.ExpiresAt,
292 &post.Views,
293 )
294 if err != nil {
295 return nil, err
296 }
297 return post, nil
298}
299
300func CreatePostWithTagsFromRow(r RowScanner) (*db.Post, error) {
301 post := &db.Post{}
302 tagStr := ""
303 err := r.Scan(
304 &post.ID,
305 &post.UserID,
306 &post.Username,
307 &post.Filename,
308 &post.Slug,
309 &post.Title,
310 &post.Text,
311 &post.Description,
312 &post.CreatedAt,
313 &post.PublishAt,
314 &post.UpdatedAt,
315 &post.Hidden,
316 &post.FileSize,
317 &post.MimeType,
318 &post.Shasum,
319 &post.Data,
320 &post.ExpiresAt,
321 &post.Views,
322 &tagStr,
323 )
324 if err != nil {
325 return nil, err
326 }
327
328 tags := strings.Split(tagStr, ",")
329 for _, tag := range tags {
330 tg := strings.TrimSpace(tag)
331 if tg == "" {
332 continue
333 }
334 post.Tags = append(post.Tags, tg)
335 }
336
337 return post, nil
338}
339
340func NewDB(databaseUrl string, logger *slog.Logger) *PsqlDB {
341 var err error
342 d := &PsqlDB{
343 Logger: logger,
344 }
345 d.Logger.Info("Connecting to postgres", "databaseUrl", databaseUrl)
346
347 db, err := sql.Open("postgres", databaseUrl)
348 if err != nil {
349 d.Logger.Error(err.Error())
350 }
351 d.Db = db
352 return d
353}
354
355func (me *PsqlDB) RegisterUser(username, pubkey, comment string) (*db.User, error) {
356 lowerName := strings.ToLower(username)
357 valid, err := me.ValidateName(lowerName)
358 if !valid {
359 return nil, err
360 }
361
362 ctx := context.Background()
363 tx, err := me.Db.BeginTx(ctx, nil)
364 if err != nil {
365 return nil, err
366 }
367 defer func() {
368 err = tx.Rollback()
369 }()
370
371 stmt, err := tx.Prepare(sqlInsertUser)
372 if err != nil {
373 return nil, err
374 }
375 defer stmt.Close()
376
377 var id string
378 err = stmt.QueryRow(lowerName).Scan(&id)
379 if err != nil {
380 return nil, err
381 }
382
383 err = me.InsertPublicKey(id, pubkey, comment, tx)
384 if err != nil {
385 return nil, err
386 }
387
388 err = tx.Commit()
389 if err != nil {
390 return nil, err
391 }
392
393 return me.FindUserForKey(username, pubkey)
394}
395
396func (me *PsqlDB) RemoveUsers(userIDs []string) error {
397 param := "{" + strings.Join(userIDs, ",") + "}"
398 _, err := me.Db.Exec(sqlRemoveUsers, param)
399 return err
400}
401
402func (me *PsqlDB) InsertPublicKey(userID, key, name string, tx *sql.Tx) error {
403 pk, _ := me.FindPublicKeyForKey(key)
404 if pk != nil {
405 return db.ErrPublicKeyTaken
406 }
407 query := `INSERT INTO public_keys (user_id, public_key, name) VALUES ($1, $2, $3)`
408 var err error
409 if tx != nil {
410 _, err = tx.Exec(query, userID, key, name)
411 } else {
412 _, err = me.Db.Exec(query, userID, key, name)
413 }
414 if err != nil {
415 return err
416 }
417
418 return nil
419}
420
421func (me *PsqlDB) UpdatePublicKey(pubkeyID, name string) (*db.PublicKey, error) {
422 pk, err := me.FindPublicKey(pubkeyID)
423 if err != nil {
424 return nil, err
425 }
426
427 query := `UPDATE public_keys SET name=$1 WHERE id=$2;`
428 _, err = me.Db.Exec(query, name, pk.ID)
429 if err != nil {
430 return nil, err
431 }
432
433 pk, err = me.FindPublicKey(pubkeyID)
434 if err != nil {
435 return nil, err
436 }
437 return pk, nil
438}
439
440func (me *PsqlDB) FindPublicKeyForKey(key string) (*db.PublicKey, error) {
441 var keys []*db.PublicKey
442 rs, err := me.Db.Query(sqlSelectPublicKey, key)
443 if err != nil {
444 return nil, err
445 }
446
447 for rs.Next() {
448 pk := &db.PublicKey{}
449 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
450 if err != nil {
451 return nil, err
452 }
453
454 keys = append(keys, pk)
455 }
456
457 if rs.Err() != nil {
458 return nil, rs.Err()
459 }
460
461 if len(keys) == 0 {
462 return nil, fmt.Errorf("pubkey not found in our database: [%s]", key)
463 }
464
465 // When we run PublicKeyForKey and there are multiple public keys returned from the database
466 // that should mean that we don't have the correct username for this public key.
467 // When that happens we need to reject the authentication and ask the user to provide the correct
468 // username when using ssh. So instead of `ssh <domain>` it should be `ssh user@<domain>`
469 if len(keys) > 1 {
470 return nil, &db.ErrMultiplePublicKeys{}
471 }
472
473 return keys[0], nil
474}
475
476func (me *PsqlDB) FindPublicKey(pubkeyID string) (*db.PublicKey, error) {
477 var keys []*db.PublicKey
478 rs, err := me.Db.Query(`SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE id = $1`, pubkeyID)
479 if err != nil {
480 return nil, err
481 }
482
483 for rs.Next() {
484 pk := &db.PublicKey{}
485 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
486 if err != nil {
487 return nil, err
488 }
489
490 keys = append(keys, pk)
491 }
492
493 if rs.Err() != nil {
494 return nil, rs.Err()
495 }
496
497 if len(keys) == 0 {
498 return nil, errors.New("no public keys found for key provided")
499 }
500
501 return keys[0], nil
502}
503
504func (me *PsqlDB) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
505 var keys []*db.PublicKey
506 rs, err := me.Db.Query(sqlSelectPublicKeys, user.ID)
507 if err != nil {
508 return keys, err
509 }
510 for rs.Next() {
511 pk := &db.PublicKey{}
512 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
513 if err != nil {
514 return keys, err
515 }
516
517 keys = append(keys, pk)
518 }
519 if rs.Err() != nil {
520 return keys, rs.Err()
521 }
522 return keys, nil
523}
524
525func (me *PsqlDB) RemoveKeys(keyIDs []string) error {
526 param := "{" + strings.Join(keyIDs, ",") + "}"
527 _, err := me.Db.Exec(sqlRemoveKeys, param)
528 return err
529}
530
531func (me *PsqlDB) FindSiteAnalytics(space string) (*db.Analytics, error) {
532 analytics := &db.Analytics{}
533 r := me.Db.QueryRow(sqlSelectTotalUsers)
534 err := r.Scan(&analytics.TotalUsers)
535 if err != nil {
536 return nil, err
537 }
538
539 r = me.Db.QueryRow(sqlSelectTotalPosts, space)
540 err = r.Scan(&analytics.TotalPosts)
541 if err != nil {
542 return nil, err
543 }
544
545 now := time.Now()
546 year, month, _ := now.Date()
547 begMonth := time.Date(year, month, 1, 0, 0, 0, 0, now.Location())
548
549 r = me.Db.QueryRow(sqlSelectTotalPostsAfterDate, begMonth, space)
550 err = r.Scan(&analytics.PostsLastMonth)
551 if err != nil {
552 return nil, err
553 }
554
555 r = me.Db.QueryRow(sqlSelectUsersAfterDate, begMonth)
556 err = r.Scan(&analytics.UsersLastMonth)
557 if err != nil {
558 return nil, err
559 }
560
561 r = me.Db.QueryRow(sqlSelectUsersWithPost, space)
562 err = r.Scan(&analytics.UsersWithPost)
563 if err != nil {
564 return nil, err
565 }
566
567 return analytics, nil
568}
569
570func (me *PsqlDB) FindPostsBeforeDate(date *time.Time, space string) ([]*db.Post, error) {
571 // now := time.Now()
572 // expired := now.AddDate(0, 0, -3)
573 var posts []*db.Post
574 rs, err := me.Db.Query(sqlSelectPostsBeforeDate, date, space)
575 if err != nil {
576 return posts, err
577 }
578 for rs.Next() {
579 post, err := CreatePostFromRow(rs)
580 if err != nil {
581 return nil, err
582 }
583
584 posts = append(posts, post)
585 }
586 if rs.Err() != nil {
587 return posts, rs.Err()
588 }
589 return posts, nil
590}
591
592func (me *PsqlDB) FindUserForKey(username string, key string) (*db.User, error) {
593 me.Logger.Info("attempting to find user with only public key", "key", key)
594 pk, err := me.FindPublicKeyForKey(key)
595 if err == nil {
596 me.Logger.Info("found pubkey, looking for user", "key", key, "userId", pk.UserID)
597 user, err := me.FindUser(pk.UserID)
598 if err != nil {
599 return nil, err
600 }
601 user.PublicKey = pk
602 return user, nil
603 }
604
605 if errors.Is(err, &db.ErrMultiplePublicKeys{}) {
606 me.Logger.Info("detected multiple users with same public key", "user", username)
607 user, err := me.FindUserForNameAndKey(username, key)
608 if err != nil {
609 me.Logger.Info("could not find user by username and public key", "user", username, "key", key)
610 // this is a little hacky but if we cannot find a user by name and public key
611 // then we return the multiple keys detected error so the user knows to specify their
612 // when logging in
613 return nil, &db.ErrMultiplePublicKeys{}
614 }
615 return user, nil
616 }
617
618 return nil, err
619}
620
621func (me *PsqlDB) FindUser(userID string) (*db.User, error) {
622 user := &db.User{}
623 var un sql.NullString
624 r := me.Db.QueryRow(sqlSelectUser, userID)
625 err := r.Scan(&user.ID, &un, &user.CreatedAt)
626 if err != nil {
627 return nil, err
628 }
629 if un.Valid {
630 user.Name = un.String
631 }
632 return user, nil
633}
634
635func (me *PsqlDB) ValidateName(name string) (bool, error) {
636 lower := strings.ToLower(name)
637 if slices.Contains(db.DenyList, lower) {
638 return false, fmt.Errorf("%s is on deny list: %w", lower, db.ErrNameDenied)
639 }
640 v := db.NameValidator.MatchString(lower)
641 if !v {
642 return false, fmt.Errorf("%s is invalid: %w", lower, db.ErrNameInvalid)
643 }
644 user, _ := me.FindUserForName(lower)
645 if user == nil {
646 return true, nil
647 }
648 return false, fmt.Errorf("%s already taken: %w", lower, db.ErrNameTaken)
649}
650
651func (me *PsqlDB) FindUserForName(name string) (*db.User, error) {
652 user := &db.User{}
653 r := me.Db.QueryRow(sqlSelectUserForName, strings.ToLower(name))
654 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
655 if err != nil {
656 return nil, err
657 }
658 return user, nil
659}
660
661func (me *PsqlDB) FindUserForNameAndKey(name string, key string) (*db.User, error) {
662 user := &db.User{}
663 pk := &db.PublicKey{}
664
665 r := me.Db.QueryRow(sqlSelectUserForNameAndKey, strings.ToLower(name), key)
666 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt, &pk.ID, &pk.Key, &pk.CreatedAt)
667 if err != nil {
668 return nil, err
669 }
670
671 user.PublicKey = pk
672 return user, nil
673}
674
675func (me *PsqlDB) FindUserForToken(token string) (*db.User, error) {
676 user := &db.User{}
677
678 r := me.Db.QueryRow(sqlSelectUserForToken, token)
679 err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
680 if err != nil {
681 return nil, err
682 }
683
684 return user, nil
685}
686
687func (me *PsqlDB) SetUserName(userID string, name string) error {
688 lowerName := strings.ToLower(name)
689 valid, err := me.ValidateName(lowerName)
690 if !valid {
691 return err
692 }
693
694 _, err = me.Db.Exec(sqlUpdateUserName, lowerName, userID)
695 return err
696}
697
698func (me *PsqlDB) FindPostWithFilename(filename string, persona_id string, space string) (*db.Post, error) {
699 r := me.Db.QueryRow(sqlSelectPostWithFilename, filename, persona_id, space)
700 post, err := CreatePostWithTagsFromRow(r)
701 if err != nil {
702 return nil, err
703 }
704
705 return post, nil
706}
707
708func (me *PsqlDB) FindPostWithSlug(slug string, user_id string, space string) (*db.Post, error) {
709 r := me.Db.QueryRow(sqlSelectPostWithSlug, slug, user_id, space)
710 post, err := CreatePostWithTagsFromRow(r)
711 if err != nil {
712 // attempt to find post inside post_aliases
713 alias := me.Db.QueryRow(sqlSelectPostIdByAliasSlug, slug)
714 postID := ""
715 err := alias.Scan(&postID)
716 if err != nil {
717 return nil, err
718 }
719
720 return me.FindPost(postID)
721 }
722
723 return post, nil
724}
725
726func (me *PsqlDB) FindPost(postID string) (*db.Post, error) {
727 r := me.Db.QueryRow(sqlSelectPost, postID)
728 post, err := CreatePostFromRow(r)
729 if err != nil {
730 return nil, err
731 }
732
733 return post, nil
734}
735
736func (me *PsqlDB) postPager(rs *sql.Rows, pageNum int, space string, tag string) (*db.Paginate[*db.Post], error) {
737 var posts []*db.Post
738 for rs.Next() {
739 post := &db.Post{}
740 err := rs.Scan(
741 &post.ID,
742 &post.UserID,
743 &post.Filename,
744 &post.Slug,
745 &post.Title,
746 &post.Text,
747 &post.Description,
748 &post.PublishAt,
749 &post.Username,
750 &post.UpdatedAt,
751 &post.MimeType,
752 )
753 if err != nil {
754 return nil, err
755 }
756
757 posts = append(posts, post)
758 }
759 if rs.Err() != nil {
760 return nil, rs.Err()
761 }
762
763 var count int
764 var err error
765 if tag == "" {
766 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
767 } else {
768 err = me.Db.QueryRow(sqlSelectTagPostCount, space, tag).Scan(&count)
769 }
770 if err != nil {
771 return nil, err
772 }
773
774 pager := &db.Paginate[*db.Post]{
775 Data: posts,
776 Total: int(math.Ceil(float64(count) / float64(pageNum))),
777 }
778
779 return pager, nil
780}
781
782func (me *PsqlDB) FindAllPosts(page *db.Pager, space string) (*db.Paginate[*db.Post], error) {
783 rs, err := me.Db.Query(sqlSelectPostsByRank, page.Num, page.Num*page.Page, space)
784 if err != nil {
785 return nil, err
786 }
787 return me.postPager(rs, page.Num, space, "")
788}
789
790func (me *PsqlDB) FindAllUpdatedPosts(page *db.Pager, space string) (*db.Paginate[*db.Post], error) {
791 rs, err := me.Db.Query(sqlSelectAllUpdatedPosts, page.Num, page.Num*page.Page, space)
792 if err != nil {
793 return nil, err
794 }
795 return me.postPager(rs, page.Num, space, "")
796}
797
798func (me *PsqlDB) InsertPost(post *db.Post) (*db.Post, error) {
799 var id string
800 err := me.Db.QueryRow(
801 sqlInsertPost,
802 post.UserID,
803 post.Filename,
804 post.Slug,
805 post.Title,
806 post.Text,
807 post.Description,
808 post.PublishAt,
809 post.Hidden,
810 post.Space,
811 post.FileSize,
812 post.MimeType,
813 post.Shasum,
814 post.Data,
815 post.ExpiresAt,
816 post.UpdatedAt,
817 ).Scan(&id)
818 if err != nil {
819 return nil, err
820 }
821
822 return me.FindPost(id)
823}
824
825func (me *PsqlDB) UpdatePost(post *db.Post) (*db.Post, error) {
826 _, err := me.Db.Exec(
827 sqlUpdatePost,
828 post.Slug,
829 post.Title,
830 post.Text,
831 post.Description,
832 post.UpdatedAt,
833 post.PublishAt,
834 post.FileSize,
835 post.Shasum,
836 post.Data,
837 post.ID,
838 post.Hidden,
839 post.ExpiresAt,
840 )
841 if err != nil {
842 return nil, err
843 }
844
845 return me.FindPost(post.ID)
846}
847
848func (me *PsqlDB) RemovePosts(postIDs []string) error {
849 param := "{" + strings.Join(postIDs, ",") + "}"
850 _, err := me.Db.Exec(sqlRemovePosts, param)
851 return err
852}
853
854func (me *PsqlDB) FindPostsForUser(page *db.Pager, userID string, space string) (*db.Paginate[*db.Post], error) {
855 var posts []*db.Post
856 rs, err := me.Db.Query(
857 sqlSelectPostsForUser,
858 userID,
859 space,
860 page.Num,
861 page.Num*page.Page,
862 )
863 if err != nil {
864 return nil, err
865 }
866 for rs.Next() {
867 post, err := CreatePostWithTagsFromRow(rs)
868 if err != nil {
869 return nil, err
870 }
871
872 posts = append(posts, post)
873 }
874
875 if rs.Err() != nil {
876 return nil, rs.Err()
877 }
878
879 var count int
880 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
881 if err != nil {
882 return nil, err
883 }
884
885 pager := &db.Paginate[*db.Post]{
886 Data: posts,
887 Total: int(math.Ceil(float64(count) / float64(page.Num))),
888 }
889 return pager, nil
890}
891
892func (me *PsqlDB) FindAllPostsForUser(userID string, space string) ([]*db.Post, error) {
893 var posts []*db.Post
894 rs, err := me.Db.Query(sqlSelectAllPostsForUser, userID, space)
895 if err != nil {
896 return posts, err
897 }
898 for rs.Next() {
899 post, err := CreatePostFromRow(rs)
900 if err != nil {
901 return nil, err
902 }
903
904 posts = append(posts, post)
905 }
906 if rs.Err() != nil {
907 return posts, rs.Err()
908 }
909 return posts, nil
910}
911
912func (me *PsqlDB) FindPosts() ([]*db.Post, error) {
913 var posts []*db.Post
914 rs, err := me.Db.Query(sqlSelectPosts)
915 if err != nil {
916 return posts, err
917 }
918 for rs.Next() {
919 post, err := CreatePostFromRow(rs)
920 if err != nil {
921 return nil, err
922 }
923
924 posts = append(posts, post)
925 }
926 if rs.Err() != nil {
927 return posts, rs.Err()
928 }
929 return posts, nil
930}
931
932func (me *PsqlDB) FindExpiredPosts(space string) ([]*db.Post, error) {
933 var posts []*db.Post
934 rs, err := me.Db.Query(sqlSelectExpiredPosts, space)
935 if err != nil {
936 return posts, err
937 }
938 for rs.Next() {
939 post, err := CreatePostFromRow(rs)
940 if err != nil {
941 return nil, err
942 }
943
944 posts = append(posts, post)
945 }
946 if rs.Err() != nil {
947 return posts, rs.Err()
948 }
949 return posts, nil
950}
951
952func (me *PsqlDB) FindUpdatedPostsForUser(userID string, space string) ([]*db.Post, error) {
953 var posts []*db.Post
954 rs, err := me.Db.Query(sqlSelectUpdatedPostsForUser, userID, space)
955 if err != nil {
956 return posts, err
957 }
958 for rs.Next() {
959 post, err := CreatePostFromRow(rs)
960 if err != nil {
961 return nil, err
962 }
963
964 posts = append(posts, post)
965 }
966 if rs.Err() != nil {
967 return posts, rs.Err()
968 }
969 return posts, nil
970}
971
972func (me *PsqlDB) Close() error {
973 me.Logger.Info("Closing db")
974 return me.Db.Close()
975}
976
977func newNullString(s string) sql.NullString {
978 if len(s) == 0 {
979 return sql.NullString{}
980 }
981 return sql.NullString{
982 String: s,
983 Valid: true,
984 }
985}
986
987func (me *PsqlDB) InsertVisit(visit *db.AnalyticsVisits) error {
988 _, err := me.Db.Exec(
989 `INSERT INTO analytics_visits (user_id, project_id, post_id, namespace, host, path, ip_address, user_agent, referer, status, content_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);`,
990 visit.UserID,
991 newNullString(visit.ProjectID),
992 newNullString(visit.PostID),
993 newNullString(visit.Namespace),
994 visit.Host,
995 visit.Path,
996 visit.IpAddress,
997 visit.UserAgent,
998 visit.Referer,
999 visit.Status,
1000 visit.ContentType,
1001 )
1002 return err
1003}
1004
1005func visitFilterBy(opts *db.SummaryOpts) (string, string) {
1006 where := ""
1007 val := ""
1008 if opts.Host != "" {
1009 where = "host"
1010 val = opts.Host
1011 } else if opts.Path != "" {
1012 where = "path"
1013 val = opts.Path
1014 }
1015
1016 return where, val
1017}
1018
1019func (me *PsqlDB) visitUnique(opts *db.SummaryOpts) ([]*db.VisitInterval, error) {
1020 where, with := visitFilterBy(opts)
1021 uniqueVisitors := fmt.Sprintf(`SELECT
1022 date_trunc('%s', created_at) as interval_start,
1023 count(DISTINCT ip_address) as unique_visitors
1024 FROM analytics_visits
1025 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND status <> 404
1026 GROUP BY interval_start`, opts.Interval, where)
1027
1028 intervals := []*db.VisitInterval{}
1029 rs, err := me.Db.Query(uniqueVisitors, opts.Origin, with, opts.UserID)
1030 if err != nil {
1031 return nil, err
1032 }
1033
1034 for rs.Next() {
1035 interval := &db.VisitInterval{}
1036 err := rs.Scan(
1037 &interval.Interval,
1038 &interval.Visitors,
1039 )
1040 if err != nil {
1041 return nil, err
1042 }
1043
1044 intervals = append(intervals, interval)
1045 }
1046 if rs.Err() != nil {
1047 return nil, rs.Err()
1048 }
1049 return intervals, nil
1050}
1051
1052func (me *PsqlDB) visitReferer(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1053 where, with := visitFilterBy(opts)
1054 topUrls := fmt.Sprintf(`SELECT
1055 referer,
1056 count(DISTINCT ip_address) as referer_count
1057 FROM analytics_visits
1058 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND referer <> '' AND status <> 404
1059 GROUP BY referer
1060 ORDER BY referer_count DESC
1061 LIMIT 10`, where)
1062
1063 intervals := []*db.VisitUrl{}
1064 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1065 if err != nil {
1066 return nil, err
1067 }
1068
1069 for rs.Next() {
1070 interval := &db.VisitUrl{}
1071 err := rs.Scan(
1072 &interval.Url,
1073 &interval.Count,
1074 )
1075 if err != nil {
1076 return nil, err
1077 }
1078
1079 intervals = append(intervals, interval)
1080 }
1081 if rs.Err() != nil {
1082 return nil, rs.Err()
1083 }
1084 return intervals, nil
1085}
1086
1087func (me *PsqlDB) visitUrl(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1088 where, with := visitFilterBy(opts)
1089 topUrls := fmt.Sprintf(`SELECT
1090 path,
1091 count(DISTINCT ip_address) as path_count
1092 FROM analytics_visits
1093 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND path <> '' AND status <> 404
1094 GROUP BY path
1095 ORDER BY path_count DESC
1096 LIMIT 10`, where)
1097
1098 intervals := []*db.VisitUrl{}
1099 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1100 if err != nil {
1101 return nil, err
1102 }
1103
1104 for rs.Next() {
1105 interval := &db.VisitUrl{}
1106 err := rs.Scan(
1107 &interval.Url,
1108 &interval.Count,
1109 )
1110 if err != nil {
1111 return nil, err
1112 }
1113
1114 intervals = append(intervals, interval)
1115 }
1116 if rs.Err() != nil {
1117 return nil, rs.Err()
1118 }
1119 return intervals, nil
1120}
1121
1122func (me *PsqlDB) visitUrlNotFound(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1123 where, with := visitFilterBy(opts)
1124 topUrls := fmt.Sprintf(`SELECT
1125 path,
1126 count(DISTINCT ip_address) as path_count
1127 FROM analytics_visits
1128 WHERE created_at >= $1 AND %s = $2 AND user_id = $3 AND path <> '' AND status = 404
1129 GROUP BY path
1130 ORDER BY path_count DESC
1131 LIMIT 10`, where)
1132
1133 intervals := []*db.VisitUrl{}
1134 rs, err := me.Db.Query(topUrls, opts.Origin, with, opts.UserID)
1135 if err != nil {
1136 return nil, err
1137 }
1138
1139 for rs.Next() {
1140 interval := &db.VisitUrl{}
1141 err := rs.Scan(
1142 &interval.Url,
1143 &interval.Count,
1144 )
1145 if err != nil {
1146 return nil, err
1147 }
1148
1149 intervals = append(intervals, interval)
1150 }
1151 if rs.Err() != nil {
1152 return nil, rs.Err()
1153 }
1154 return intervals, nil
1155}
1156
1157func (me *PsqlDB) visitHost(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1158 topUrls := `SELECT
1159 host,
1160 count(DISTINCT ip_address) as host_count
1161 FROM analytics_visits
1162 WHERE user_id = $1 AND host <> '' AND status >= 200 AND status < 300
1163 GROUP BY host
1164 ORDER BY host_count DESC`
1165
1166 intervals := []*db.VisitUrl{}
1167 rs, err := me.Db.Query(topUrls, opts.UserID)
1168 if err != nil {
1169 return nil, err
1170 }
1171
1172 for rs.Next() {
1173 interval := &db.VisitUrl{}
1174 err := rs.Scan(
1175 &interval.Url,
1176 &interval.Count,
1177 )
1178 if err != nil {
1179 return nil, err
1180 }
1181
1182 intervals = append(intervals, interval)
1183 }
1184 if rs.Err() != nil {
1185 return nil, rs.Err()
1186 }
1187 return intervals, nil
1188}
1189
1190func (me *PsqlDB) VisitSummary(opts *db.SummaryOpts) (*db.SummaryVisits, error) {
1191 visitors, err := me.visitUnique(opts)
1192 if err != nil {
1193 return nil, err
1194 }
1195
1196 urls, err := me.visitUrl(opts)
1197 if err != nil {
1198 return nil, err
1199 }
1200
1201 notFound, err := me.visitUrlNotFound(opts)
1202 if err != nil {
1203 return nil, err
1204 }
1205
1206 refs, err := me.visitReferer(opts)
1207 if err != nil {
1208 return nil, err
1209 }
1210
1211 return &db.SummaryVisits{
1212 Intervals: visitors,
1213 TopUrls: urls,
1214 TopReferers: refs,
1215 NotFoundUrls: notFound,
1216 }, nil
1217}
1218
1219func (me *PsqlDB) FindVisitSiteList(opts *db.SummaryOpts) ([]*db.VisitUrl, error) {
1220 return me.visitHost(opts)
1221}
1222
1223func (me *PsqlDB) FindUsers() ([]*db.User, error) {
1224 var users []*db.User
1225 rs, err := me.Db.Query(sqlSelectUsers)
1226 if err != nil {
1227 return users, err
1228 }
1229 for rs.Next() {
1230 var name sql.NullString
1231 user := &db.User{}
1232 err := rs.Scan(
1233 &user.ID,
1234 &name,
1235 &user.CreatedAt,
1236 )
1237 if err != nil {
1238 return users, err
1239 }
1240 user.Name = name.String
1241
1242 users = append(users, user)
1243 }
1244 if rs.Err() != nil {
1245 return users, rs.Err()
1246 }
1247 return users, nil
1248}
1249
1250func (me *PsqlDB) removeTagsForPost(tx *sql.Tx, postID string) error {
1251 _, err := tx.Exec(sqlRemoveTagsByPost, postID)
1252 return err
1253}
1254
1255func (me *PsqlDB) insertTagsForPost(tx *sql.Tx, tags []string, postID string) ([]string, error) {
1256 ids := make([]string, 0)
1257 for _, tag := range tags {
1258 id := ""
1259 err := tx.QueryRow(sqlInsertTag, postID, tag).Scan(&id)
1260 if err != nil {
1261 return nil, err
1262 }
1263 ids = append(ids, id)
1264 }
1265
1266 return ids, nil
1267}
1268
1269func (me *PsqlDB) ReplaceTagsForPost(tags []string, postID string) error {
1270 ctx := context.Background()
1271 tx, err := me.Db.BeginTx(ctx, nil)
1272 if err != nil {
1273 return err
1274 }
1275 defer func() {
1276 err = tx.Rollback()
1277 }()
1278
1279 err = me.removeTagsForPost(tx, postID)
1280 if err != nil {
1281 return err
1282 }
1283
1284 _, err = me.insertTagsForPost(tx, tags, postID)
1285 if err != nil {
1286 return err
1287 }
1288
1289 err = tx.Commit()
1290 return err
1291}
1292
1293func (me *PsqlDB) removeAliasesForPost(tx *sql.Tx, postID string) error {
1294 _, err := tx.Exec(sqlRemoveAliasesByPost, postID)
1295 return err
1296}
1297
1298func (me *PsqlDB) insertAliasesForPost(tx *sql.Tx, aliases []string, postID string) ([]string, error) {
1299 // hardcoded
1300 denyList := []string{
1301 "rss",
1302 "rss.xml",
1303 "atom.xml",
1304 "feed.xml",
1305 "smol.css",
1306 "main.css",
1307 "syntax.css",
1308 "card.png",
1309 "favicon-16x16.png",
1310 "favicon-32x32.png",
1311 "apple-touch-icon.png",
1312 "favicon.ico",
1313 "robots.txt",
1314 "atom",
1315 "blog/index.xml",
1316 }
1317
1318 ids := make([]string, 0)
1319 for _, alias := range aliases {
1320 if slices.Contains(denyList, alias) {
1321 me.Logger.Info(
1322 "name is in the deny list for aliases because it conflicts with a static route, skipping",
1323 "alias", alias,
1324 )
1325 continue
1326 }
1327 id := ""
1328 err := tx.QueryRow(sqlInsertAliases, postID, alias).Scan(&id)
1329 if err != nil {
1330 return nil, err
1331 }
1332 ids = append(ids, id)
1333 }
1334
1335 return ids, nil
1336}
1337
1338func (me *PsqlDB) ReplaceAliasesForPost(aliases []string, postID string) error {
1339 ctx := context.Background()
1340 tx, err := me.Db.BeginTx(ctx, nil)
1341 if err != nil {
1342 return err
1343 }
1344 defer func() {
1345 err = tx.Rollback()
1346 }()
1347
1348 err = me.removeAliasesForPost(tx, postID)
1349 if err != nil {
1350 return err
1351 }
1352
1353 _, err = me.insertAliasesForPost(tx, aliases, postID)
1354 if err != nil {
1355 return err
1356 }
1357
1358 err = tx.Commit()
1359 return err
1360}
1361
1362func (me *PsqlDB) FindUserPostsByTag(page *db.Pager, tag, userID, space string) (*db.Paginate[*db.Post], error) {
1363 var posts []*db.Post
1364 rs, err := me.Db.Query(
1365 sqlSelectUserPostsByTag,
1366 userID,
1367 tag,
1368 space,
1369 page.Num,
1370 page.Num*page.Page,
1371 )
1372 if err != nil {
1373 return nil, err
1374 }
1375 for rs.Next() {
1376 post, err := CreatePostFromRow(rs)
1377 if err != nil {
1378 return nil, err
1379 }
1380
1381 posts = append(posts, post)
1382 }
1383
1384 if rs.Err() != nil {
1385 return nil, rs.Err()
1386 }
1387
1388 var count int
1389 err = me.Db.QueryRow(sqlSelectPostCount, space).Scan(&count)
1390 if err != nil {
1391 return nil, err
1392 }
1393
1394 pager := &db.Paginate[*db.Post]{
1395 Data: posts,
1396 Total: int(math.Ceil(float64(count) / float64(page.Num))),
1397 }
1398 return pager, nil
1399}
1400
1401func (me *PsqlDB) FindPostsByTag(pager *db.Pager, tag, space string) (*db.Paginate[*db.Post], error) {
1402 rs, err := me.Db.Query(
1403 sqlSelectPostsByTag,
1404 pager.Num,
1405 pager.Num*pager.Page,
1406 tag,
1407 space,
1408 )
1409 if err != nil {
1410 return nil, err
1411 }
1412
1413 return me.postPager(rs, pager.Num, space, tag)
1414}
1415
1416func (me *PsqlDB) FindPopularTags(space string) ([]string, error) {
1417 tags := make([]string, 0)
1418 rs, err := me.Db.Query(sqlSelectPopularTags, space)
1419 if err != nil {
1420 return tags, err
1421 }
1422 for rs.Next() {
1423 name := ""
1424 tally := 0
1425 err := rs.Scan(&name, &tally)
1426 if err != nil {
1427 return tags, err
1428 }
1429
1430 tags = append(tags, name)
1431 }
1432 if rs.Err() != nil {
1433 return tags, rs.Err()
1434 }
1435 return tags, nil
1436}
1437
1438func (me *PsqlDB) FindTagsForPost(postID string) ([]string, error) {
1439 tags := make([]string, 0)
1440 rs, err := me.Db.Query(sqlSelectTagsForPost, postID)
1441 if err != nil {
1442 return tags, err
1443 }
1444
1445 for rs.Next() {
1446 name := ""
1447 err := rs.Scan(&name)
1448 if err != nil {
1449 return tags, err
1450 }
1451
1452 tags = append(tags, name)
1453 }
1454
1455 if rs.Err() != nil {
1456 return tags, rs.Err()
1457 }
1458
1459 return tags, nil
1460}
1461
1462func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
1463 ff := &db.FeatureFlag{}
1464 // payment history is allowed to be null
1465 // https://devtidbits.com/2020/08/03/go-sql-error-converting-null-to-string-is-unsupported/
1466 var paymentHistoryID sql.NullString
1467 err := me.Db.QueryRow(sqlSelectFeatureForUser, userID, feature).Scan(
1468 &ff.ID,
1469 &ff.UserID,
1470 &paymentHistoryID,
1471 &ff.Name,
1472 &ff.Data,
1473 &ff.CreatedAt,
1474 &ff.ExpiresAt,
1475 )
1476 if err != nil {
1477 return nil, err
1478 }
1479
1480 ff.PaymentHistoryID = paymentHistoryID.String
1481
1482 return ff, nil
1483}
1484
1485func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error) {
1486 var features []*db.FeatureFlag
1487 // https://stackoverflow.com/a/16920077
1488 query := `SELECT DISTINCT ON (name)
1489 id, user_id, payment_history_id, name, data, created_at, expires_at
1490 FROM feature_flags
1491 WHERE user_id=$1
1492 ORDER BY name, expires_at DESC;`
1493 rs, err := me.Db.Query(query, userID)
1494 if err != nil {
1495 return features, err
1496 }
1497 for rs.Next() {
1498 var paymentHistoryID sql.NullString
1499 ff := &db.FeatureFlag{}
1500 err := rs.Scan(
1501 &ff.ID,
1502 &ff.UserID,
1503 &paymentHistoryID,
1504 &ff.Name,
1505 &ff.Data,
1506 &ff.CreatedAt,
1507 &ff.ExpiresAt,
1508 )
1509 if err != nil {
1510 return features, err
1511 }
1512 ff.PaymentHistoryID = paymentHistoryID.String
1513
1514 features = append(features, ff)
1515 }
1516 if rs.Err() != nil {
1517 return features, rs.Err()
1518 }
1519 return features, nil
1520}
1521
1522func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
1523 ff, err := me.FindFeatureForUser(userID, feature)
1524 if err != nil {
1525 return false
1526 }
1527 return ff.IsValid()
1528}
1529
1530func (me *PsqlDB) FindTotalSizeForUser(userID string) (int, error) {
1531 var fileSize int
1532 err := me.Db.QueryRow(sqlSelectSizeForUser, userID).Scan(&fileSize)
1533 if err != nil {
1534 return 0, err
1535 }
1536 return fileSize, nil
1537}
1538
1539func (me *PsqlDB) InsertFeedItems(postID string, items []*db.FeedItem) error {
1540 ctx := context.Background()
1541 tx, err := me.Db.BeginTx(ctx, nil)
1542 if err != nil {
1543 return err
1544 }
1545 defer func() {
1546 err = tx.Rollback()
1547 }()
1548
1549 for _, item := range items {
1550 _, err := tx.Exec(
1551 sqlInsertFeedItems,
1552 item.PostID,
1553 item.GUID,
1554 item.Data,
1555 )
1556 if err != nil {
1557 return err
1558 }
1559 }
1560
1561 err = tx.Commit()
1562 return err
1563}
1564
1565func (me *PsqlDB) FindFeedItemsByPostID(postID string) ([]*db.FeedItem, error) {
1566 // sqlSelectFeedItemsByPost
1567 items := make([]*db.FeedItem, 0)
1568 rs, err := me.Db.Query(sqlSelectFeedItemsByPost, postID)
1569 if err != nil {
1570 return items, err
1571 }
1572
1573 for rs.Next() {
1574 item := &db.FeedItem{}
1575 err := rs.Scan(
1576 &item.ID,
1577 &item.PostID,
1578 &item.GUID,
1579 &item.Data,
1580 &item.CreatedAt,
1581 )
1582 if err != nil {
1583 return items, err
1584 }
1585
1586 items = append(items, item)
1587 }
1588
1589 if rs.Err() != nil {
1590 return items, rs.Err()
1591 }
1592
1593 return items, nil
1594}
1595
1596func (me *PsqlDB) InsertProject(userID, name, projectDir string) (string, error) {
1597 if !utils.IsValidSubdomain(name) {
1598 return "", fmt.Errorf("'%s' is not a valid project name, must match /^[a-z0-9-]+$/", name)
1599 }
1600
1601 var id string
1602 err := me.Db.QueryRow(sqlInsertProject, userID, name, projectDir).Scan(&id)
1603 if err != nil {
1604 return "", err
1605 }
1606 return id, nil
1607}
1608
1609func (me *PsqlDB) UpdateProject(userID, name string) error {
1610 _, err := me.Db.Exec(sqlUpdateProject, userID, name, time.Now())
1611 return err
1612}
1613
1614func (me *PsqlDB) UpdateProjectAcl(userID, name string, acl db.ProjectAcl) error {
1615 _, err := me.Db.Exec(sqlUpdateProjectAcl, userID, name, acl, time.Now())
1616 return err
1617}
1618
1619func (me *PsqlDB) LinkToProject(userID, projectID, projectDir string, commit bool) error {
1620 linkToProject, err := me.FindProjectByName(userID, projectDir)
1621 if err != nil {
1622 return err
1623 }
1624 isAlreadyLinked := linkToProject.Name != linkToProject.ProjectDir
1625 sameProject := linkToProject.ID == projectID
1626
1627 /*
1628 A project linked to another project which is also linked to a
1629 project is forbidden. CI/CD Example:
1630 - ProjectProd links to ProjectStaging
1631 - ProjectStaging links to ProjectMain
1632 - We merge `main` and trigger a deploy which uploads to ProjectMain
1633 - All three get updated immediately
1634 This scenario was not the intent of our CI/CD. What we actually
1635 wanted was to create a snapshot of ProjectMain and have ProjectStaging
1636 link to the snapshot, but that's not the intended design of pgs.
1637
1638 So we want to close that gap here.
1639
1640 We ensure that `project.Name` and `project.ProjectDir` are identical
1641 when there is no aliasing.
1642 */
1643 if !sameProject && isAlreadyLinked {
1644 return fmt.Errorf(
1645 "cannot link (%s) to (%s) because it is also a link to (%s)",
1646 projectID,
1647 projectDir,
1648 linkToProject.ProjectDir,
1649 )
1650 }
1651
1652 if commit {
1653 _, err = me.Db.Exec(
1654 sqlLinkToProject,
1655 projectDir,
1656 time.Now(),
1657 projectID,
1658 )
1659 }
1660 return err
1661}
1662
1663func (me *PsqlDB) RemoveProject(projectID string) error {
1664 _, err := me.Db.Exec(sqlRemoveProject, projectID)
1665 return err
1666}
1667
1668func (me *PsqlDB) FindProjectByName(userID, name string) (*db.Project, error) {
1669 project := &db.Project{}
1670 r := me.Db.QueryRow(sqlFindProjectByName, userID, name)
1671 err := r.Scan(
1672 &project.ID,
1673 &project.UserID,
1674 &project.Name,
1675 &project.ProjectDir,
1676 &project.Acl,
1677 &project.Blocked,
1678 &project.CreatedAt,
1679 &project.UpdatedAt,
1680 )
1681 if err != nil {
1682 return nil, err
1683 }
1684
1685 return project, nil
1686}
1687
1688func (me *PsqlDB) FindProjectLinks(userID, name string) ([]*db.Project, error) {
1689 var projects []*db.Project
1690 rs, err := me.Db.Query(sqlFindProjectLinks, userID, name)
1691 if err != nil {
1692 return nil, err
1693 }
1694 for rs.Next() {
1695 project := &db.Project{}
1696 err := rs.Scan(
1697 &project.ID,
1698 &project.UserID,
1699 &project.Name,
1700 &project.ProjectDir,
1701 &project.Acl,
1702 &project.Blocked,
1703 &project.CreatedAt,
1704 &project.UpdatedAt,
1705 )
1706 if err != nil {
1707 return nil, err
1708 }
1709
1710 projects = append(projects, project)
1711 }
1712
1713 if rs.Err() != nil {
1714 return nil, rs.Err()
1715 }
1716
1717 return projects, nil
1718}
1719
1720func (me *PsqlDB) FindProjectsByPrefix(userID, prefix string) ([]*db.Project, error) {
1721 var projects []*db.Project
1722 rs, err := me.Db.Query(sqlFindProjectsByPrefix, userID, prefix+"%")
1723 if err != nil {
1724 return nil, err
1725 }
1726 for rs.Next() {
1727 project := &db.Project{}
1728 err := rs.Scan(
1729 &project.ID,
1730 &project.UserID,
1731 &project.Name,
1732 &project.ProjectDir,
1733 &project.Acl,
1734 &project.Blocked,
1735 &project.CreatedAt,
1736 &project.UpdatedAt,
1737 )
1738 if err != nil {
1739 return nil, err
1740 }
1741
1742 projects = append(projects, project)
1743 }
1744
1745 if rs.Err() != nil {
1746 return nil, rs.Err()
1747 }
1748
1749 return projects, nil
1750}
1751
1752func (me *PsqlDB) FindProjectsByUser(userID string) ([]*db.Project, error) {
1753 var projects []*db.Project
1754 rs, err := me.Db.Query(sqlFindProjectsByUser, userID)
1755 if err != nil {
1756 return nil, err
1757 }
1758 for rs.Next() {
1759 project := &db.Project{}
1760 err := rs.Scan(
1761 &project.ID,
1762 &project.UserID,
1763 &project.Name,
1764 &project.ProjectDir,
1765 &project.Acl,
1766 &project.Blocked,
1767 &project.CreatedAt,
1768 &project.UpdatedAt,
1769 )
1770 if err != nil {
1771 return nil, err
1772 }
1773
1774 projects = append(projects, project)
1775 }
1776
1777 if rs.Err() != nil {
1778 return nil, rs.Err()
1779 }
1780
1781 return projects, nil
1782}
1783
1784func (me *PsqlDB) FindAllProjects(page *db.Pager, by string) (*db.Paginate[*db.Project], error) {
1785 var projects []*db.Project
1786 sqlFindAllProjects := fmt.Sprintf(`
1787 SELECT projects.id, user_id, app_users.name as username, projects.name, project_dir, projects.acl, projects.blocked, projects.created_at, projects.updated_at
1788 FROM projects
1789 LEFT JOIN app_users ON app_users.id = projects.user_id
1790 ORDER BY %s DESC
1791 LIMIT $1 OFFSET $2`, by)
1792 rs, err := me.Db.Query(sqlFindAllProjects, page.Num, page.Num*page.Page)
1793 if err != nil {
1794 return nil, err
1795 }
1796 for rs.Next() {
1797 project := &db.Project{}
1798 err := rs.Scan(
1799 &project.ID,
1800 &project.UserID,
1801 &project.Username,
1802 &project.Name,
1803 &project.ProjectDir,
1804 &project.Acl,
1805 &project.Blocked,
1806 &project.CreatedAt,
1807 &project.UpdatedAt,
1808 )
1809 if err != nil {
1810 return nil, err
1811 }
1812
1813 projects = append(projects, project)
1814 }
1815
1816 if rs.Err() != nil {
1817 return nil, rs.Err()
1818 }
1819
1820 var count int
1821 err = me.Db.QueryRow(sqlSelectProjectCount).Scan(&count)
1822 if err != nil {
1823 return nil, err
1824 }
1825
1826 pager := &db.Paginate[*db.Project]{
1827 Data: projects,
1828 Total: int(math.Ceil(float64(count) / float64(page.Num))),
1829 }
1830
1831 return pager, nil
1832}
1833
1834func (me *PsqlDB) InsertToken(userID, name string) (string, error) {
1835 var token string
1836 err := me.Db.QueryRow(sqlInsertToken, userID, name).Scan(&token)
1837 if err != nil {
1838 return "", err
1839 }
1840 return token, nil
1841}
1842
1843func (me *PsqlDB) UpsertToken(userID, name string) (string, error) {
1844 token, _ := me.FindTokenByName(userID, name)
1845 if token != "" {
1846 return token, nil
1847 }
1848
1849 token, err := me.InsertToken(userID, name)
1850 return token, err
1851}
1852
1853func (me *PsqlDB) FindTokenByName(userID, name string) (string, error) {
1854 var token string
1855 err := me.Db.QueryRow(sqlSelectTokenByNameForUser, userID, name).Scan(&token)
1856 if err != nil {
1857 return "", err
1858 }
1859 return token, nil
1860}
1861
1862func (me *PsqlDB) RemoveToken(tokenID string) error {
1863 _, err := me.Db.Exec(sqlRemoveToken, tokenID)
1864 return err
1865}
1866
1867func (me *PsqlDB) FindTokensForUser(userID string) ([]*db.Token, error) {
1868 var keys []*db.Token
1869 rs, err := me.Db.Query(sqlSelectTokensForUser, userID)
1870 if err != nil {
1871 return keys, err
1872 }
1873 for rs.Next() {
1874 pk := &db.Token{}
1875 err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.CreatedAt, &pk.ExpiresAt)
1876 if err != nil {
1877 return keys, err
1878 }
1879
1880 keys = append(keys, pk)
1881 }
1882 if rs.Err() != nil {
1883 return keys, rs.Err()
1884 }
1885 return keys, nil
1886}
1887
1888func (me *PsqlDB) InsertFeature(userID, name string, expiresAt time.Time) (*db.FeatureFlag, error) {
1889 var featureID string
1890 err := me.Db.QueryRow(
1891 `INSERT INTO feature_flags (user_id, name, expires_at) VALUES ($1, $2, $3) RETURNING id;`,
1892 userID,
1893 name,
1894 expiresAt,
1895 ).Scan(&featureID)
1896 if err != nil {
1897 return nil, err
1898 }
1899
1900 feature, err := me.FindFeatureForUser(userID, name)
1901 if err != nil {
1902 return nil, err
1903 }
1904
1905 return feature, nil
1906}
1907
1908func (me *PsqlDB) RemoveFeature(userID string, name string) error {
1909 _, err := me.Db.Exec(`DELETE FROM feature_flags WHERE user_id = $1 AND name = $2`, userID, name)
1910 return err
1911}
1912
1913func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
1914 ff, _ := me.FindFeatureForUser(userID, name)
1915 if ff == nil {
1916 t := time.Now()
1917 return t.AddDate(1, 0, 0)
1918 }
1919 return ff.ExpiresAt.AddDate(1, 0, 0)
1920}
1921
1922func (me *PsqlDB) AddPicoPlusUser(username, email, paymentType, txId string) error {
1923 user, err := me.FindUserForName(username)
1924 if err != nil {
1925 return err
1926 }
1927
1928 ctx := context.Background()
1929 tx, err := me.Db.BeginTx(ctx, nil)
1930 if err != nil {
1931 return err
1932 }
1933 defer func() {
1934 err = tx.Rollback()
1935 }()
1936
1937 var paymentHistoryId sql.NullString
1938 if paymentType != "" {
1939 data := db.PaymentHistoryData{
1940 Notes: "",
1941 TxID: txId,
1942 }
1943
1944 err := tx.QueryRow(
1945 `INSERT INTO payment_history (user_id, payment_type, amount, data) VALUES ($1, $2, 24 * 1000000, $3) RETURNING id;`,
1946 user.ID,
1947 paymentType,
1948 data,
1949 ).Scan(&paymentHistoryId)
1950 if err != nil {
1951 return err
1952 }
1953 }
1954
1955 plus := me.createFeatureExpiresAt(user.ID, "plus")
1956 plusQuery := fmt.Sprintf(`INSERT INTO feature_flags (user_id, name, data, expires_at, payment_history_id)
1957 VALUES ($1, 'plus', '{"storage_max":10000000000, "file_max":50000000, "email": "%s"}'::jsonb, $2, $3);`, email)
1958 _, err = tx.Exec(plusQuery, user.ID, plus, paymentHistoryId)
1959 if err != nil {
1960 return err
1961 }
1962
1963 return tx.Commit()
1964}