From acf8c1d62be3e9438893a2a972551cde861d96d5 Mon Sep 17 00:00:00 2001 From: dailz Date: Wed, 15 Apr 2026 21:30:51 +0800 Subject: [PATCH] feat(store): add TaskStore CRUD and batch query methods for files and blobs --- internal/store/blob_store.go | 9 ++ internal/store/blob_store_test.go | 53 +++++++ internal/store/file_store.go | 9 ++ internal/store/file_store_test.go | 78 ++++++++++ internal/store/mysql.go | 1 + internal/store/task_store.go | 141 ++++++++++++++++++ internal/store/task_store_test.go | 229 ++++++++++++++++++++++++++++++ 7 files changed, 520 insertions(+) create mode 100644 internal/store/task_store.go create mode 100644 internal/store/task_store_test.go diff --git a/internal/store/blob_store.go b/internal/store/blob_store.go index 1bfcbd6..6280594 100644 --- a/internal/store/blob_store.go +++ b/internal/store/blob_store.go @@ -82,6 +82,15 @@ func (s *BlobStore) Delete(ctx context.Context, sha256 string) error { return nil } +func (s *BlobStore) GetBySHA256s(ctx context.Context, sha256s []string) ([]model.FileBlob, error) { + var blobs []model.FileBlob + if len(sha256s) == 0 { + return blobs, nil + } + err := s.db.WithContext(ctx).Where("sha256 IN ?", sha256s).Find(&blobs).Error + return blobs, err +} + // GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock. // Returns (nil, nil) if not found. func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) { diff --git a/internal/store/blob_store_test.go b/internal/store/blob_store_test.go index 1e85f60..ddb23d2 100644 --- a/internal/store/blob_store_test.go +++ b/internal/store/blob_store_test.go @@ -133,6 +133,59 @@ func TestBlobStore_Delete(t *testing.T) { } } +func TestBlobStore_GetBySHA256s(t *testing.T) { + db := setupBlobTestDB(t) + store := NewBlobStore(db) + ctx := context.Background() + + store.Create(ctx, &model.FileBlob{SHA256: "h1", MinioKey: "files/h1", FileSize: 100}) + store.Create(ctx, &model.FileBlob{SHA256: "h2", MinioKey: "files/h2", FileSize: 200}) + store.Create(ctx, &model.FileBlob{SHA256: "h3", MinioKey: "files/h3", FileSize: 300}) + + blobs, err := store.GetBySHA256s(ctx, []string{"h1", "h3"}) + if err != nil { + t.Fatalf("GetBySHA256s() error = %v", err) + } + if len(blobs) != 2 { + t.Fatalf("len(blobs) = %d, want 2", len(blobs)) + } + keys := map[string]bool{} + for _, b := range blobs { + keys[b.SHA256] = true + } + if !keys["h1"] || !keys["h3"] { + t.Errorf("expected h1 and h3, got %v", blobs) + } +} + +func TestBlobStore_GetBySHA256s_Empty(t *testing.T) { + db := setupBlobTestDB(t) + store := NewBlobStore(db) + ctx := context.Background() + + blobs, err := store.GetBySHA256s(ctx, []string{}) + if err != nil { + t.Fatalf("GetBySHA256s() error = %v", err) + } + if len(blobs) != 0 { + t.Errorf("len(blobs) = %d, want 0", len(blobs)) + } +} + +func TestBlobStore_GetBySHA256s_NotFound(t *testing.T) { + db := setupBlobTestDB(t) + store := NewBlobStore(db) + ctx := context.Background() + + blobs, err := store.GetBySHA256s(ctx, []string{"nonexistent"}) + if err != nil { + t.Fatalf("GetBySHA256s() error = %v", err) + } + if len(blobs) != 0 { + t.Errorf("len(blobs) = %d, want 0 for non-existent SHA256s", len(blobs)) + } +} + func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) { db := setupBlobTestDB(t) store := NewBlobStore(db) diff --git a/internal/store/file_store.go b/internal/store/file_store.go index ba0dd48..403ce8d 100644 --- a/internal/store/file_store.go +++ b/internal/store/file_store.go @@ -86,6 +86,15 @@ func (s *FileStore) CountByBlobSHA256(ctx context.Context, blobSHA256 string) (i return count, err } +func (s *FileStore) GetByIDs(ctx context.Context, ids []int64) ([]model.File, error) { + var files []model.File + if len(ids) == 0 { + return files, nil + } + err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&files).Error + return files, err +} + func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) { var file model.File err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error diff --git a/internal/store/file_store_test.go b/internal/store/file_store_test.go index a0cc95d..7a8ded0 100644 --- a/internal/store/file_store_test.go +++ b/internal/store/file_store_test.go @@ -230,6 +230,84 @@ func TestFileStore_GetBlobSHA256ByID(t *testing.T) { } } +func TestFileStore_GetByIDs(t *testing.T) { + db := setupFileTestDB(t) + store := NewFileStore(db) + ctx := context.Background() + + store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"}) + store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"}) + store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"}) + + files, err := store.GetByIDs(ctx, []int64{1, 3}) + if err != nil { + t.Fatalf("GetByIDs() error = %v", err) + } + if len(files) != 2 { + t.Fatalf("len(files) = %d, want 2", len(files)) + } + names := map[string]bool{} + for _, f := range files { + names[f.Name] = true + } + if !names["a.bin"] || !names["c.bin"] { + t.Errorf("expected a.bin and c.bin, got %v", files) + } +} + +func TestFileStore_GetByIDs_Empty(t *testing.T) { + db := setupFileTestDB(t) + store := NewFileStore(db) + ctx := context.Background() + + files, err := store.GetByIDs(ctx, []int64{}) + if err != nil { + t.Fatalf("GetByIDs() error = %v", err) + } + if len(files) != 0 { + t.Errorf("len(files) = %d, want 0", len(files)) + } +} + +func TestFileStore_GetByIDs_NotFound(t *testing.T) { + db := setupFileTestDB(t) + store := NewFileStore(db) + ctx := context.Background() + + files, err := store.GetByIDs(ctx, []int64{999}) + if err != nil { + t.Fatalf("GetByIDs() error = %v", err) + } + if len(files) != 0 { + t.Errorf("len(files) = %d, want 0 for non-existent IDs", len(files)) + } +} + +func TestFileStore_GetByIDs_SoftDeleteExcluded(t *testing.T) { + db := setupFileTestDB(t) + store := NewFileStore(db) + ctx := context.Background() + + store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"}) + store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"}) + store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"}) + + store.Delete(ctx, 2) + + files, err := store.GetByIDs(ctx, []int64{1, 2, 3}) + if err != nil { + t.Fatalf("GetByIDs() error = %v", err) + } + if len(files) != 2 { + t.Fatalf("len(files) = %d, want 2 (soft-deleted excluded)", len(files)) + } + for _, f := range files { + if f.ID == 2 { + t.Error("soft-deleted file ID 2 should not appear") + } + } +} + func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) { db := setupFileTestDB(t) store := NewFileStore(db) diff --git a/internal/store/mysql.go b/internal/store/mysql.go index d202e37..db9e7b6 100644 --- a/internal/store/mysql.go +++ b/internal/store/mysql.go @@ -47,5 +47,6 @@ func AutoMigrate(db *gorm.DB) error { &model.Folder{}, &model.UploadSession{}, &model.UploadChunk{}, + &model.Task{}, ) } diff --git a/internal/store/task_store.go b/internal/store/task_store.go new file mode 100644 index 0000000..c457561 --- /dev/null +++ b/internal/store/task_store.go @@ -0,0 +1,141 @@ +package store + +import ( + "context" + "errors" + "fmt" + "time" + + "gcy_hpc_server/internal/model" + + "gorm.io/gorm" +) + +type TaskStore struct { + db *gorm.DB +} + +func NewTaskStore(db *gorm.DB) *TaskStore { + return &TaskStore{db: db} +} + +func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) { + if err := s.db.WithContext(ctx).Create(task).Error; err != nil { + return 0, err + } + return task.ID, nil +} + +func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) { + var task model.Task + err := s.db.WithContext(ctx).First(&task, id).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + if err != nil { + return nil, err + } + return &task, nil +} + +func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) { + page := query.Page + pageSize := query.PageSize + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 10 + } + + q := s.db.WithContext(ctx).Model(&model.Task{}) + if query.Status != "" { + q = q.Where("status = ?", query.Status) + } + + var total int64 + if err := q.Count(&total).Error; err != nil { + return nil, 0, err + } + + var tasks []model.Task + offset := (page - 1) * pageSize + if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil { + return nil, 0, err + } + return tasks, total, nil +} + +func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error { + updates := map[string]interface{}{ + "status": status, + "error_message": errorMsg, + } + + now := time.Now() + switch status { + case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady, + model.TaskStatusQueued, model.TaskStatusRunning: + updates["started_at"] = &now + case model.TaskStatusCompleted, model.TaskStatusFailed: + updates["finished_at"] = &now + } + + result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("task %d not found", id) + } + return nil +} + +func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error { + result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id). + Update("slurm_job_id", slurmJobID) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("task %d not found", id) + } + return nil +} + +func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error { + result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id). + Update("work_dir", workDir) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("task %d not found", id) + } + return nil +} + +func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error { + result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{ + "status": status, + "current_step": currentStep, + "retry_count": retryCount, + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("task %d not found", id) + } + return nil +} + +func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) { + cutoff := time.Now().Add(-maxAge) + var tasks []model.Task + err := s.db.WithContext(ctx). + Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}). + Where("updated_at < ?", cutoff). + Find(&tasks).Error + return tasks, err +} diff --git a/internal/store/task_store_test.go b/internal/store/task_store_test.go new file mode 100644 index 0000000..a2aa403 --- /dev/null +++ b/internal/store/task_store_test.go @@ -0,0 +1,229 @@ +package store + +import ( + "context" + "testing" + "time" + + "gcy_hpc_server/internal/model" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func newTaskTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Task{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + return db +} + +func makeTestTask(name, status string) *model.Task { + return &model.Task{ + TaskName: name, + AppID: 1, + AppName: "test-app", + Status: status, + CurrentStep: "", + RetryCount: 0, + UserID: "user1", + SubmittedAt: time.Now(), + } +} + +func TestTaskStore_CreateAndGetByID(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + ctx := context.Background() + + task := makeTestTask("test-task", model.TaskStatusSubmitted) + id, err := s.Create(ctx, task) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if id <= 0 { + t.Errorf("Create() id = %d, want positive", id) + } + + got, err := s.GetByID(ctx, id) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if got == nil { + t.Fatal("GetByID() returned nil") + } + if got.TaskName != "test-task" { + t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task") + } + if got.Status != model.TaskStatusSubmitted { + t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted) + } + if got.ID != id { + t.Errorf("ID = %d, want %d", got.ID, id) + } +} + +func TestTaskStore_GetByID_NotFound(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + + got, err := s.GetByID(context.Background(), 999) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if got != nil { + t.Error("GetByID() expected nil for not-found, got non-nil") + } +} + +func TestTaskStore_ListPagination(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + ctx := context.Background() + + for i := 0; i < 5; i++ { + s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted)) + } + + tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3}) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if total != 5 { + t.Errorf("total = %d, want 5", total) + } + if len(tasks) != 3 { + t.Errorf("len(tasks) = %d, want 3", len(tasks)) + } + + tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3}) + if err != nil { + t.Fatalf("List() page 2 error = %v", err) + } + if total2 != 5 { + t.Errorf("total2 = %d, want 5", total2) + } + if len(tasks2) != 2 { + t.Errorf("len(tasks2) = %d, want 2", len(tasks2)) + } +} + +func TestTaskStore_ListStatusFilter(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + ctx := context.Background() + + s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning)) + s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning)) + s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted)) + + tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning}) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if total != 2 { + t.Errorf("total = %d, want 2", total) + } + for _, t2 := range tasks { + if t2.Status != model.TaskStatusRunning { + t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning) + } + } +} + +func TestTaskStore_UpdateStatus(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + ctx := context.Background() + + id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted)) + + err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "") + if err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + got, _ := s.GetByID(ctx, id) + if got.Status != model.TaskStatusPreparing { + t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing) + } + if got.StartedAt == nil { + t.Error("StartedAt expected non-nil after preparing status") + } + + err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke") + if err != nil { + t.Fatalf("UpdateStatus(failed) error = %v", err) + } + got, _ = s.GetByID(ctx, id) + if got.ErrorMessage != "something broke" { + t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke") + } + if got.FinishedAt == nil { + t.Error("FinishedAt expected non-nil after failed status") + } +} + +func TestTaskStore_UpdateRetryState(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + ctx := context.Background() + + task := makeTestTask("retry-test", model.TaskStatusFailed) + task.CurrentStep = model.TaskStepDownloading + task.RetryCount = 1 + id, _ := s.Create(ctx, task) + + err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2) + if err != nil { + t.Fatalf("UpdateRetryState() error = %v", err) + } + + got, _ := s.GetByID(ctx, id) + if got.Status != model.TaskStatusSubmitted { + t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted) + } + if got.CurrentStep != model.TaskStepDownloading { + t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading) + } + if got.RetryCount != 2 { + t.Errorf("RetryCount = %d, want 2", got.RetryCount) + } +} + +func TestTaskStore_GetStuckTasks(t *testing.T) { + db := newTaskTestDB(t) + s := NewTaskStore(db) + ctx := context.Background() + + stuck := makeTestTask("stuck-1", model.TaskStatusDownloading) + stuck.UpdatedAt = time.Now().Add(-1 * time.Hour) + s.Create(ctx, stuck) + + recent := makeTestTask("recent-1", model.TaskStatusDownloading) + s.Create(ctx, recent) + + done := makeTestTask("done-1", model.TaskStatusCompleted) + done.UpdatedAt = time.Now().Add(-2 * time.Hour) + s.Create(ctx, done) + + tasks, err := s.GetStuckTasks(ctx, 30*time.Minute) + if err != nil { + t.Fatalf("GetStuckTasks() error = %v", err) + } + + if len(tasks) != 1 { + t.Fatalf("len(tasks) = %d, want 1", len(tasks)) + } + if tasks[0].TaskName != "stuck-1" { + t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1") + } +}