weiro/handlers/middleware/site.go

55 lines
1.3 KiB
Go

package middleware
import (
"strconv"
"emperror.dev/errors"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
"lmika.dev/lmika/weiro/models"
"lmika.dev/lmika/weiro/providers/db"
"lmika.dev/lmika/weiro/services/sites"
)
func RequiresSite(sites *sites.Service) func(c fiber.Ctx) error {
return func(c fiber.Ctx) error {
siteIDStr := c.Params("siteID")
if siteIDStr == "" {
return fiber.ErrBadRequest
}
siteID, err := strconv.ParseInt(siteIDStr, 10, 64)
if err != nil {
return fiber.ErrBadRequest
}
site, err := sites.GetSiteByID(c.Context(), siteID)
if err != nil {
if errors.Is(err, models.UserRequiredError) {
return fiber.ErrForbidden
} else if errors.Is(err, models.PermissionError) || db.ErrorIsNoRows(err) {
return fiber.ErrNotFound
} else if errors.Is(err, models.NotFoundError) || db.ErrorIsNoRows(err) {
return err
}
}
c.Locals("site", site)
c.SetContext(models.WithSite(c.Context(), site))
sitesOwnedByUser, err := sites.ListSites(c.Context())
if err != nil {
return err
}
c.Locals("allSites", sitesOwnedByUser)
sess := session.FromContext(c)
sess.Set("last_site_id", siteID)
if pubTargets, err := sites.BestPubTarget(c.Context(), site); err == nil {
c.Locals("pubTarget", pubTargets)
}
return c.Next()
}
}