Files
hpc/internal/service/upload_service.go
dailz 79870333cb fix(service): tolerate concurrent pending-to-uploading status race in UploadChunk
When multiple chunk uploads race on the pending→uploading transition, ignore ErrRecordNotFound from UpdateSessionStatus since another request already completed the update.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 10:27:12 +08:00

444 lines
12 KiB
Go

package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UploadService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
uploadStore *store.UploadStore
cfg config.MinioConfig
db *gorm.DB
logger *zap.Logger
}
func NewUploadService(
st storage.ObjectStorage,
blobStore *store.BlobStore,
fileStore *store.FileStore,
uploadStore *store.UploadStore,
cfg config.MinioConfig,
db *gorm.DB,
logger *zap.Logger,
) *UploadService {
return &UploadService{
storage: st,
blobStore: blobStore,
fileStore: fileStore,
uploadStore: uploadStore,
cfg: cfg,
db: db,
logger: logger,
}
}
func (s *UploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
if err := model.ValidateFileName(req.FileName); err != nil {
return nil, fmt.Errorf("invalid file name: %w", err)
}
if req.FileSize < 0 {
return nil, fmt.Errorf("file size cannot be negative")
}
if req.FileSize > s.cfg.MaxFileSize {
return nil, fmt.Errorf("file size %d exceeds maximum %d", req.FileSize, s.cfg.MaxFileSize)
}
chunkSize := s.cfg.ChunkSize
if req.ChunkSize != nil {
chunkSize = *req.ChunkSize
}
if chunkSize < s.cfg.MinChunkSize {
return nil, fmt.Errorf("chunk size %d is below minimum %d", chunkSize, s.cfg.MinChunkSize)
}
totalChunks := int((req.FileSize + chunkSize - 1) / chunkSize)
if req.FileSize == 0 {
totalChunks = 0
}
if totalChunks > 10000 {
return nil, fmt.Errorf("total chunks %d exceeds limit of 10000", totalChunks)
}
blob, err := s.blobStore.GetBySHA256(ctx, req.SHA256)
if err != nil {
return nil, fmt.Errorf("check dedup: %w", err)
}
if blob != nil {
if err := s.blobStore.IncrementRef(ctx, req.SHA256); err != nil {
return nil, fmt.Errorf("increment ref: %w", err)
}
file := &model.File{
Name: req.FileName,
FolderID: req.FolderID,
BlobSHA256: req.SHA256,
}
if err := s.fileStore.Create(ctx, file); err != nil {
return nil, fmt.Errorf("create file: %w", err)
}
return model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: blob.SHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}, nil
}
mimeType := req.MimeType
if mimeType == "" {
mimeType = "application/octet-stream"
}
session := &model.UploadSession{
FileName: req.FileName,
FileSize: req.FileSize,
ChunkSize: chunkSize,
TotalChunks: totalChunks,
SHA256: req.SHA256,
FolderID: req.FolderID,
Status: "pending",
MinioPrefix: fmt.Sprintf("uploads/%d/", time.Now().UnixNano()),
MimeType: mimeType,
ExpiresAt: time.Now().Add(time.Duration(s.cfg.SessionTTL) * time.Hour),
}
if err := s.uploadStore.CreateSession(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
return model.UploadSessionResponse{
ID: session.ID,
FileName: session.FileName,
FileSize: session.FileSize,
ChunkSize: session.ChunkSize,
TotalChunks: session.TotalChunks,
SHA256: session.SHA256,
Status: session.Status,
UploadedChunks: []int{},
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
}, nil
}
func (s *UploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
if session == nil {
return fmt.Errorf("session not found")
}
switch session.Status {
case "pending", "uploading", "failed":
default:
return fmt.Errorf("cannot upload to session with status %q", session.Status)
}
if chunkIndex < 0 || chunkIndex >= session.TotalChunks {
return fmt.Errorf("chunk index %d out of range [0, %d)", chunkIndex, session.TotalChunks)
}
key := fmt.Sprintf("%schunk_%05d", session.MinioPrefix, chunkIndex)
hasher := sha256.New()
teeReader := io.TeeReader(reader, hasher)
_, err = s.storage.PutObject(ctx, s.cfg.Bucket, key, teeReader, size, storage.PutObjectOptions{
DisableMultipart: true,
})
if err != nil {
return fmt.Errorf("put object: %w", err)
}
chunkSHA256 := hex.EncodeToString(hasher.Sum(nil))
chunk := &model.UploadChunk{
SessionID: sessionID,
ChunkIndex: chunkIndex,
MinioKey: key,
SHA256: chunkSHA256,
Size: size,
Status: "uploaded",
}
if err := s.uploadStore.UpsertChunk(ctx, chunk); err != nil {
return fmt.Errorf("upsert chunk: %w", err)
}
if session.Status == "pending" {
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "uploading"); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("update status to uploading: %w", err)
}
}
}
return nil
}
func (s *UploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
session, chunks, err := s.uploadStore.GetSessionWithChunks(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("get session: %w", err)
}
if session == nil {
return nil, fmt.Errorf("session not found")
}
uploadedChunks := make([]int, 0, len(chunks))
for _, c := range chunks {
if c.Status == "uploaded" {
uploadedChunks = append(uploadedChunks, c.ChunkIndex)
}
}
return &model.UploadSessionResponse{
ID: session.ID,
FileName: session.FileName,
FileSize: session.FileSize,
ChunkSize: session.ChunkSize,
TotalChunks: session.TotalChunks,
SHA256: session.SHA256,
Status: session.Status,
UploadedChunks: uploadedChunks,
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
}, nil
}
func (s *UploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
var fileResp *model.FileResponse
var totalChunks int
var minioPrefix string
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var session model.UploadSession
query := tx.WithContext(ctx)
if !isSQLite(tx) {
query = query.Clauses(clause.Locking{Strength: "UPDATE"})
}
if err := query.First(&session, sessionID).Error; err != nil {
return fmt.Errorf("get session: %w", err)
}
switch session.Status {
case "uploading", "failed":
default:
return fmt.Errorf("cannot complete session with status %q", session.Status)
}
if session.TotalChunks > 0 {
var chunkCount int64
if err := tx.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Count(&chunkCount).Error; err != nil {
return fmt.Errorf("count chunks: %w", err)
}
if int(chunkCount) != session.TotalChunks {
return fmt.Errorf("not all chunks uploaded: %d/%d", chunkCount, session.TotalChunks)
}
}
totalChunks = session.TotalChunks
minioPrefix = session.MinioPrefix
if err := updateStatusTx(tx, ctx, sessionID, "merging"); err != nil {
return fmt.Errorf("update status to merging: %w", err)
}
blob, err := s.blobStore.GetBySHA256ForUpdate(ctx, tx, session.SHA256)
if err != nil {
return fmt.Errorf("get blob for update: %w", err)
}
if blob != nil {
result := tx.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ?", session.SHA256).
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
if result.Error != nil {
return fmt.Errorf("increment ref: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("blob not found for ref increment")
}
} else {
if session.TotalChunks > 0 {
chunkKeys := make([]string, session.TotalChunks)
for i := 0; i < session.TotalChunks; i++ {
chunkKeys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
}
dstKey := "files/" + session.SHA256
if _, err := s.storage.ComposeObject(ctx, s.cfg.Bucket, dstKey, chunkKeys); err != nil {
return errComposeFailed{err: err}
}
}
blobRecord := &model.FileBlob{
SHA256: session.SHA256,
MinioKey: "files/" + session.SHA256,
FileSize: session.FileSize,
MimeType: session.MimeType,
RefCount: 1,
}
if session.TotalChunks == 0 {
blobRecord.MinioKey = "files/" + session.SHA256
blobRecord.FileSize = 0
}
if err := tx.WithContext(ctx).Create(blobRecord).Error; err != nil {
return fmt.Errorf("create blob: %w", err)
}
}
file := &model.File{
Name: session.FileName,
FolderID: session.FolderID,
BlobSHA256: session.SHA256,
}
if err := tx.WithContext(ctx).Create(file).Error; err != nil {
return fmt.Errorf("create file: %w", err)
}
if err := updateStatusTx(tx, ctx, sessionID, "completed"); err != nil {
return fmt.Errorf("update status to completed: %w", err)
}
fileResp = &model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: session.FileSize,
MimeType: session.MimeType,
SHA256: session.SHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}
return nil
})
if err != nil {
var cfe errComposeFailed
if errors.As(err, &cfe) {
if statusErr := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "failed"); statusErr != nil {
s.logger.Warn("failed to mark session as failed", zap.Int64("session_id", sessionID), zap.Error(statusErr))
}
return nil, fmt.Errorf("compose object: %w", cfe.err)
}
return nil, err
}
if totalChunks > 0 {
keys := make([]string, totalChunks)
for i := 0; i < totalChunks; i++ {
keys[i] = fmt.Sprintf("%schunk_%05d", minioPrefix, i)
}
go func() {
bgCtx := context.Background()
if delErr := s.storage.RemoveObjects(bgCtx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
s.logger.Warn("delete temp chunks", zap.Error(delErr))
}
}()
}
return fileResp, nil
}
func (s *UploadService) CancelUpload(ctx context.Context, sessionID int64) error {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
if session == nil {
return fmt.Errorf("session not found")
}
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "cancelled"); err != nil {
return fmt.Errorf("update status to cancelled: %w", err)
}
if session.TotalChunks > 0 {
keys, listErr := s.listChunkKeys(ctx, sessionID)
if listErr != nil {
s.logger.Warn("list chunk keys for cancel", zap.Error(listErr))
} else if len(keys) > 0 {
if delErr := s.storage.RemoveObjects(ctx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
s.logger.Warn("remove chunk objects for cancel", zap.Error(delErr))
}
}
}
if err := s.uploadStore.DeleteSession(ctx, sessionID); err != nil {
return fmt.Errorf("delete session: %w", err)
}
return nil
}
func (s *UploadService) listChunkKeys(ctx context.Context, sessionID int64) ([]string, error) {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil || session == nil {
return nil, fmt.Errorf("get session: %w", err)
}
keys := make([]string, session.TotalChunks)
for i := 0; i < session.TotalChunks; i++ {
keys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
}
return keys, nil
}
func updateStatusTx(tx *gorm.DB, ctx context.Context, id int64, status string) error {
result := tx.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
func isSQLite(db *gorm.DB) bool {
return db.Dialector.Name() == "sqlite"
}
func objectKeys(objects []storage.ObjectInfo) []string {
keys := make([]string, len(objects))
for i, o := range objects {
keys[i] = o.Key
}
return keys
}
type errComposeFailed struct {
err error
}
func (e errComposeFailed) Error() string {
return fmt.Sprintf("compose failed: %v", e.err)
}