package middleware import ( "strconv" "emperror.dev/errors" "github.com/gofiber/fiber/v3" "lmika.dev/lmika/weiro/models" "lmika.dev/lmika/weiro/providers/db" "lmika.dev/lmika/weiro/services/sites" ) func RequiresSite(sites *sites.Service) func(c fiber.Ctx) error { return func(c fiber.Ctx) error { siteIDStr := c.Params("siteID") if siteIDStr == "" { return fiber.ErrBadRequest } siteID, err := strconv.ParseInt(siteIDStr, 10, 64) if err != nil { return fiber.ErrBadRequest } site, err := sites.GetSiteByID(c.Context(), siteID) if err != nil { if errors.Is(err, models.UserRequiredError) { return fiber.ErrForbidden } else if errors.Is(err, models.PermissionError) || db.ErrorIsNoRows(err) { return fiber.ErrNotFound } else if errors.Is(err, models.NotFoundError) || db.ErrorIsNoRows(err) { return err } } c.Locals("site", site) c.SetContext(models.WithSite(c.Context(), site)) return c.Next() } }