diff --git a/providers/db/gen/sqlgen/posts.sql.go b/providers/db/gen/sqlgen/posts.sql.go index 8bff191..ef3d170 100644 --- a/providers/db/gen/sqlgen/posts.sql.go +++ b/providers/db/gen/sqlgen/posts.sql.go @@ -9,6 +9,28 @@ import ( "context" ) +const countPostsOfSite = `-- name: CountPostsOfSite :one +SELECT COUNT(*) FROM posts +WHERE site_id = ?1 AND ( + CASE CAST (?2 AS TEXT) + WHEN 'deleted' THEN deleted_at > 0 + ELSE deleted_at = 0 + END +) +` + +type CountPostsOfSiteParams struct { + SiteID int64 + PostFilter string +} + +func (q *Queries) CountPostsOfSite(ctx context.Context, arg CountPostsOfSiteParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countPostsOfSite, arg.SiteID, arg.PostFilter) + var count int64 + err := row.Scan(&count) + return count, err +} + const hardDeletePost = `-- name: HardDeletePost :exec DELETE FROM posts WHERE id = ? ` diff --git a/providers/db/posts.go b/providers/db/posts.go index 218e931..7f58d1a 100644 --- a/providers/db/posts.go +++ b/providers/db/posts.go @@ -13,6 +13,17 @@ type PagingParams struct { Offset int64 } +func (db *Provider) CountPostsOfSite(ctx context.Context, siteID int64, showDeleted bool) (int64, error) { + filter := "active" + if showDeleted { + filter = "deleted" + } + return db.queries.CountPostsOfSite(ctx, sqlgen.CountPostsOfSiteParams{ + SiteID: siteID, + PostFilter: filter, + }) +} + func (db *Provider) SelectPostsOfSite(ctx context.Context, siteID int64, showDeleted bool, pp PagingParams) ([]*models.Post, error) { var filter = "" if showDeleted { diff --git a/providers/db/provider_test.go b/providers/db/provider_test.go index 06f03c0..0a2e6df 100644 --- a/providers/db/provider_test.go +++ b/providers/db/provider_test.go @@ -3,6 +3,7 @@ package db_test import ( "context" "encoding/base64" + "fmt" "path/filepath" "testing" "time" @@ -229,6 +230,45 @@ func TestProvider_Posts(t *testing.T) { require.NoError(t, err) assert.Empty(t, posts) }) + + t.Run("count posts of site", func(t *testing.T) { + countSite := &models.Site{ + OwnerID: user.ID, + GUID: models.NewNanoID(), + Title: "Count Blog", + } + require.NoError(t, p.SaveSite(ctx, countSite)) + + now := time.Date(2026, 3, 22, 12, 0, 0, 0, time.UTC) + for i := 0; i < 3; i++ { + post := &models.Post{ + SiteID: countSite.ID, + GUID: models.NewNanoID(), + Title: fmt.Sprintf("Post %d", i), + Body: "body", + Slug: fmt.Sprintf("/post-%d", i), + CreatedAt: now, + } + require.NoError(t, p.SavePost(ctx, post)) + } + + count, err := p.CountPostsOfSite(ctx, countSite.ID, false) + require.NoError(t, err) + assert.Equal(t, int64(3), count) + + // Soft-delete one post + posts, err := p.SelectPostsOfSite(ctx, countSite.ID, false, db.PagingParams{Limit: 10, Offset: 0}) + require.NoError(t, err) + require.NoError(t, p.SoftDeletePost(ctx, posts[0].ID)) + + count, err = p.CountPostsOfSite(ctx, countSite.ID, false) + require.NoError(t, err) + assert.Equal(t, int64(2), count) + + count, err = p.CountPostsOfSite(ctx, countSite.ID, true) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + }) } func TestProvider_PublishTargets(t *testing.T) { diff --git a/sql/queries/posts.sql b/sql/queries/posts.sql index dae1f39..5a4c18e 100644 --- a/sql/queries/posts.sql +++ b/sql/queries/posts.sql @@ -1,3 +1,12 @@ +-- name: CountPostsOfSite :one +SELECT COUNT(*) FROM posts +WHERE site_id = sqlc.arg(site_id) AND ( + CASE CAST (sqlc.arg(post_filter) AS TEXT) + WHEN 'deleted' THEN deleted_at > 0 + ELSE deleted_at = 0 + END +); + -- name: SelectPostsOfSite :many SELECT * FROM posts