feat: add DB provider methods for categories

Implements SaveCategory, SelectCategory, SelectCategoriesOfSite,
SelectCategoryBySlugAndSite, DeleteCategory, SelectCategoriesOfPost,
SelectPostsOfCategory, CountPostsOfCategory, and SetPostCategories on
the DB Provider, along with BeginTx/QueriesWithTx for transaction
support. Also fixes pre-existing compilation errors in provider_test.go
(missing PagingParams args) so new tests can compile and run.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Leon Mika 2026-03-18 21:37:01 +11:00
parent d47095a902
commit 15bc6b7f73
3 changed files with 305 additions and 3 deletions

132
providers/db/categories.go Normal file
View file

@ -0,0 +1,132 @@
package db
import (
"context"
"time"
"lmika.dev/lmika/weiro/models"
"lmika.dev/lmika/weiro/providers/db/gen/sqlgen"
)
func (db *Provider) SelectCategoriesOfSite(ctx context.Context, siteID int64) ([]*models.Category, error) {
rows, err := db.queries.SelectCategoriesOfSite(ctx, siteID)
if err != nil {
return nil, err
}
cats := make([]*models.Category, len(rows))
for i, row := range rows {
cats[i] = dbCategoryToCategory(row)
}
return cats, nil
}
func (db *Provider) SelectCategory(ctx context.Context, id int64) (*models.Category, error) {
row, err := db.queries.SelectCategory(ctx, id)
if err != nil {
return nil, err
}
return dbCategoryToCategory(row), nil
}
func (db *Provider) SelectCategoryBySlugAndSite(ctx context.Context, siteID int64, slug string) (*models.Category, error) {
row, err := db.queries.SelectCategoryBySlugAndSite(ctx, sqlgen.SelectCategoryBySlugAndSiteParams{
SiteID: siteID,
Slug: slug,
})
if err != nil {
return nil, err
}
return dbCategoryToCategory(row), nil
}
func (db *Provider) SaveCategory(ctx context.Context, cat *models.Category) error {
if cat.ID == 0 {
newID, err := db.queries.InsertCategory(ctx, sqlgen.InsertCategoryParams{
SiteID: cat.SiteID,
Guid: cat.GUID,
Name: cat.Name,
Slug: cat.Slug,
Description: cat.Description,
CreatedAt: timeToInt(cat.CreatedAt),
UpdatedAt: timeToInt(cat.UpdatedAt),
})
if err != nil {
return err
}
cat.ID = newID
return nil
}
return db.queries.UpdateCategory(ctx, sqlgen.UpdateCategoryParams{
ID: cat.ID,
Name: cat.Name,
Slug: cat.Slug,
Description: cat.Description,
UpdatedAt: timeToInt(cat.UpdatedAt),
})
}
func (db *Provider) DeleteCategory(ctx context.Context, id int64) error {
return db.queries.DeleteCategory(ctx, id)
}
func (db *Provider) SelectCategoriesOfPost(ctx context.Context, postID int64) ([]*models.Category, error) {
rows, err := db.queries.SelectCategoriesOfPost(ctx, postID)
if err != nil {
return nil, err
}
cats := make([]*models.Category, len(rows))
for i, row := range rows {
cats[i] = dbCategoryToCategory(row)
}
return cats, nil
}
func (db *Provider) SelectPostsOfCategory(ctx context.Context, categoryID int64, pp PagingParams) ([]*models.Post, error) {
rows, err := db.queries.SelectPostsOfCategory(ctx, sqlgen.SelectPostsOfCategoryParams{
CategoryID: categoryID,
Limit: pp.Limit,
Offset: pp.Offset,
})
if err != nil {
return nil, err
}
posts := make([]*models.Post, len(rows))
for i, row := range rows {
posts[i] = dbPostToPost(row)
}
return posts, nil
}
func (db *Provider) CountPostsOfCategory(ctx context.Context, categoryID int64) (int64, error) {
return db.queries.CountPostsOfCategory(ctx, categoryID)
}
// SetPostCategories replaces all category associations for a post.
func (db *Provider) SetPostCategories(ctx context.Context, postID int64, categoryIDs []int64) error {
if err := db.queries.DeletePostCategoriesByPost(ctx, postID); err != nil {
return err
}
for _, catID := range categoryIDs {
if err := db.queries.InsertPostCategory(ctx, sqlgen.InsertPostCategoryParams{
PostID: postID,
CategoryID: catID,
}); err != nil {
return err
}
}
return nil
}
func dbCategoryToCategory(row sqlgen.Category) *models.Category {
return &models.Category{
ID: row.ID,
SiteID: row.SiteID,
GUID: row.Guid,
Name: row.Name,
Slug: row.Slug,
Description: row.Description,
CreatedAt: time.Unix(row.CreatedAt, 0).UTC(),
UpdatedAt: time.Unix(row.UpdatedAt, 0).UTC(),
}
}

