weiro/services/uploads/pending.go

166 lines
3.9 KiB
Go

package uploads
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"os"
"path/filepath"
"time"
"emperror.dev/errors"
"lmika.dev/lmika/weiro/models"
)
type NewPendingRequest struct {
FileSize int64 `json:"size"`
Filename string `json:"name"`
MIMEType string `json:"type"`
}
func (s *Service) NewPending(ctx context.Context, req NewPendingRequest) (models.PendingUpload, error) {
site, user, err := s.fetchSiteAndUser(ctx)
if err != nil {
return models.PendingUpload{}, err
}
pending := models.PendingUpload{
GUID: models.NewNanoID(),
SiteID: site.ID,
UserID: user.ID,
FileSize: req.FileSize,
Filename: req.Filename,
MIMEType: req.MIMEType,
UploadStarted: time.Now(),
}
if err := s.db.SavePendingUpload(ctx, &pending); err != nil {
return models.PendingUpload{}, err
}
if err := os.MkdirAll(s.pendingDir, 0755); err != nil {
return models.PendingUpload{}, err
}
pendingDataFile, err := os.Create(filepath.Join(s.pendingDir, pending.GUID+".upload"))
if err != nil {
return models.PendingUpload{}, err
}
return pending, pendingDataFile.Close()
}
func (s *Service) WriteToPending(ctx context.Context, pendingGUID string, data []byte) error {
site, user, err := s.fetchSiteAndUser(ctx)
if err != nil {
return err
}
pu, err := s.db.SelectPendingUploadByGUID(ctx, pendingGUID)
if err != nil {
return err
} else if pu.SiteID != site.ID || pu.UserID != user.ID {
return errors.New("invalid pending upload")
}
pendingDataFilename := filepath.Join(s.pendingDir, pu.GUID+".upload")
if _, err := os.Stat(pendingDataFilename); err != nil {
return err
}
pendingDataFile, err := os.OpenFile(pendingDataFilename, os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return err
}
defer pendingDataFile.Close()
pendingDataFile.Seek(0, io.SeekEnd)
if _, err := pendingDataFile.Write(data); err != nil {
return err
}
return nil
}
func (s *Service) FinalizePending(ctx context.Context, pendingGUID string, expectedHash string) error {
site, user, err := s.fetchSiteAndUser(ctx)
if err != nil {
return err
}
pu, err := s.db.SelectPendingUploadByGUID(ctx, pendingGUID)
if err != nil {
return err
} else if pu.SiteID != site.ID || pu.UserID != user.ID {
return errors.New("invalid pending upload")
}
pendingDataFilename := filepath.Join(s.pendingDir, pu.GUID+".upload")
if err := s.verifyPendingUpload(pendingDataFilename, expectedHash); err != nil {
return err
}
newUploadGUID := models.NewNanoID()
newTime := time.Now().UTC()
newSlug := filepath.Join(
fmt.Sprintf("%04d", newTime.Year()),
fmt.Sprintf("%02d", newTime.Month()),
newUploadGUID+filepath.Ext(pu.Filename),
)
newUpload := models.Upload{
SiteID: site.ID,
GUID: models.NewNanoID(),
FileSize: pu.FileSize,
MIMEType: pu.MIMEType,
Filename: pu.Filename,
CreatedAt: newTime,
Slug: newSlug,
}
if err := s.db.SaveUpload(ctx, &newUpload); err != nil {
return err
}
if err := s.up.AdoptFile(site, newUpload, pendingDataFilename); err != nil {
return err
}
if err := s.db.DeletePendingUpload(ctx, newUpload.GUID); err != nil {
return err
}
if err := s.up.StripeEXIFData(site, newUpload); err != nil {
log.Printf("warn: failed to extract exif data from %s: %v\n", newUpload.Slug, err)
}
return nil
}
func (s *Service) verifyPendingUpload(pendingDataFilename string, expectedHash string) error {
expectedHashBytes, err := hex.DecodeString(expectedHash)
if err != nil {
return err
}
if _, err := os.Stat(pendingDataFilename); err != nil {
return err
}
pendingDataFile, err := os.Open(pendingDataFilename)
if err != nil {
return err
}
defer pendingDataFile.Close()
shaSum := sha256.New()
if _, err := io.Copy(shaSum, pendingDataFile); err != nil {
return err
}
if !bytes.Equal(shaSum.Sum(nil), expectedHashBytes) {
return errors.New("hash mismatch")
}
return nil
}