feat(store): add TaskStore CRUD and batch query methods for files and blobs

This commit is contained in:
dailz
2026-04-15 21:30:51 +08:00
parent d46a784efb
commit acf8c1d62b
7 changed files with 520 additions and 0 deletions

View File

@@ -82,6 +82,15 @@ func (s *BlobStore) Delete(ctx context.Context, sha256 string) error {
return nil return nil
} }
func (s *BlobStore) GetBySHA256s(ctx context.Context, sha256s []string) ([]model.FileBlob, error) {
var blobs []model.FileBlob
if len(sha256s) == 0 {
return blobs, nil
}
err := s.db.WithContext(ctx).Where("sha256 IN ?", sha256s).Find(&blobs).Error
return blobs, err
}
// GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock. // GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock.
// Returns (nil, nil) if not found. // Returns (nil, nil) if not found.
func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) { func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) {

View File

@@ -133,6 +133,59 @@ func TestBlobStore_Delete(t *testing.T) {
} }
} }
func TestBlobStore_GetBySHA256s(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "h1", MinioKey: "files/h1", FileSize: 100})
store.Create(ctx, &model.FileBlob{SHA256: "h2", MinioKey: "files/h2", FileSize: 200})
store.Create(ctx, &model.FileBlob{SHA256: "h3", MinioKey: "files/h3", FileSize: 300})
blobs, err := store.GetBySHA256s(ctx, []string{"h1", "h3"})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 2 {
t.Fatalf("len(blobs) = %d, want 2", len(blobs))
}
keys := map[string]bool{}
for _, b := range blobs {
keys[b.SHA256] = true
}
if !keys["h1"] || !keys["h3"] {
t.Errorf("expected h1 and h3, got %v", blobs)
}
}
func TestBlobStore_GetBySHA256s_Empty(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blobs, err := store.GetBySHA256s(ctx, []string{})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 0 {
t.Errorf("len(blobs) = %d, want 0", len(blobs))
}
}
func TestBlobStore_GetBySHA256s_NotFound(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blobs, err := store.GetBySHA256s(ctx, []string{"nonexistent"})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 0 {
t.Errorf("len(blobs) = %d, want 0 for non-existent SHA256s", len(blobs))
}
}
func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) { func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) {
db := setupBlobTestDB(t) db := setupBlobTestDB(t)
store := NewBlobStore(db) store := NewBlobStore(db)

View File

@@ -86,6 +86,15 @@ func (s *FileStore) CountByBlobSHA256(ctx context.Context, blobSHA256 string) (i
return count, err return count, err
} }
func (s *FileStore) GetByIDs(ctx context.Context, ids []int64) ([]model.File, error) {
var files []model.File
if len(ids) == 0 {
return files, nil
}
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&files).Error
return files, err
}
func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) { func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) {
var file model.File var file model.File
err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error

View File

@@ -230,6 +230,84 @@ func TestFileStore_GetBlobSHA256ByID(t *testing.T) {
} }
} }
func TestFileStore_GetByIDs(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
files, err := store.GetByIDs(ctx, []int64{1, 3})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 2 {
t.Fatalf("len(files) = %d, want 2", len(files))
}
names := map[string]bool{}
for _, f := range files {
names[f.Name] = true
}
if !names["a.bin"] || !names["c.bin"] {
t.Errorf("expected a.bin and c.bin, got %v", files)
}
}
func TestFileStore_GetByIDs_Empty(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
files, err := store.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 0 {
t.Errorf("len(files) = %d, want 0", len(files))
}
}
func TestFileStore_GetByIDs_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
files, err := store.GetByIDs(ctx, []int64{999})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 0 {
t.Errorf("len(files) = %d, want 0 for non-existent IDs", len(files))
}
}
func TestFileStore_GetByIDs_SoftDeleteExcluded(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
store.Delete(ctx, 2)
files, err := store.GetByIDs(ctx, []int64{1, 2, 3})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 2 {
t.Fatalf("len(files) = %d, want 2 (soft-deleted excluded)", len(files))
}
for _, f := range files {
if f.ID == 2 {
t.Error("soft-deleted file ID 2 should not appear")
}
}
}
func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) { func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) {
db := setupFileTestDB(t) db := setupFileTestDB(t)
store := NewFileStore(db) store := NewFileStore(db)

View File

@@ -47,5 +47,6 @@ func AutoMigrate(db *gorm.DB) error {
&model.Folder{}, &model.Folder{},
&model.UploadSession{}, &model.UploadSession{},
&model.UploadChunk{}, &model.UploadChunk{},
&model.Task{},
) )
} }

View File

