diff --git a/providers/db/categories.go b/providers/db/categories.go new file mode 100644 index 0000000..72fac94 --- /dev/null +++ b/providers/db/categories.go @@ -0,0 +1,132 @@ +package db + +import ( + "context" + "time" + + "lmika.dev/lmika/weiro/models" + "lmika.dev/lmika/weiro/providers/db/gen/sqlgen" +) + +func (db *Provider) SelectCategoriesOfSite(ctx context.Context, siteID int64) ([]*models.Category, error) { + rows, err := db.queries.SelectCategoriesOfSite(ctx, siteID) + if err != nil { + return nil, err + } + cats := make([]*models.Category, len(rows)) + for i, row := range rows { + cats[i] = dbCategoryToCategory(row) + } + return cats, nil +} + +func (db *Provider) SelectCategory(ctx context.Context, id int64) (*models.Category, error) { + row, err := db.queries.SelectCategory(ctx, id) + if err != nil { + return nil, err + } + return dbCategoryToCategory(row), nil +} + +func (db *Provider) SelectCategoryBySlugAndSite(ctx context.Context, siteID int64, slug string) (*models.Category, error) { + row, err := db.queries.SelectCategoryBySlugAndSite(ctx, sqlgen.SelectCategoryBySlugAndSiteParams{ + SiteID: siteID, + Slug: slug, + }) + if err != nil { + return nil, err + } + return dbCategoryToCategory(row), nil +} + +func (db *Provider) SaveCategory(ctx context.Context, cat *models.Category) error { + if cat.ID == 0 { + newID, err := db.queries.InsertCategory(ctx, sqlgen.InsertCategoryParams{ + SiteID: cat.SiteID, + Guid: cat.GUID, + Name: cat.Name, + Slug: cat.Slug, + Description: cat.Description, + CreatedAt: timeToInt(cat.CreatedAt), + UpdatedAt: timeToInt(cat.UpdatedAt), + }) + if err != nil { + return err + } + cat.ID = newID + return nil + } + + return db.queries.UpdateCategory(ctx, sqlgen.UpdateCategoryParams{ + ID: cat.ID, + Name: cat.Name, + Slug: cat.Slug, + Description: cat.Description, + UpdatedAt: timeToInt(cat.UpdatedAt), + }) +} + +func (db *Provider) DeleteCategory(ctx context.Context, id int64) error { + return db.queries.DeleteCategory(ctx, id) +} + +func (db *Provider) SelectCategoriesOfPost(ctx context.Context, postID int64) ([]*models.Category, error) { + rows, err := db.queries.SelectCategoriesOfPost(ctx, postID) + if err != nil { + return nil, err + } + cats := make([]*models.Category, len(rows)) + for i, row := range rows { + cats[i] = dbCategoryToCategory(row) + } + return cats, nil +} + +func (db *Provider) SelectPostsOfCategory(ctx context.Context, categoryID int64, pp PagingParams) ([]*models.Post, error) { + rows, err := db.queries.SelectPostsOfCategory(ctx, sqlgen.SelectPostsOfCategoryParams{ + CategoryID: categoryID, + Limit: pp.Limit, + Offset: pp.Offset, + }) + if err != nil { + return nil, err + } + posts := make([]*models.Post, len(rows)) + for i, row := range rows { + posts[i] = dbPostToPost(row) + } + return posts, nil +} + +func (db *Provider) CountPostsOfCategory(ctx context.Context, categoryID int64) (int64, error) { + return db.queries.CountPostsOfCategory(ctx, categoryID) +} + +// SetPostCategories replaces all category associations for a post. +func (db *Provider) SetPostCategories(ctx context.Context, postID int64, categoryIDs []int64) error { + if err := db.queries.DeletePostCategoriesByPost(ctx, postID); err != nil { + return err + } + for _, catID := range categoryIDs { + if err := db.queries.InsertPostCategory(ctx, sqlgen.InsertPostCategoryParams{ + PostID: postID, + CategoryID: catID, + }); err != nil { + return err + } + } + return nil +} + +func dbCategoryToCategory(row sqlgen.Category) *models.Category { + return &models.Category{ + ID: row.ID, + SiteID: row.SiteID, + GUID: row.Guid, + Name: row.Name, + Slug: row.Slug, + Description: row.Description, + CreatedAt: time.Unix(row.CreatedAt, 0).UTC(), + UpdatedAt: time.Unix(row.UpdatedAt, 0).UTC(), + } +} diff --git a/providers/db/provider.go b/providers/db/provider.go index eda0513..cc35225 100644 --- a/providers/db/provider.go +++ b/providers/db/provider.go @@ -40,6 +40,17 @@ func (db *Provider) Close() error { return db.drvr.Close() } +func (db *Provider) BeginTx(ctx context.Context) (*sql.Tx, error) { + return db.drvr.BeginTx(ctx, nil) +} + +func (db *Provider) QueriesWithTx(tx *sql.Tx) *Provider { + return &Provider{ + drvr: db.drvr, + queries: db.queries.WithTx(tx), + } +} + func (db *Provider) SoftDeletePost(ctx context.Context, postID int64) error { return db.queries.SoftDeletePost(ctx, sqlgen.SoftDeletePostParams{ DeletedAt: time.Now().Unix(), diff --git a/providers/db/provider_test.go b/providers/db/provider_test.go index 4781d61..caf83d1 100644 --- a/providers/db/provider_test.go +++ b/providers/db/provider_test.go @@ -158,7 +158,7 @@ func TestProvider_Posts(t *testing.T) { require.NoError(t, err) assert.NotZero(t, post.ID) - posts, err := p.SelectPostsOfSite(ctx, site.ID, false) + posts, err := p.SelectPostsOfSite(ctx, site.ID, false, db.PagingParams{}) require.NoError(t, err) require.Len(t, posts, 1) assert.Equal(t, post.ID, posts[0].ID) @@ -205,7 +205,7 @@ func TestProvider_Posts(t *testing.T) { require.NoError(t, p.SavePost(ctx, post1)) require.NoError(t, p.SavePost(ctx, post2)) - posts, err := p.SelectPostsOfSite(ctx, site2.ID, false) + posts, err := p.SelectPostsOfSite(ctx, site2.ID, false, db.PagingParams{}) require.NoError(t, err) require.Len(t, posts, 2) assert.Equal(t, "New Post", posts[0].Title) @@ -220,7 +220,7 @@ func TestProvider_Posts(t *testing.T) { } require.NoError(t, p.SaveSite(ctx, emptySite)) - posts, err := p.SelectPostsOfSite(ctx, emptySite.ID, false) + posts, err := p.SelectPostsOfSite(ctx, emptySite.ID, false, db.PagingParams{}) require.NoError(t, err) assert.Empty(t, posts) }) @@ -283,6 +283,165 @@ func TestProvider_PublishTargets(t *testing.T) { }) } +func TestProvider_Categories(t *testing.T) { + ctx := context.Background() + p := newTestDB(t) + + user := &models.User{Username: "testuser", PasswordHashed: []byte("password")} + require.NoError(t, p.SaveUser(ctx, user)) + + site := &models.Site{OwnerID: user.ID, Title: "My Blog", Tagline: "test"} + require.NoError(t, p.SaveSite(ctx, site)) + + t.Run("save and select categories", func(t *testing.T) { + now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) + cat := &models.Category{ + SiteID: site.ID, + GUID: "cat-001", + Name: "Go Programming", + Slug: "go-programming", + Description: "Posts about Go", + CreatedAt: now, + UpdatedAt: now, + } + + err := p.SaveCategory(ctx, cat) + require.NoError(t, err) + assert.NotZero(t, cat.ID) + + cats, err := p.SelectCategoriesOfSite(ctx, site.ID) + require.NoError(t, err) + require.Len(t, cats, 1) + assert.Equal(t, "Go Programming", cats[0].Name) + assert.Equal(t, "go-programming", cats[0].Slug) + assert.Equal(t, "Posts about Go", cats[0].Description) + }) + + t.Run("update category", func(t *testing.T) { + now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) + cat := &models.Category{ + SiteID: site.ID, + GUID: "cat-002", + Name: "Original", + Slug: "original", + CreatedAt: now, + UpdatedAt: now, + } + require.NoError(t, p.SaveCategory(ctx, cat)) + + cat.Name = "Updated" + cat.Slug = "updated" + cat.UpdatedAt = now.Add(time.Hour) + require.NoError(t, p.SaveCategory(ctx, cat)) + + got, err := p.SelectCategory(ctx, cat.ID) + require.NoError(t, err) + assert.Equal(t, "Updated", got.Name) + assert.Equal(t, "updated", got.Slug) + }) + + t.Run("delete category", func(t *testing.T) { + now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) + cat := &models.Category{ + SiteID: site.ID, + GUID: "cat-003", + Name: "ToDelete", + Slug: "to-delete", + CreatedAt: now, + UpdatedAt: now, + } + require.NoError(t, p.SaveCategory(ctx, cat)) + + err := p.DeleteCategory(ctx, cat.ID) + require.NoError(t, err) + + _, err = p.SelectCategory(ctx, cat.ID) + assert.Error(t, err) + }) +} + +func TestProvider_PostCategories(t *testing.T) { + ctx := context.Background() + p := newTestDB(t) + + user := &models.User{Username: "testuser", PasswordHashed: []byte("password")} + require.NoError(t, p.SaveUser(ctx, user)) + + site := &models.Site{OwnerID: user.ID, Title: "My Blog", Tagline: "test"} + require.NoError(t, p.SaveSite(ctx, site)) + + now := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) + post := &models.Post{ + SiteID: site.ID, + GUID: "post-pc-001", + Title: "Test Post", + Body: "body", + Slug: "/test", + CreatedAt: now, + } + require.NoError(t, p.SavePost(ctx, post)) + + cat1 := &models.Category{SiteID: site.ID, GUID: "cat-pc-1", Name: "Alpha", Slug: "alpha", CreatedAt: now, UpdatedAt: now} + cat2 := &models.Category{SiteID: site.ID, GUID: "cat-pc-2", Name: "Beta", Slug: "beta", CreatedAt: now, UpdatedAt: now} + require.NoError(t, p.SaveCategory(ctx, cat1)) + require.NoError(t, p.SaveCategory(ctx, cat2)) + + t.Run("set and get post categories", func(t *testing.T) { + err := p.SetPostCategories(ctx, post.ID, []int64{cat1.ID, cat2.ID}) + require.NoError(t, err) + + cats, err := p.SelectCategoriesOfPost(ctx, post.ID) + require.NoError(t, err) + require.Len(t, cats, 2) + assert.Equal(t, "Alpha", cats[0].Name) + assert.Equal(t, "Beta", cats[1].Name) + }) + + t.Run("replace post categories", func(t *testing.T) { + err := p.SetPostCategories(ctx, post.ID, []int64{cat2.ID}) + require.NoError(t, err) + + cats, err := p.SelectCategoriesOfPost(ctx, post.ID) + require.NoError(t, err) + require.Len(t, cats, 1) + assert.Equal(t, "Beta", cats[0].Name) + }) + + t.Run("clear post categories", func(t *testing.T) { + err := p.SetPostCategories(ctx, post.ID, []int64{}) + require.NoError(t, err) + + cats, err := p.SelectCategoriesOfPost(ctx, post.ID) + require.NoError(t, err) + assert.Empty(t, cats) + }) + + t.Run("count posts of category", func(t *testing.T) { + post.State = models.StatePublished + post.PublishedAt = now + require.NoError(t, p.SavePost(ctx, post)) + require.NoError(t, p.SetPostCategories(ctx, post.ID, []int64{cat1.ID})) + + count, err := p.CountPostsOfCategory(ctx, cat1.ID) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + + count, err = p.CountPostsOfCategory(ctx, cat2.ID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) + }) + + t.Run("cascade delete category removes associations", func(t *testing.T) { + require.NoError(t, p.SetPostCategories(ctx, post.ID, []int64{cat1.ID, cat2.ID})) + require.NoError(t, p.DeleteCategory(ctx, cat1.ID)) + + cats, err := p.SelectCategoriesOfPost(ctx, post.ID) + require.NoError(t, err) + require.Len(t, cats, 1) + assert.Equal(t, "Beta", cats[0].Name) + }) +} + // Verify that password encoding roundtrips correctly through base64 func TestProvider_UserPasswordEncoding(t *testing.T) { ctx := context.Background()