feat(store): add TaskStore CRUD and batch query methods for files and blobs
This commit is contained in:
@@ -82,6 +82,15 @@ func (s *BlobStore) Delete(ctx context.Context, sha256 string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BlobStore) GetBySHA256s(ctx context.Context, sha256s []string) ([]model.FileBlob, error) {
|
||||||
|
var blobs []model.FileBlob
|
||||||
|
if len(sha256s) == 0 {
|
||||||
|
return blobs, nil
|
||||||
|
}
|
||||||
|
err := s.db.WithContext(ctx).Where("sha256 IN ?", sha256s).Find(&blobs).Error
|
||||||
|
return blobs, err
|
||||||
|
}
|
||||||
|
|
||||||
// GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock.
|
// GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock.
|
||||||
// Returns (nil, nil) if not found.
|
// Returns (nil, nil) if not found.
|
||||||
func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) {
|
func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) {
|
||||||
|
|||||||
@@ -133,6 +133,59 @@ func TestBlobStore_Delete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256s(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "h1", MinioKey: "files/h1", FileSize: 100})
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "h2", MinioKey: "files/h2", FileSize: 200})
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "h3", MinioKey: "files/h3", FileSize: 300})
|
||||||
|
|
||||||
|
blobs, err := store.GetBySHA256s(ctx, []string{"h1", "h3"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256s() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(blobs) != 2 {
|
||||||
|
t.Fatalf("len(blobs) = %d, want 2", len(blobs))
|
||||||
|
}
|
||||||
|
keys := map[string]bool{}
|
||||||
|
for _, b := range blobs {
|
||||||
|
keys[b.SHA256] = true
|
||||||
|
}
|
||||||
|
if !keys["h1"] || !keys["h3"] {
|
||||||
|
t.Errorf("expected h1 and h3, got %v", blobs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256s_Empty(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blobs, err := store.GetBySHA256s(ctx, []string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256s() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(blobs) != 0 {
|
||||||
|
t.Errorf("len(blobs) = %d, want 0", len(blobs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256s_NotFound(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blobs, err := store.GetBySHA256s(ctx, []string{"nonexistent"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256s() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(blobs) != 0 {
|
||||||
|
t.Errorf("len(blobs) = %d, want 0 for non-existent SHA256s", len(blobs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) {
|
func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) {
|
||||||
db := setupBlobTestDB(t)
|
db := setupBlobTestDB(t)
|
||||||
store := NewBlobStore(db)
|
store := NewBlobStore(db)
|
||||||
|
|||||||
@@ -86,6 +86,15 @@ func (s *FileStore) CountByBlobSHA256(ctx context.Context, blobSHA256 string) (i
|
|||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetByIDs(ctx context.Context, ids []int64) ([]model.File, error) {
|
||||||
|
var files []model.File
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&files).Error
|
||||||
|
return files, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) {
|
func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) {
|
||||||
var file model.File
|
var file model.File
|
||||||
err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error
|
err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error
|
||||||
|
|||||||
@@ -230,6 +230,84 @@ func TestFileStore_GetBlobSHA256ByID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
|
||||||
|
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
|
||||||
|
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{1, 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("len(files) = %d, want 2", len(files))
|
||||||
|
}
|
||||||
|
names := map[string]bool{}
|
||||||
|
for _, f := range files {
|
||||||
|
names[f.Name] = true
|
||||||
|
}
|
||||||
|
if !names["a.bin"] || !names["c.bin"] {
|
||||||
|
t.Errorf("expected a.bin and c.bin, got %v", files)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs_Empty(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 0 {
|
||||||
|
t.Errorf("len(files) = %d, want 0", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs_NotFound(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{999})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 0 {
|
||||||
|
t.Errorf("len(files) = %d, want 0 for non-existent IDs", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs_SoftDeleteExcluded(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
|
||||||
|
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
|
||||||
|
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
|
||||||
|
|
||||||
|
store.Delete(ctx, 2)
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{1, 2, 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("len(files) = %d, want 2 (soft-deleted excluded)", len(files))
|
||||||
|
}
|
||||||
|
for _, f := range files {
|
||||||
|
if f.ID == 2 {
|
||||||
|
t.Error("soft-deleted file ID 2 should not appear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) {
|
func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) {
|
||||||
db := setupFileTestDB(t)
|
db := setupFileTestDB(t)
|
||||||
store := NewFileStore(db)
|
store := NewFileStore(db)
|
||||||
|
|||||||
@@ -47,5 +47,6 @@ func AutoMigrate(db *gorm.DB) error {
|
|||||||
&model.Folder{},
|
&model.Folder{},
|
||||||
&model.UploadSession{},
|
&model.UploadSession{},
|
||||||
&model.UploadChunk{},
|
&model.UploadChunk{},
|
||||||
|
&model.Task{},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
141
internal/store/task_store.go
Normal file
141
internal/store/task_store.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskStore(db *gorm.DB) *TaskStore {
|
||||||
|
return &TaskStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) {
|
||||||
|
if err := s.db.WithContext(ctx).Create(task).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return task.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) {
|
||||||
|
var task model.Task
|
||||||
|
err := s.db.WithContext(ctx).First(&task, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &task, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
||||||
|
page := query.Page
|
||||||
|
pageSize := query.PageSize
|
||||||
|
if page <= 0 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
q := s.db.WithContext(ctx).Model(&model.Task{})
|
||||||
|
if query.Status != "" {
|
||||||
|
q = q.Where("status = ?", query.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := q.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tasks []model.Task
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return tasks, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error {
|
||||||
|
updates := map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
"error_message": errorMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
switch status {
|
||||||
|
case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady,
|
||||||
|
model.TaskStatusQueued, model.TaskStatusRunning:
|
||||||
|
updates["started_at"] = &now
|
||||||
|
case model.TaskStatusCompleted, model.TaskStatusFailed:
|
||||||
|
updates["finished_at"] = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
|
||||||
|
Update("slurm_job_id", slurmJobID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
|
||||||
|
Update("work_dir", workDir)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
"current_step": currentStep,
|
||||||
|
"retry_count": retryCount,
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) {
|
||||||
|
cutoff := time.Now().Add(-maxAge)
|
||||||
|
var tasks []model.Task
|
||||||
|
err := s.db.WithContext(ctx).
|
||||||
|
Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}).
|
||||||
|
Where("updated_at < ?", cutoff).
|
||||||
|
Find(&tasks).Error
|
||||||
|
return tasks, err
|
||||||
|
}
|
||||||
229
internal/store/task_store_test.go
Normal file
229
internal/store/task_store_test.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTaskTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Task{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTestTask(name, status string) *model.Task {
|
||||||
|
return &model.Task{
|
||||||
|
TaskName: name,
|
||||||
|
AppID: 1,
|
||||||
|
AppName: "test-app",
|
||||||
|
Status: status,
|
||||||
|
CurrentStep: "",
|
||||||
|
RetryCount: 0,
|
||||||
|
UserID: "user1",
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_CreateAndGetByID(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
task := makeTestTask("test-task", model.TaskStatusSubmitted)
|
||||||
|
id, err := s.Create(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
t.Errorf("Create() id = %d, want positive", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := s.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetByID() returned nil")
|
||||||
|
}
|
||||||
|
if got.TaskName != "test-task" {
|
||||||
|
t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task")
|
||||||
|
}
|
||||||
|
if got.Status != model.TaskStatusSubmitted {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
|
||||||
|
}
|
||||||
|
if got.ID != id {
|
||||||
|
t.Errorf("ID = %d, want %d", got.ID, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_GetByID_NotFound(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
|
||||||
|
got, err := s.GetByID(context.Background(), 999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByID() expected nil for not-found, got non-nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_ListPagination(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted))
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 5 {
|
||||||
|
t.Errorf("total = %d, want 5", total)
|
||||||
|
}
|
||||||
|
if len(tasks) != 3 {
|
||||||
|
t.Errorf("len(tasks) = %d, want 3", len(tasks))
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() page 2 error = %v", err)
|
||||||
|
}
|
||||||
|
if total2 != 5 {
|
||||||
|
t.Errorf("total2 = %d, want 5", total2)
|
||||||
|
}
|
||||||
|
if len(tasks2) != 2 {
|
||||||
|
t.Errorf("len(tasks2) = %d, want 2", len(tasks2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_ListStatusFilter(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning))
|
||||||
|
s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning))
|
||||||
|
s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted))
|
||||||
|
|
||||||
|
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("total = %d, want 2", total)
|
||||||
|
}
|
||||||
|
for _, t2 := range tasks {
|
||||||
|
if t2.Status != model.TaskStatusRunning {
|
||||||
|
t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_UpdateStatus(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted))
|
||||||
|
|
||||||
|
err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := s.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusPreparing {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing)
|
||||||
|
}
|
||||||
|
if got.StartedAt == nil {
|
||||||
|
t.Error("StartedAt expected non-nil after preparing status")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus(failed) error = %v", err)
|
||||||
|
}
|
||||||
|
got, _ = s.GetByID(ctx, id)
|
||||||
|
if got.ErrorMessage != "something broke" {
|
||||||
|
t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke")
|
||||||
|
}
|
||||||
|
if got.FinishedAt == nil {
|
||||||
|
t.Error("FinishedAt expected non-nil after failed status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_UpdateRetryState(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
task := makeTestTask("retry-test", model.TaskStatusFailed)
|
||||||
|
task.CurrentStep = model.TaskStepDownloading
|
||||||
|
task.RetryCount = 1
|
||||||
|
id, _ := s.Create(ctx, task)
|
||||||
|
|
||||||
|
err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateRetryState() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := s.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusSubmitted {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
|
||||||
|
}
|
||||||
|
if got.CurrentStep != model.TaskStepDownloading {
|
||||||
|
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading)
|
||||||
|
}
|
||||||
|
if got.RetryCount != 2 {
|
||||||
|
t.Errorf("RetryCount = %d, want 2", got.RetryCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_GetStuckTasks(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
stuck := makeTestTask("stuck-1", model.TaskStatusDownloading)
|
||||||
|
stuck.UpdatedAt = time.Now().Add(-1 * time.Hour)
|
||||||
|
s.Create(ctx, stuck)
|
||||||
|
|
||||||
|
recent := makeTestTask("recent-1", model.TaskStatusDownloading)
|
||||||
|
s.Create(ctx, recent)
|
||||||
|
|
||||||
|
done := makeTestTask("done-1", model.TaskStatusCompleted)
|
||||||
|
done.UpdatedAt = time.Now().Add(-2 * time.Hour)
|
||||||
|
s.Create(ctx, done)
|
||||||
|
|
||||||
|
tasks, err := s.GetStuckTasks(ctx, 30*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetStuckTasks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tasks) != 1 {
|
||||||
|
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
|
||||||
|
}
|
||||||
|
if tasks[0].TaskName != "stuck-1" {
|
||||||
|
t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user