feat: add DB provider methods for categories
Implements SaveCategory, SelectCategory, SelectCategoriesOfSite, SelectCategoryBySlugAndSite, DeleteCategory, SelectCategoriesOfPost, SelectPostsOfCategory, CountPostsOfCategory, and SetPostCategories on the DB Provider, along with BeginTx/QueriesWithTx for transaction support. Also fixes pre-existing compilation errors in provider_test.go (missing PagingParams args) so new tests can compile and run. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d47095a902
commit
15bc6b7f73
132
providers/db/categories.go
Normal file
132
providers/db/categories.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"lmika.dev/lmika/weiro/models"
|
||||
"lmika.dev/lmika/weiro/providers/db/gen/sqlgen"
|
||||
)
|
||||
|
||||
func (db *Provider) SelectCategoriesOfSite(ctx context.Context, siteID int64) ([]*models.Category, error) {
|
||||
rows, err := db.queries.SelectCategoriesOfSite(ctx, siteID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cats := make([]*models.Category, len(rows))
|
||||
for i, row := range rows {
|
||||
cats[i] = dbCategoryToCategory(row)
|
||||
}
|
||||
return cats, nil
|
||||
}
|
||||
|
||||
func (db *Provider) SelectCategory(ctx context.Context, id int64) (*models.Category, error) {
|
||||
row, err := db.queries.SelectCategory(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbCategoryToCategory(row), nil
|
||||
}
|
||||
|
||||
func (db *Provider) SelectCategoryBySlugAndSite(ctx context.Context, siteID int64, slug string) (*models.Category, error) {
|
||||
row, err := db.queries.SelectCategoryBySlugAndSite(ctx, sqlgen.SelectCategoryBySlugAndSiteParams{
|
||||
SiteID: siteID,
|
||||
Slug: slug,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbCategoryToCategory(row), nil
|
||||
}
|
||||
|
||||
func (db *Provider) SaveCategory(ctx context.Context, cat *models.Category) error {
|
||||
if cat.ID == 0 {
|
||||
newID, err := db.queries.InsertCategory(ctx, sqlgen.InsertCategoryParams{
|
||||
SiteID: cat.SiteID,
|
||||
Guid: cat.GUID,
|
||||
Name: cat.Name,
|
||||
Slug: cat.Slug,
|
||||
Description: cat.Description,
|
||||
CreatedAt: timeToInt(cat.CreatedAt),
|
||||
UpdatedAt: timeToInt(cat.UpdatedAt),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cat.ID = newID
|
||||
return nil
|
||||
}
|
||||
|
||||
return db.queries.UpdateCategory(ctx, sqlgen.UpdateCategoryParams{
|
||||
ID: cat.ID,
|
||||
Name: cat.Name,
|
||||
Slug: cat.Slug,
|
||||
Description: cat.Description,
|
||||
UpdatedAt: timeToInt(cat.UpdatedAt),
|
||||
})
|
||||
}
|
||||
|
||||
func (db *Provider) DeleteCategory(ctx context.Context, id int64) error {
|
||||
return db.queries.DeleteCategory(ctx, id)
|
||||
}
|
||||
|
||||
func (db *Provider) SelectCategoriesOfPost(ctx context.Context, postID int64) ([]*models.Category, error) {
|
||||
rows, err := db.queries.SelectCategoriesOfPost(ctx, postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cats := make([]*models.Category, len(rows))
|
||||
for i, row := range rows {
|
||||
cats[i] = dbCategoryToCategory(row)
|
||||
}
|
||||
return cats, nil
|
||||
}
|
||||
|
||||
func (db *Provider) SelectPostsOfCategory(ctx context.Context, categoryID int64, pp PagingParams) ([]*models.Post, error) {
|
||||
rows, err := db.queries.SelectPostsOfCategory(ctx, sqlgen.SelectPostsOfCategoryParams{
|
||||
CategoryID: categoryID,
|
||||
Limit: pp.Limit,
|
||||
Offset: pp.Offset,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
posts := make([]*models.Post, len(rows))
|
||||
for i, row := range rows {
|
||||
posts[i] = dbPostToPost(row)
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (db *Provider) CountPostsOfCategory(ctx context.Context, categoryID int64) (int64, error) {
|
||||
return db.queries.CountPostsOfCategory(ctx, categoryID)
|
||||
}
|
||||
|
||||
// SetPostCategories replaces all category associations for a post.
|
||||
func (db *Provider) SetPostCategories(ctx context.Context, postID int64, categoryIDs []int64) error {
|
||||
if err := db.queries.DeletePostCategoriesByPost(ctx, postID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, catID := range categoryIDs {
|
||||
if err := db.queries.InsertPostCategory(ctx, sqlgen.InsertPostCategoryParams{
|
||||
PostID: postID,
|
||||
CategoryID: catID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbCategoryToCategory(row sqlgen.Category) *models.Category {
|
||||
return &models.Category{
|
||||
ID: row.ID,
|
||||
SiteID: row.SiteID,
|
||||
GUID: row.Guid,
|
||||
Name: row.Name,
|
||||
Slug: row.Slug,
|
||||
Description: row.Description,
|
||||
CreatedAt: time.Unix(row.CreatedAt, 0).UTC(),
|
||||
UpdatedAt: time.Unix(row.UpdatedAt, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
|
@ -40,6 +40,17 @@ func (db *Provider) Close() error {
|
|||
return db.drvr.Close()
|
||||
}
|
||||
|
||||
func (db *Provider) BeginTx(ctx context.Context) (*sql.Tx, error) {
|
||||
return db.drvr.BeginTx(ctx, nil)
|
||||
}
|
||||
|
||||
func (db *Provider) QueriesWithTx(tx *sql.Tx) *Provider {
|
||||
return &Provider{
|
||||
drvr: db.drvr,
|
||||
queries: db.queries.WithTx(tx),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Provider) SoftDeletePost(ctx context.Context, postID int64) error {
|
||||
return db.queries.SoftDeletePost(ctx, sqlgen.SoftDeletePostParams{
|
||||
DeletedAt: time.Now().Unix(),
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ func TestProvider_Posts(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.NotZero(t, post.ID)
|
||||
|
||||
posts, err := p.SelectPostsOfSite(ctx, site.ID, false)
|
||||
posts, err := p.SelectPostsOfSite(ctx, site.ID, false, db.PagingParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, posts, 1)
|
||||
assert.Equal(t, post.ID, posts[0].ID)
|
||||
|
|
@ -205,7 +205,7 @@ func TestProvider_Posts(t *testing.T) {
|
|||
require.NoError(t, p.SavePost(ctx, post1))
|
||||
require.NoError(t, p.SavePost(ctx, post2))
|
||||
|
||||
posts, err := p.SelectPostsOfSite(ctx, site2.ID, false)
|
||||
posts, err := p.SelectPostsOfSite(ctx, site2.ID, false, db.PagingParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, posts, 2)
|
||||
assert.Equal(t, "New Post", posts[0].Title)
|
||||
|
|
@ -220,7 +220,7 @@ func TestProvider_Posts(t *testing.T) {
|
|||
}
|
||||
require.NoError(t, p.SaveSite(ctx, emptySite))
|
||||
|
||||
posts, err := p.SelectPostsOfSite(ctx, emptySite.ID, false)
|
||||
posts, err := p.SelectPostsOfSite(ctx, emptySite.ID, false, db.PagingParams{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, posts)
|
||||
})
|
||||
|
|
@ -283,6 +283,165 @@ func TestProvider_PublishTargets(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestProvider_Categories(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p := newTestDB(t)
|
||||
|
||||
user := &models.User{Username: "testuser", PasswordHashed: []byte("password")}
|
||||
require.NoError(t, p.SaveUser(ctx, user))
|
||||
|
||||
site := &models.Site{OwnerID: user.ID, Title: "My Blog", Tagline: "test"}
|
||||
require.NoError(t, p.SaveSite(ctx, site))
|
||||
|
||||
t.Run("save and select categories", func(t *testing.T) {
|
||||
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
|
||||
cat := &models.Category{
|
||||
SiteID: site.ID,
|
||||
GUID: "cat-001",
|
||||
Name: "Go Programming",
|
||||
Slug: "go-programming",
|
||||
Description: "Posts about Go",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
err := p.SaveCategory(ctx, cat)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, cat.ID)
|
||||
|
||||
cats, err := p.SelectCategoriesOfSite(ctx, site.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cats, 1)
|
||||
assert.Equal(t, "Go Programming", cats[0].Name)
|
||||
assert.Equal(t, "go-programming", cats[0].Slug)
|
||||
assert.Equal(t, "Posts about Go", cats[0].Description)
|
||||
})
|
||||
|
||||
t.Run("update category", func(t *testing.T) {
|
||||
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
|
||||
cat := &models.Category{
|
||||
SiteID: site.ID,
|
||||
GUID: "cat-002",
|
||||
Name: "Original",
|
||||
Slug: "original",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
require.NoError(t, p.SaveCategory(ctx, cat))
|
||||
|
||||
cat.Name = "Updated"
|
||||
cat.Slug = "updated"
|
||||
cat.UpdatedAt = now.Add(time.Hour)
|
||||
require.NoError(t, p.SaveCategory(ctx, cat))
|
||||
|
||||
got, err := p.SelectCategory(ctx, cat.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", got.Name)
|
||||
assert.Equal(t, "updated", got.Slug)
|
||||
})
|
||||
|
||||
t.Run("delete category", func(t *testing.T) {
|
||||
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
|
||||
cat := &models.Category{
|
||||
SiteID: site.ID,
|
||||
GUID: "cat-003",
|
||||
Name: "ToDelete",
|
||||
Slug: "to-delete",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
require.NoError(t, p.SaveCategory(ctx, cat))
|
||||
|
||||
err := p.DeleteCategory(ctx, cat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = p.SelectCategory(ctx, cat.ID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProvider_PostCategories(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p := newTestDB(t)
|
||||
|
||||
user := &models.User{Username: "testuser", PasswordHashed: []byte("password")}
|
||||
require.NoError(t, p.SaveUser(ctx, user))
|
||||
|
||||
site := &models.Site{OwnerID: user.ID, Title: "My Blog", Tagline: "test"}
|
||||
require.NoError(t, p.SaveSite(ctx, site))
|
||||
|
||||
now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC)
|
||||
post := &models.Post{
|
||||
SiteID: site.ID,
|
||||
GUID: "post-pc-001",
|
||||
Title: "Test Post",
|
||||
Body: "body",
|
||||
Slug: "/test",
|
||||
CreatedAt: now,
|
||||
}
|
||||
require.NoError(t, p.SavePost(ctx, post))
|
||||
|
||||
cat1 := &models.Category{SiteID: site.ID, GUID: "cat-pc-1", Name: "Alpha", Slug: "alpha", CreatedAt: now, UpdatedAt: now}
|
||||
cat2 := &models.Category{SiteID: site.ID, GUID: "cat-pc-2", Name: "Beta", Slug: "beta", CreatedAt: now, UpdatedAt: now}
|
||||
require.NoError(t, p.SaveCategory(ctx, cat1))
|
||||
require.NoError(t, p.SaveCategory(ctx, cat2))
|
||||
|
||||
t.Run("set and get post categories", func(t *testing.T) {
|
||||
err := p.SetPostCategories(ctx, post.ID, []int64{cat1.ID, cat2.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cats, 2)
|
||||
assert.Equal(t, "Alpha", cats[0].Name)
|
||||
assert.Equal(t, "Beta", cats[1].Name)
|
||||
})
|
||||
|
||||
t.Run("replace post categories", func(t *testing.T) {
|
||||
err := p.SetPostCategories(ctx, post.ID, []int64{cat2.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cats, 1)
|
||||
assert.Equal(t, "Beta", cats[0].Name)
|
||||
})
|
||||
|
||||
t.Run("clear post categories", func(t *testing.T) {
|
||||
err := p.SetPostCategories(ctx, post.ID, []int64{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, cats)
|
||||
})
|
||||
|
||||
t.Run("count posts of category", func(t *testing.T) {
|
||||
post.State = models.StatePublished
|
||||
post.PublishedAt = now
|
||||
require.NoError(t, p.SavePost(ctx, post))
|
||||
require.NoError(t, p.SetPostCategories(ctx, post.ID, []int64{cat1.ID}))
|
||||
|
||||
count, err := p.CountPostsOfCategory(ctx, cat1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
count, err = p.CountPostsOfCategory(ctx, cat2.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("cascade delete category removes associations", func(t *testing.T) {
|
||||
require.NoError(t, p.SetPostCategories(ctx, post.ID, []int64{cat1.ID, cat2.ID}))
|
||||
require.NoError(t, p.DeleteCategory(ctx, cat1.ID))
|
||||
|
||||
cats, err := p.SelectCategoriesOfPost(ctx, post.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cats, 1)
|
||||
assert.Equal(t, "Beta", cats[0].Name)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify that password encoding roundtrips correctly through base64
|
||||
func TestProvider_UserPasswordEncoding(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
|
|
|||
Loading…
Reference in a new issue