View file

@ -40,6 +40,17 @@ func (db *Provider) Close() error {
return db.drvr.Close()
}
func (db *Provider) BeginTx(ctx context.Context) (*sql.Tx, error) {
return db.drvr.BeginTx(ctx, nil)
}
func (db *Provider) QueriesWithTx(tx *sql.Tx) *Provider {
return &Provider{
drvr: db.drvr,
queries: db.queries.WithTx(tx),
}
}
func (db *Provider) SoftDeletePost(ctx context.Context, postID int64) error {
return db.queries.SoftDeletePost(ctx, sqlgen.SoftDeletePostParams{
DeletedAt: time.Now().Unix(),

View file

@ -158,7 +158,7 @@ func TestProvider_Posts(t *testing.T) {
require.NoError(t, err)
assert.NotZero(t, post.ID)
posts, err := p.SelectPostsOfSite(ctx, site.ID, false)
posts, err := p.SelectPostsOfSite(ctx, site.ID, false, db.PagingParams{})
require.NoError(t, err)
require.Len(t, posts, 1)
assert.Equal(t, post.ID, posts[0].ID)
@ -205,7 +205,7 @@ func TestProvider_Posts(t *testing.T) {
require.NoError(t, p.SavePost(ctx, post1))
require.NoError(t, p.SavePost(ctx, post2))
posts, err := p.SelectPostsOfSite(ctx, site2.ID, false)
posts, err := p.SelectPostsOfSite(ctx, site2.ID, false, db.PagingParams{})
require.NoError(t, err)
require.Len(t, posts, 2)
assert.Equal(t, "New Post", posts[0].Title)
@ -220,7 +220,7 @@ func TestProvider_Posts(t *testing.T) {
}
require.NoError(t, p.SaveSite(ctx, emptySite))
posts, err := p.SelectPostsOfSite(ctx, emptySite.ID, false)
posts, err := p.SelectPostsOfSite(ctx, emptySite.ID, false, db.PagingParams{})
require.NoError(t, err)
assert.Empty(t, posts)
})
@ -283,6 +283,165 @@ func TestProvider_PublishTargets(t *testing.T) {
})
}
func TestProvider_Categories(t *testing.T) {
ctx := context.Background()
p := newTestDB(t)
user := &models.User{Username: "testuser", PasswordHashed: []byte("password")}
require.NoError(t, p.SaveUser(ctx, user))
site := &models.Site{OwnerID: user.ID, Title: "My Blog", Tagline: "test"}
require.NoError(t, p.SaveSite(ctx, site))
t.Run("save and select categories", func(t *testing.T) {
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
cat := &models.Category{
SiteID: site.ID,
GUID: "cat-001",
Name: "Go Programming",
Slug: "go-programming",
Description: "Posts about Go",
CreatedAt: now,
UpdatedAt: now,
}
err := p.SaveCategory(ctx, cat)
require.NoError(t, err)
assert.NotZero(t, cat.ID)
cats, err := p.SelectCategoriesOfSite(ctx, site.ID)
require.NoError(t, err)
require.Len(t, cats, 1)
assert.Equal(t, "Go Programming", cats[0].Name)
assert.Equal(t, "go-programming", cats[0].Slug)
assert.Equal(t, "Posts about Go", cats[0].Description)
})
t.Run("update category", func(t *testing.T) {
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
cat := &models.Category{
SiteID: site.ID,
GUID: "cat-002",
Name: "Original",
Slug: "original",
CreatedAt: now,
UpdatedAt: now,
}
require.NoError(t, p.SaveCategory(ctx, cat))
cat.Name = "Updated"
cat.Slug = "updated"
cat.UpdatedAt = now.Add(time.Hour)
require.NoError(t, p.SaveCategory(ctx, cat))
got, err := p.SelectCategory(ctx, cat.ID)
require.NoError(t, err)
assert.Equal(t, "Updated", got.Name)
assert.Equal(t, "updated", got.Slug)
})
t.Run("delete category", func(t *testing.T) {
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
cat := &models.Category{
SiteID: site.ID,
GUID: "cat-003",
Name: "ToDelete",
Slug: "to-delete",
CreatedAt: now,
UpdatedAt: now,
}
require.NoError(t, p.SaveCategory(ctx, cat))
err := p.DeleteCategory(ctx, cat.ID)
require.NoError(t, err)
_, err = p.SelectCategory(ctx, cat.ID)
assert.Error(t, err)
})
}
func TestProvider_PostCategories(t *testing.T) {
ctx := context.Background()
p := newTestDB(t)
user := &models.User{Username: "testuser", PasswordHashed: []byte("password")}
require.NoError(t, p.SaveUser(ctx, user))
site := &models.Site{OwnerID: user.ID, Title: "My Blog", Tagline: "test"}
require.NoError(t, p.SaveSite(ctx, site))
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
post := &models.Post{
SiteID: site.ID,
GUID: "post-pc-001",
Title: "Test Post",
Body: "body",
Slug: "/test",
CreatedAt: now,
}
require.NoError(t, p.SavePost(ctx, post))
cat1 := &models.Category{SiteID: site.ID, GUID: "cat-pc-1", Name: "Alpha", Slug: "alpha", CreatedAt: now, UpdatedAt: now}
cat2 := &models.Category{SiteID: site.ID, GUID: "cat-pc-2", Name: "Beta", Slug: "beta", CreatedAt: now, UpdatedAt: now}
require.NoError(t, p.SaveCategory(ctx, cat1))
require.NoError(t, p.SaveCategory(ctx, cat2))
t.Run("set and get post categories", func(t *testing.T) {
err := p.SetPostCategories(ctx, post.ID, []int64{cat1.ID, cat2.ID})
require.NoError(t, err)
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
require.NoError(t, err)
require.Len(t, cats, 2)
assert.Equal(t, "Alpha", cats[0].Name)
assert.Equal(t, "Beta", cats[1].Name)
})
t.Run("replace post categories", func(t *testing.T) {
err := p.SetPostCategories(ctx, post.ID, []int64{cat2.ID})
require.NoError(t, err)
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
require.NoError(t, err)
require.Len(t, cats, 1)
assert.Equal(t, "Beta", cats[0].Name)
})
t.Run("clear post categories", func(t *testing.T) {
err := p.SetPostCategories(ctx, post.ID, []int64{})
require.NoError(t, err)
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
require.NoError(t, err)
assert.Empty(t, cats)
})
t.Run("count posts of category", func(t *testing.T) {
post.State = models.StatePublished
post.PublishedAt = now
require.NoError(t, p.SavePost(ctx, post))
require.NoError(t, p.SetPostCategories(ctx, post.ID, []int64{cat1.ID}))
count, err := p.CountPostsOfCategory(ctx, cat1.ID)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
count, err = p.CountPostsOfCategory(ctx, cat2.ID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
})
t.Run("cascade delete category removes associations", func(t *testing.T) {
require.NoError(t, p.SetPostCategories(ctx, post.ID, []int64{cat1.ID, cat2.ID}))
require.NoError(t, p.DeleteCategory(ctx, cat1.ID))
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
require.NoError(t, err)
require.Len(t, cats, 1)
assert.Equal(t, "Beta", cats[0].Name)
})
}
// Verify that password encoding roundtrips correctly through base64
func TestProvider_UserPasswordEncoding(t *testing.T) {
ctx := context.Background()