diff --git a/cmds/server.go b/cmds/server.go index 6b2e71b..56517e7 100644 --- a/cmds/server.go +++ b/cmds/server.go @@ -109,7 +109,7 @@ Starting weiro without any arguments will start the server. ih := handlers.IndexHandler{SiteService: svcs.Sites} lh := handlers.LoginHandler{Config: cfg, AuthService: svcs.Auth} - ph := handlers.PostsHandler{PostService: svcs.Posts} + ph := handlers.PostsHandler{PostService: svcs.Posts, CategoryService: svcs.Categories} uh := handlers.UploadsHandler{UploadsService: svcs.Uploads} ssh := handlers.SiteSettingsHandler{SiteService: svcs.Sites} ch := handlers.CategoriesHandler{CategoryService: svcs.Categories} diff --git a/handlers/posts.go b/handlers/posts.go index 3f282e0..e0234fc 100644 --- a/handlers/posts.go +++ b/handlers/posts.go @@ -6,11 +6,13 @@ import ( "github.com/gofiber/fiber/v3" "lmika.dev/lmika/weiro/models" + "lmika.dev/lmika/weiro/services/categories" "lmika.dev/lmika/weiro/services/posts" ) type PostsHandler struct { - PostService *posts.Service + PostService *posts.Service + CategoryService *categories.Service } func (ph PostsHandler) Index(c fiber.Ctx) error { @@ -42,8 +44,15 @@ func (ph PostsHandler) New(c fiber.Ctx) error { State: models.StateDraft, } + cats, err := ph.CategoryService.ListCategories(c.Context()) + if err != nil { + return err + } + return c.Render("posts/edit", fiber.Map{ - "post": p, + "post": p, + "categories": cats, + "selectedCategories": map[int64]bool{}, }) } @@ -62,11 +71,28 @@ func (ph PostsHandler) Edit(c fiber.Ctx) error { return err } + cats, err := ph.CategoryService.ListCategories(c.Context()) + if err != nil { + return err + } + + postCats, err := ph.PostService.GetPostCategories(c.Context(), postID) + if err != nil { + return err + } + + selectedCategories := make(map[int64]bool) + for _, pc := range postCats { + selectedCategories[pc.ID] = true + } + return accepts(c, json(func() any { return post }), html(func(c fiber.Ctx) error { return c.Render("posts/edit", fiber.Map{ - "post": post, + "post": post, + "categories": cats, + "selectedCategories": selectedCategories, }) })) } diff --git a/services/posts/create.go b/services/posts/create.go index f73d49c..b1a6466 100644 --- a/services/posts/create.go +++ b/services/posts/create.go @@ -10,10 +10,11 @@ import ( ) type CreatePostParams struct { - GUID string `form:"guid" json:"guid"` - Title string `form:"title" json:"title"` - Body string `form:"body" json:"body"` - Action string `form:"action" json:"action"` + GUID string `form:"guid" json:"guid"` + Title string `form:"title" json:"title"` + Body string `form:"body" json:"body"` + Action string `form:"action" json:"action"` + CategoryIDs []int64 `form:"category_ids" json:"category_ids"` } func (s *Service) UpdatePost(ctx context.Context, params CreatePostParams) (*models.Post, error) { @@ -53,7 +54,21 @@ func (s *Service) UpdatePost(ctx context.Context, params CreatePostParams) (*mod // Leave unchanged } - if err := s.db.SavePost(ctx, post); err != nil { + // Use a transaction for atomicity of post save + category reassignment + tx, err := s.db.BeginTx(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback() + + txDB := s.db.QueriesWithTx(tx) + if err := txDB.SavePost(ctx, post); err != nil { + return nil, err + } + if err := txDB.SetPostCategories(ctx, post.ID, params.CategoryIDs); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { return nil, err } diff --git a/services/posts/list.go b/services/posts/list.go index ae70e1c..15e14d3 100644 --- a/services/posts/list.go +++ b/services/posts/list.go @@ -7,7 +7,12 @@ import ( "lmika.dev/lmika/weiro/providers/db" ) -func (s *Service) ListPosts(ctx context.Context, showDeleted bool) ([]*models.Post, error) { +type PostWithCategories struct { + *models.Post + Categories []*models.Category +} + +func (s *Service) ListPosts(ctx context.Context, showDeleted bool) ([]*PostWithCategories, error) { site, ok := models.GetSite(ctx) if !ok { return nil, models.SiteRequiredError @@ -21,7 +26,15 @@ func (s *Service) ListPosts(ctx context.Context, showDeleted bool) ([]*models.Po return nil, err } - return posts, nil + result := make([]*PostWithCategories, len(posts)) + for i, post := range posts { + cats, err := s.db.SelectCategoriesOfPost(ctx, post.ID) + if err != nil { + return nil, err + } + result[i] = &PostWithCategories{Post: post, Categories: cats} + } + return result, nil } func (s *Service) GetPost(ctx context.Context, pid int64) (*models.Post, error) { @@ -32,3 +45,7 @@ func (s *Service) GetPost(ctx context.Context, pid int64) (*models.Post, error) return post, nil } + +func (s *Service) GetPostCategories(ctx context.Context, postID int64) ([]*models.Category, error) { + return s.db.SelectCategoriesOfPost(ctx, postID) +} diff --git a/views/posts/edit.html b/views/posts/edit.html index 475c9a0..07be770 100644 --- a/views/posts/edit.html +++ b/views/posts/edit.html @@ -4,20 +4,41 @@ data-controller="postedit" data-action="keydown.meta+s->postedit#save keydown.meta+enter->postedit#publish" data-postedit-save-action-value="{{ if $isPublished }}Update{{ else }}Save Draft{{ end }}"> - -