@@ -0,0 +1,141 @@
package store
import (
"context"
"errors"
"fmt"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type TaskStore struct {
db *gorm.DB
}
func NewTaskStore(db *gorm.DB) *TaskStore {
return &TaskStore{db: db}
}
func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) {
if err := s.db.WithContext(ctx).Create(task).Error; err != nil {
return 0, err
}
return task.ID, nil
}
func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) {
var task model.Task
err := s.db.WithContext(ctx).First(&task, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &task, nil
}
func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
page := query.Page
pageSize := query.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 10
}
q := s.db.WithContext(ctx).Model(&model.Task{})
if query.Status != "" {
q = q.Where("status = ?", query.Status)
}
var total int64
if err := q.Count(&total).Error; err != nil {
return nil, 0, err
}
var tasks []model.Task
offset := (page - 1) * pageSize
if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
return nil, 0, err
}
return tasks, total, nil
}
func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error {
updates := map[string]interface{}{
"status": status,
"error_message": errorMsg,
}
now := time.Now()
switch status {
case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady,
model.TaskStatusQueued, model.TaskStatusRunning:
updates["started_at"] = &now
case model.TaskStatusCompleted, model.TaskStatusFailed:
updates["finished_at"] = &now
}
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
Update("slurm_job_id", slurmJobID)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
Update("work_dir", workDir)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
"status": status,
"current_step": currentStep,
"retry_count": retryCount,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) {
cutoff := time.Now().Add(-maxAge)
var tasks []model.Task
err := s.db.WithContext(ctx).
Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}).
Where("updated_at < ?", cutoff).
Find(&tasks).Error
return tasks, err
}

View File

@@ -0,0 +1,229 @@
package store
import (
"context"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func makeTestTask(name, status string) *model.Task {
return &model.Task{
TaskName: name,
AppID: 1,
AppName: "test-app",
Status: status,
CurrentStep: "",
RetryCount: 0,
UserID: "user1",
SubmittedAt: time.Now(),
}
}
func TestTaskStore_CreateAndGetByID(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
task := makeTestTask("test-task", model.TaskStatusSubmitted)
id, err := s.Create(ctx, task)
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Errorf("Create() id = %d, want positive", id)
}
got, err := s.GetByID(ctx, id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.TaskName != "test-task" {
t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task")
}
if got.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
}
if got.ID != id {
t.Errorf("ID = %d, want %d", got.ID, id)
}
}
func TestTaskStore_GetByID_NotFound(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
got, err := s.GetByID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() expected nil for not-found, got non-nil")
}
}
func TestTaskStore_ListPagination(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted))
}
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
if len(tasks) != 3 {
t.Errorf("len(tasks) = %d, want 3", len(tasks))
}
tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3})
if err != nil {
t.Fatalf("List() page 2 error = %v", err)
}
if total2 != 5 {
t.Errorf("total2 = %d, want 5", total2)
}
if len(tasks2) != 2 {
t.Errorf("len(tasks2) = %d, want 2", len(tasks2))
}
}
func TestTaskStore_ListStatusFilter(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning))
s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning))
s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted))
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
for _, t2 := range tasks {
if t2.Status != model.TaskStatusRunning {
t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning)
}
}
}
func TestTaskStore_UpdateStatus(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted))
err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "")
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
got, _ := s.GetByID(ctx, id)
if got.Status != model.TaskStatusPreparing {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing)
}
if got.StartedAt == nil {
t.Error("StartedAt expected non-nil after preparing status")
}
err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke")
if err != nil {
t.Fatalf("UpdateStatus(failed) error = %v", err)
}
got, _ = s.GetByID(ctx, id)
if got.ErrorMessage != "something broke" {
t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke")
}
if got.FinishedAt == nil {
t.Error("FinishedAt expected non-nil after failed status")
}
}
func TestTaskStore_UpdateRetryState(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
task := makeTestTask("retry-test", model.TaskStatusFailed)
task.CurrentStep = model.TaskStepDownloading
task.RetryCount = 1
id, _ := s.Create(ctx, task)
err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2)
if err != nil {
t.Fatalf("UpdateRetryState() error = %v", err)
}
got, _ := s.GetByID(ctx, id)
if got.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
}
if got.CurrentStep != model.TaskStepDownloading {
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading)
}
if got.RetryCount != 2 {
t.Errorf("RetryCount = %d, want 2", got.RetryCount)
}
}
func TestTaskStore_GetStuckTasks(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
stuck := makeTestTask("stuck-1", model.TaskStatusDownloading)
stuck.UpdatedAt = time.Now().Add(-1 * time.Hour)
s.Create(ctx, stuck)
recent := makeTestTask("recent-1", model.TaskStatusDownloading)
s.Create(ctx, recent)
done := makeTestTask("done-1", model.TaskStatusCompleted)
done.UpdatedAt = time.Now().Add(-2 * time.Hour)
s.Create(ctx, done)
tasks, err := s.GetStuckTasks(ctx, 30*time.Minute)
if err != nil {
t.Fatalf("GetStuckTasks() error = %v", err)
}
if len(tasks) != 1 {
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
}
if tasks[0].TaskName != "stuck-1" {
t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1")
}
}