When multiple chunk uploads race on the pending→uploading transition, ignore ErrRecordNotFound from UpdateSessionStatus since another request already completed the update. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
444 lines
12 KiB
Go
444 lines
12 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"gcy_hpc_server/internal/config"
|
|
"gcy_hpc_server/internal/model"
|
|
"gcy_hpc_server/internal/storage"
|
|
"gcy_hpc_server/internal/store"
|
|
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
type UploadService struct {
|
|
storage storage.ObjectStorage
|
|
blobStore *store.BlobStore
|
|
fileStore *store.FileStore
|
|
uploadStore *store.UploadStore
|
|
cfg config.MinioConfig
|
|
db *gorm.DB
|
|
logger *zap.Logger
|
|
}
|
|
|
|
func NewUploadService(
|
|
st storage.ObjectStorage,
|
|
blobStore *store.BlobStore,
|
|
fileStore *store.FileStore,
|
|
uploadStore *store.UploadStore,
|
|
cfg config.MinioConfig,
|
|
db *gorm.DB,
|
|
logger *zap.Logger,
|
|
) *UploadService {
|
|
return &UploadService{
|
|
storage: st,
|
|
blobStore: blobStore,
|
|
fileStore: fileStore,
|
|
uploadStore: uploadStore,
|
|
cfg: cfg,
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (s *UploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
|
|
if err := model.ValidateFileName(req.FileName); err != nil {
|
|
return nil, fmt.Errorf("invalid file name: %w", err)
|
|
}
|
|
|
|
if req.FileSize < 0 {
|
|
return nil, fmt.Errorf("file size cannot be negative")
|
|
}
|
|
|
|
if req.FileSize > s.cfg.MaxFileSize {
|
|
return nil, fmt.Errorf("file size %d exceeds maximum %d", req.FileSize, s.cfg.MaxFileSize)
|
|
}
|
|
|
|
chunkSize := s.cfg.ChunkSize
|
|
if req.ChunkSize != nil {
|
|
chunkSize = *req.ChunkSize
|
|
}
|
|
if chunkSize < s.cfg.MinChunkSize {
|
|
return nil, fmt.Errorf("chunk size %d is below minimum %d", chunkSize, s.cfg.MinChunkSize)
|
|
}
|
|
|
|
totalChunks := int((req.FileSize + chunkSize - 1) / chunkSize)
|
|
if req.FileSize == 0 {
|
|
totalChunks = 0
|
|
}
|
|
if totalChunks > 10000 {
|
|
return nil, fmt.Errorf("total chunks %d exceeds limit of 10000", totalChunks)
|
|
}
|
|
|
|
blob, err := s.blobStore.GetBySHA256(ctx, req.SHA256)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check dedup: %w", err)
|
|
}
|
|
if blob != nil {
|
|
if err := s.blobStore.IncrementRef(ctx, req.SHA256); err != nil {
|
|
return nil, fmt.Errorf("increment ref: %w", err)
|
|
}
|
|
file := &model.File{
|
|
Name: req.FileName,
|
|
FolderID: req.FolderID,
|
|
BlobSHA256: req.SHA256,
|
|
}
|
|
if err := s.fileStore.Create(ctx, file); err != nil {
|
|
return nil, fmt.Errorf("create file: %w", err)
|
|
}
|
|
return model.FileResponse{
|
|
ID: file.ID,
|
|
Name: file.Name,
|
|
FolderID: file.FolderID,
|
|
Size: blob.FileSize,
|
|
MimeType: blob.MimeType,
|
|
SHA256: blob.SHA256,
|
|
CreatedAt: file.CreatedAt,
|
|
UpdatedAt: file.UpdatedAt,
|
|
}, nil
|
|
}
|
|
|
|
mimeType := req.MimeType
|
|
if mimeType == "" {
|
|
mimeType = "application/octet-stream"
|
|
}
|
|
|
|
session := &model.UploadSession{
|
|
FileName: req.FileName,
|
|
FileSize: req.FileSize,
|
|
ChunkSize: chunkSize,
|
|
TotalChunks: totalChunks,
|
|
SHA256: req.SHA256,
|
|
FolderID: req.FolderID,
|
|
Status: "pending",
|
|
MinioPrefix: fmt.Sprintf("uploads/%d/", time.Now().UnixNano()),
|
|
MimeType: mimeType,
|
|
ExpiresAt: time.Now().Add(time.Duration(s.cfg.SessionTTL) * time.Hour),
|
|
}
|
|
|
|
if err := s.uploadStore.CreateSession(ctx, session); err != nil {
|
|
return nil, fmt.Errorf("create session: %w", err)
|
|
}
|
|
|
|
return model.UploadSessionResponse{
|
|
ID: session.ID,
|
|
FileName: session.FileName,
|
|
FileSize: session.FileSize,
|
|
ChunkSize: session.ChunkSize,
|
|
TotalChunks: session.TotalChunks,
|
|
SHA256: session.SHA256,
|
|
Status: session.Status,
|
|
UploadedChunks: []int{},
|
|
ExpiresAt: session.ExpiresAt,
|
|
CreatedAt: session.CreatedAt,
|
|
}, nil
|
|
}
|
|
|
|
func (s *UploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
|
|
session, err := s.uploadStore.GetSession(ctx, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("get session: %w", err)
|
|
}
|
|
if session == nil {
|
|
return fmt.Errorf("session not found")
|
|
}
|
|
|
|
switch session.Status {
|
|
case "pending", "uploading", "failed":
|
|
default:
|
|
return fmt.Errorf("cannot upload to session with status %q", session.Status)
|
|
}
|
|
|
|
if chunkIndex < 0 || chunkIndex >= session.TotalChunks {
|
|
return fmt.Errorf("chunk index %d out of range [0, %d)", chunkIndex, session.TotalChunks)
|
|
}
|
|
|
|
key := fmt.Sprintf("%schunk_%05d", session.MinioPrefix, chunkIndex)
|
|
|
|
hasher := sha256.New()
|
|
teeReader := io.TeeReader(reader, hasher)
|
|
|
|
_, err = s.storage.PutObject(ctx, s.cfg.Bucket, key, teeReader, size, storage.PutObjectOptions{
|
|
DisableMultipart: true,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("put object: %w", err)
|
|
}
|
|
|
|
chunkSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
|
|
|
chunk := &model.UploadChunk{
|
|
SessionID: sessionID,
|
|
ChunkIndex: chunkIndex,
|
|
MinioKey: key,
|
|
SHA256: chunkSHA256,
|
|
Size: size,
|
|
Status: "uploaded",
|
|
}
|
|
|
|
if err := s.uploadStore.UpsertChunk(ctx, chunk); err != nil {
|
|
return fmt.Errorf("upsert chunk: %w", err)
|
|
}
|
|
|
|
if session.Status == "pending" {
|
|
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "uploading"); err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("update status to uploading: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
|
|
session, chunks, err := s.uploadStore.GetSessionWithChunks(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get session: %w", err)
|
|
}
|
|
if session == nil {
|
|
return nil, fmt.Errorf("session not found")
|
|
}
|
|
|
|
uploadedChunks := make([]int, 0, len(chunks))
|
|
for _, c := range chunks {
|
|
if c.Status == "uploaded" {
|
|
uploadedChunks = append(uploadedChunks, c.ChunkIndex)
|
|
}
|
|
}
|
|
|
|
return &model.UploadSessionResponse{
|
|
ID: session.ID,
|
|
FileName: session.FileName,
|
|
FileSize: session.FileSize,
|
|
ChunkSize: session.ChunkSize,
|
|
TotalChunks: session.TotalChunks,
|
|
SHA256: session.SHA256,
|
|
Status: session.Status,
|
|
UploadedChunks: uploadedChunks,
|
|
ExpiresAt: session.ExpiresAt,
|
|
CreatedAt: session.CreatedAt,
|
|
}, nil
|
|
}
|
|
|
|
func (s *UploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
|
|
var fileResp *model.FileResponse
|
|
var totalChunks int
|
|
var minioPrefix string
|
|
|
|
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
var session model.UploadSession
|
|
query := tx.WithContext(ctx)
|
|
if !isSQLite(tx) {
|
|
query = query.Clauses(clause.Locking{Strength: "UPDATE"})
|
|
}
|
|
if err := query.First(&session, sessionID).Error; err != nil {
|
|
return fmt.Errorf("get session: %w", err)
|
|
}
|
|
|
|
switch session.Status {
|
|
case "uploading", "failed":
|
|
default:
|
|
return fmt.Errorf("cannot complete session with status %q", session.Status)
|
|
}
|
|
|
|
if session.TotalChunks > 0 {
|
|
var chunkCount int64
|
|
if err := tx.WithContext(ctx).Model(&model.UploadChunk{}).
|
|
Where("session_id = ? AND status = ?", sessionID, "uploaded").
|
|
Count(&chunkCount).Error; err != nil {
|
|
return fmt.Errorf("count chunks: %w", err)
|
|
}
|
|
if int(chunkCount) != session.TotalChunks {
|
|
return fmt.Errorf("not all chunks uploaded: %d/%d", chunkCount, session.TotalChunks)
|
|
}
|
|
}
|
|
|
|
totalChunks = session.TotalChunks
|
|
minioPrefix = session.MinioPrefix
|
|
|
|
if err := updateStatusTx(tx, ctx, sessionID, "merging"); err != nil {
|
|
return fmt.Errorf("update status to merging: %w", err)
|
|
}
|
|
|
|
blob, err := s.blobStore.GetBySHA256ForUpdate(ctx, tx, session.SHA256)
|
|
if err != nil {
|
|
return fmt.Errorf("get blob for update: %w", err)
|
|
}
|
|
|
|
if blob != nil {
|
|
result := tx.WithContext(ctx).Model(&model.FileBlob{}).
|
|
Where("sha256 = ?", session.SHA256).
|
|
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
|
|
if result.Error != nil {
|
|
return fmt.Errorf("increment ref: %w", result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("blob not found for ref increment")
|
|
}
|
|
} else {
|
|
if session.TotalChunks > 0 {
|
|
chunkKeys := make([]string, session.TotalChunks)
|
|
for i := 0; i < session.TotalChunks; i++ {
|
|
chunkKeys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
|
|
}
|
|
|
|
dstKey := "files/" + session.SHA256
|
|
if _, err := s.storage.ComposeObject(ctx, s.cfg.Bucket, dstKey, chunkKeys); err != nil {
|
|
return errComposeFailed{err: err}
|
|
}
|
|
}
|
|
|
|
blobRecord := &model.FileBlob{
|
|
SHA256: session.SHA256,
|
|
MinioKey: "files/" + session.SHA256,
|
|
FileSize: session.FileSize,
|
|
MimeType: session.MimeType,
|
|
RefCount: 1,
|
|
}
|
|
if session.TotalChunks == 0 {
|
|
blobRecord.MinioKey = "files/" + session.SHA256
|
|
blobRecord.FileSize = 0
|
|
}
|
|
if err := tx.WithContext(ctx).Create(blobRecord).Error; err != nil {
|
|
return fmt.Errorf("create blob: %w", err)
|
|
}
|
|
}
|
|
|
|
file := &model.File{
|
|
Name: session.FileName,
|
|
FolderID: session.FolderID,
|
|
BlobSHA256: session.SHA256,
|
|
}
|
|
if err := tx.WithContext(ctx).Create(file).Error; err != nil {
|
|
return fmt.Errorf("create file: %w", err)
|
|
}
|
|
|
|
if err := updateStatusTx(tx, ctx, sessionID, "completed"); err != nil {
|
|
return fmt.Errorf("update status to completed: %w", err)
|
|
}
|
|
|
|
fileResp = &model.FileResponse{
|
|
ID: file.ID,
|
|
Name: file.Name,
|
|
FolderID: file.FolderID,
|
|
Size: session.FileSize,
|
|
MimeType: session.MimeType,
|
|
SHA256: session.SHA256,
|
|
CreatedAt: file.CreatedAt,
|
|
UpdatedAt: file.UpdatedAt,
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
var cfe errComposeFailed
|
|
if errors.As(err, &cfe) {
|
|
if statusErr := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "failed"); statusErr != nil {
|
|
s.logger.Warn("failed to mark session as failed", zap.Int64("session_id", sessionID), zap.Error(statusErr))
|
|
}
|
|
return nil, fmt.Errorf("compose object: %w", cfe.err)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if totalChunks > 0 {
|
|
keys := make([]string, totalChunks)
|
|
for i := 0; i < totalChunks; i++ {
|
|
keys[i] = fmt.Sprintf("%schunk_%05d", minioPrefix, i)
|
|
}
|
|
go func() {
|
|
bgCtx := context.Background()
|
|
if delErr := s.storage.RemoveObjects(bgCtx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
|
|
s.logger.Warn("delete temp chunks", zap.Error(delErr))
|
|
}
|
|
}()
|
|
}
|
|
|
|
return fileResp, nil
|
|
}
|
|
|
|
func (s *UploadService) CancelUpload(ctx context.Context, sessionID int64) error {
|
|
session, err := s.uploadStore.GetSession(ctx, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("get session: %w", err)
|
|
}
|
|
if session == nil {
|
|
return fmt.Errorf("session not found")
|
|
}
|
|
|
|
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "cancelled"); err != nil {
|
|
return fmt.Errorf("update status to cancelled: %w", err)
|
|
}
|
|
|
|
if session.TotalChunks > 0 {
|
|
keys, listErr := s.listChunkKeys(ctx, sessionID)
|
|
if listErr != nil {
|
|
s.logger.Warn("list chunk keys for cancel", zap.Error(listErr))
|
|
} else if len(keys) > 0 {
|
|
if delErr := s.storage.RemoveObjects(ctx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
|
|
s.logger.Warn("remove chunk objects for cancel", zap.Error(delErr))
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := s.uploadStore.DeleteSession(ctx, sessionID); err != nil {
|
|
return fmt.Errorf("delete session: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UploadService) listChunkKeys(ctx context.Context, sessionID int64) ([]string, error) {
|
|
session, err := s.uploadStore.GetSession(ctx, sessionID)
|
|
if err != nil || session == nil {
|
|
return nil, fmt.Errorf("get session: %w", err)
|
|
}
|
|
|
|
keys := make([]string, session.TotalChunks)
|
|
for i := 0; i < session.TotalChunks; i++ {
|
|
keys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func updateStatusTx(tx *gorm.DB, ctx context.Context, id int64, status string) error {
|
|
result := tx.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return gorm.ErrRecordNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isSQLite(db *gorm.DB) bool {
|
|
return db.Dialector.Name() == "sqlite"
|
|
}
|
|
|
|
func objectKeys(objects []storage.ObjectInfo) []string {
|
|
keys := make([]string, len(objects))
|
|
for i, o := range objects {
|
|
keys[i] = o.Key
|
|
}
|
|
return keys
|
|
}
|
|
|
|
type errComposeFailed struct {
|
|
err error
|
|
}
|
|
|
|
func (e errComposeFailed) Error() string {
|
|
return fmt.Sprintf("compose failed: %v", e.err)
|
|
}
|