230 lines
5.8 KiB
Go
230 lines
5.8 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"gcy_hpc_server/internal/model"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
func newTaskTestDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.Task{}); err != nil {
|
|
t.Fatalf("auto migrate: %v", err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func makeTestTask(name, status string) *model.Task {
|
|
return &model.Task{
|
|
TaskName: name,
|
|
AppID: 1,
|
|
AppName: "test-app",
|
|
Status: status,
|
|
CurrentStep: "",
|
|
RetryCount: 0,
|
|
UserID: "user1",
|
|
SubmittedAt: time.Now(),
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_CreateAndGetByID(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
ctx := context.Background()
|
|
|
|
task := makeTestTask("test-task", model.TaskStatusSubmitted)
|
|
id, err := s.Create(ctx, task)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
if id <= 0 {
|
|
t.Errorf("Create() id = %d, want positive", id)
|
|
}
|
|
|
|
got, err := s.GetByID(ctx, id)
|
|
if err != nil {
|
|
t.Fatalf("GetByID() error = %v", err)
|
|
}
|
|
if got == nil {
|
|
t.Fatal("GetByID() returned nil")
|
|
}
|
|
if got.TaskName != "test-task" {
|
|
t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task")
|
|
}
|
|
if got.Status != model.TaskStatusSubmitted {
|
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
|
|
}
|
|
if got.ID != id {
|
|
t.Errorf("ID = %d, want %d", got.ID, id)
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_GetByID_NotFound(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
|
|
got, err := s.GetByID(context.Background(), 999)
|
|
if err != nil {
|
|
t.Fatalf("GetByID() error = %v", err)
|
|
}
|
|
if got != nil {
|
|
t.Error("GetByID() expected nil for not-found, got non-nil")
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_ListPagination(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
ctx := context.Background()
|
|
|
|
for i := 0; i < 5; i++ {
|
|
s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted))
|
|
}
|
|
|
|
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
if total != 5 {
|
|
t.Errorf("total = %d, want 5", total)
|
|
}
|
|
if len(tasks) != 3 {
|
|
t.Errorf("len(tasks) = %d, want 3", len(tasks))
|
|
}
|
|
|
|
tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3})
|
|
if err != nil {
|
|
t.Fatalf("List() page 2 error = %v", err)
|
|
}
|
|
if total2 != 5 {
|
|
t.Errorf("total2 = %d, want 5", total2)
|
|
}
|
|
if len(tasks2) != 2 {
|
|
t.Errorf("len(tasks2) = %d, want 2", len(tasks2))
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_ListStatusFilter(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
ctx := context.Background()
|
|
|
|
s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning))
|
|
s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning))
|
|
s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted))
|
|
|
|
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
if total != 2 {
|
|
t.Errorf("total = %d, want 2", total)
|
|
}
|
|
for _, t2 := range tasks {
|
|
if t2.Status != model.TaskStatusRunning {
|
|
t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_UpdateStatus(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
ctx := context.Background()
|
|
|
|
id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted))
|
|
|
|
err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "")
|
|
if err != nil {
|
|
t.Fatalf("UpdateStatus() error = %v", err)
|
|
}
|
|
|
|
got, _ := s.GetByID(ctx, id)
|
|
if got.Status != model.TaskStatusPreparing {
|
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing)
|
|
}
|
|
if got.StartedAt == nil {
|
|
t.Error("StartedAt expected non-nil after preparing status")
|
|
}
|
|
|
|
err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke")
|
|
if err != nil {
|
|
t.Fatalf("UpdateStatus(failed) error = %v", err)
|
|
}
|
|
got, _ = s.GetByID(ctx, id)
|
|
if got.ErrorMessage != "something broke" {
|
|
t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke")
|
|
}
|
|
if got.FinishedAt == nil {
|
|
t.Error("FinishedAt expected non-nil after failed status")
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_UpdateRetryState(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
ctx := context.Background()
|
|
|
|
task := makeTestTask("retry-test", model.TaskStatusFailed)
|
|
task.CurrentStep = model.TaskStepDownloading
|
|
task.RetryCount = 1
|
|
id, _ := s.Create(ctx, task)
|
|
|
|
err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2)
|
|
if err != nil {
|
|
t.Fatalf("UpdateRetryState() error = %v", err)
|
|
}
|
|
|
|
got, _ := s.GetByID(ctx, id)
|
|
if got.Status != model.TaskStatusSubmitted {
|
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
|
|
}
|
|
if got.CurrentStep != model.TaskStepDownloading {
|
|
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading)
|
|
}
|
|
if got.RetryCount != 2 {
|
|
t.Errorf("RetryCount = %d, want 2", got.RetryCount)
|
|
}
|
|
}
|
|
|
|
func TestTaskStore_GetStuckTasks(t *testing.T) {
|
|
db := newTaskTestDB(t)
|
|
s := NewTaskStore(db)
|
|
ctx := context.Background()
|
|
|
|
stuck := makeTestTask("stuck-1", model.TaskStatusDownloading)
|
|
stuck.UpdatedAt = time.Now().Add(-1 * time.Hour)
|
|
s.Create(ctx, stuck)
|
|
|
|
recent := makeTestTask("recent-1", model.TaskStatusDownloading)
|
|
s.Create(ctx, recent)
|
|
|
|
done := makeTestTask("done-1", model.TaskStatusCompleted)
|
|
done.UpdatedAt = time.Now().Add(-2 * time.Hour)
|
|
s.Create(ctx, done)
|
|
|
|
tasks, err := s.GetStuckTasks(ctx, 30*time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("GetStuckTasks() error = %v", err)
|
|
}
|
|
|
|
if len(tasks) != 1 {
|
|
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
|
|
}
|
|
if tasks[0].TaskName != "stuck-1" {
|
|
t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1")
|
|
}
|
|
}
|