package store import ( "context" "testing" "time" "gcy_hpc_server/internal/model" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) func newTaskTestDB(t *testing.T) *gorm.DB { t.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.Task{}); err != nil { t.Fatalf("auto migrate: %v", err) } return db } func makeTestTask(name, status string) *model.Task { return &model.Task{ TaskName: name, AppID: 1, AppName: "test-app", Status: status, CurrentStep: "", RetryCount: 0, UserID: "user1", SubmittedAt: time.Now(), } } func TestTaskStore_CreateAndGetByID(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) ctx := context.Background() task := makeTestTask("test-task", model.TaskStatusSubmitted) id, err := s.Create(ctx, task) if err != nil { t.Fatalf("Create() error = %v", err) } if id <= 0 { t.Errorf("Create() id = %d, want positive", id) } got, err := s.GetByID(ctx, id) if err != nil { t.Fatalf("GetByID() error = %v", err) } if got == nil { t.Fatal("GetByID() returned nil") } if got.TaskName != "test-task" { t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task") } if got.Status != model.TaskStatusSubmitted { t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted) } if got.ID != id { t.Errorf("ID = %d, want %d", got.ID, id) } } func TestTaskStore_GetByID_NotFound(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) got, err := s.GetByID(context.Background(), 999) if err != nil { t.Fatalf("GetByID() error = %v", err) } if got != nil { t.Error("GetByID() expected nil for not-found, got non-nil") } } func TestTaskStore_ListPagination(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) ctx := context.Background() for i := 0; i < 5; i++ { s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted)) } tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3}) if err != nil { t.Fatalf("List() error = %v", err) } if total != 5 { t.Errorf("total = %d, want 5", total) } if len(tasks) != 3 { t.Errorf("len(tasks) = %d, want 3", len(tasks)) } tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3}) if err != nil { t.Fatalf("List() page 2 error = %v", err) } if total2 != 5 { t.Errorf("total2 = %d, want 5", total2) } if len(tasks2) != 2 { t.Errorf("len(tasks2) = %d, want 2", len(tasks2)) } } func TestTaskStore_ListStatusFilter(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) ctx := context.Background() s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning)) s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning)) s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted)) tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning}) if err != nil { t.Fatalf("List() error = %v", err) } if total != 2 { t.Errorf("total = %d, want 2", total) } for _, t2 := range tasks { if t2.Status != model.TaskStatusRunning { t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning) } } } func TestTaskStore_UpdateStatus(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) ctx := context.Background() id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted)) err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "") if err != nil { t.Fatalf("UpdateStatus() error = %v", err) } got, _ := s.GetByID(ctx, id) if got.Status != model.TaskStatusPreparing { t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing) } if got.StartedAt == nil { t.Error("StartedAt expected non-nil after preparing status") } err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke") if err != nil { t.Fatalf("UpdateStatus(failed) error = %v", err) } got, _ = s.GetByID(ctx, id) if got.ErrorMessage != "something broke" { t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke") } if got.FinishedAt == nil { t.Error("FinishedAt expected non-nil after failed status") } } func TestTaskStore_UpdateRetryState(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) ctx := context.Background() task := makeTestTask("retry-test", model.TaskStatusFailed) task.CurrentStep = model.TaskStepDownloading task.RetryCount = 1 id, _ := s.Create(ctx, task) err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2) if err != nil { t.Fatalf("UpdateRetryState() error = %v", err) } got, _ := s.GetByID(ctx, id) if got.Status != model.TaskStatusSubmitted { t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted) } if got.CurrentStep != model.TaskStepDownloading { t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading) } if got.RetryCount != 2 { t.Errorf("RetryCount = %d, want 2", got.RetryCount) } } func TestTaskStore_GetStuckTasks(t *testing.T) { db := newTaskTestDB(t) s := NewTaskStore(db) ctx := context.Background() stuck := makeTestTask("stuck-1", model.TaskStatusDownloading) stuck.UpdatedAt = time.Now().Add(-1 * time.Hour) s.Create(ctx, stuck) recent := makeTestTask("recent-1", model.TaskStatusDownloading) s.Create(ctx, recent) done := makeTestTask("done-1", model.TaskStatusCompleted) done.UpdatedAt = time.Now().Add(-2 * time.Hour) s.Create(ctx, done) tasks, err := s.GetStuckTasks(ctx, 30*time.Minute) if err != nil { t.Fatalf("GetStuckTasks() error = %v", err) } if len(tasks) != 1 { t.Fatalf("len(tasks) = %d, want 1", len(tasks)) } if tasks[0].TaskName != "stuck-1" { t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1") } }