package service import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/slurm" "gcy_hpc_server/internal/store" "go.uber.org/zap" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) func newTaskSvcTestDB(t *testing.T) *gorm.DB { t.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.Task{}); err != nil { t.Fatalf("auto migrate: %v", err) } return db } type taskSvcTestEnv struct { taskStore *store.TaskStore jobSvc *JobService svc *TaskService srv *httptest.Server db *gorm.DB } func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv { t.Helper() db := newTaskSvcTestDB(t) ts := store.NewTaskStore(db) srv := httptest.NewServer(handler) client, _ := slurm.NewClient(srv.URL, srv.Client()) jobSvc := NewJobService(client, zap.NewNop()) svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop()) return &taskSvcTestEnv{ taskStore: ts, jobSvc: jobSvc, svc: svc, srv: srv, db: db, } } func (e *taskSvcTestEnv) close() { e.srv.Close() } func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task { return &model.Task{ TaskName: name, AppID: 1, AppName: "test-app", Status: status, CurrentStep: "", RetryCount: 0, UserID: "user1", SubmittedAt: time.Now(), SlurmJobID: slurmJobID, } } func TestTaskService_MapSlurmState_AllStates(t *testing.T) { env := newTaskSvcTestEnv(t, nil) defer env.close() cases := []struct { input []string expected string }{ {[]string{"PENDING"}, model.TaskStatusQueued}, {[]string{"RUNNING"}, model.TaskStatusRunning}, {[]string{"CONFIGURING"}, model.TaskStatusRunning}, {[]string{"COMPLETING"}, model.TaskStatusRunning}, {[]string{"COMPLETED"}, model.TaskStatusCompleted}, {[]string{"FAILED"}, model.TaskStatusFailed}, {[]string{"CANCELLED"}, model.TaskStatusFailed}, {[]string{"TIMEOUT"}, model.TaskStatusFailed}, {[]string{"NODE_FAIL"}, model.TaskStatusFailed}, {[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed}, {[]string{"PREEMPTED"}, model.TaskStatusFailed}, {[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning}, {[]string{"unknown_state"}, model.TaskStatusRunning}, {[]string{"pending"}, model.TaskStatusQueued}, {[]string{"Running"}, model.TaskStatusRunning}, } for _, tc := range cases { got := env.svc.mapSlurmStateToTaskStatus(tc.input) if got != tc.expected { t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected) } } } func TestTaskService_MapSlurmState_Empty(t *testing.T) { env := newTaskSvcTestEnv(t, nil) defer env.close() got := env.svc.mapSlurmStateToTaskStatus([]string{}) if got != model.TaskStatusRunning { t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning) } got = env.svc.mapSlurmStateToTaskStatus(nil) if got != model.TaskStatusRunning { t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning) } } func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) { jobID := int32(42) env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := slurm.OpenapiJobInfoResp{ Jobs: slurm.JobInfoMsg{ { JobID: &jobID, JobState: []string{"RUNNING"}, }, }, } json.NewEncoder(w).Encode(resp) })) defer env.close() ctx := context.Background() task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID) id, err := env.taskStore.Create(ctx, task) if err != nil { t.Fatalf("Create: %v", err) } err = env.svc.refreshTaskStatus(ctx, id) if err != nil { t.Fatalf("refreshTaskStatus: %v", err) } updated, _ := env.taskStore.GetByID(ctx, id) if updated.Status != model.TaskStatusRunning { t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning) } } func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) { env := newTaskSvcTestEnv(t, nil) defer env.close() ctx := context.Background() task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil) id, _ := env.taskStore.Create(ctx, task) err := env.svc.refreshTaskStatus(ctx, id) if err != nil { t.Fatalf("expected no error, got %v", err) } got, _ := env.taskStore.GetByID(ctx, id) if got.Status != model.TaskStatusQueued { t.Errorf("status should remain unchanged, got %q", got.Status) } } func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) { jobID := int32(42) env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(`{"error":"down"}`)) })) defer env.close() ctx := context.Background() task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID) id, _ := env.taskStore.Create(ctx, task) err := env.svc.refreshTaskStatus(ctx, id) if err != nil { t.Fatalf("expected no error (soft fail), got %v", err) } got, _ := env.taskStore.GetByID(ctx, id) if got.Status != model.TaskStatusQueued { t.Errorf("status should remain unchanged on slurm error, got %q", got.Status) } } func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) { jobID := int32(42) env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := slurm.OpenapiJobInfoResp{ Jobs: slurm.JobInfoMsg{ { JobID: &jobID, JobState: []string{"RUNNING"}, }, }, } json.NewEncoder(w).Encode(resp) })) defer env.close() ctx := context.Background() task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID) id, _ := env.taskStore.Create(ctx, task) err := env.svc.refreshTaskStatus(ctx, id) if err != nil { t.Fatalf("refreshTaskStatus: %v", err) } got, _ := env.taskStore.GetByID(ctx, id) if got.Status != model.TaskStatusRunning { t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning) } } func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) { jobID := int32(42) slurmQueried := false env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { slurmQueried = true w.WriteHeader(http.StatusInternalServerError) })) defer env.close() ctx := context.Background() task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID) id, _ := env.taskStore.Create(ctx, task) freshTask, _ := env.taskStore.GetByID(ctx, id) if freshTask == nil { t.Fatal("task not found") } env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id) err := env.svc.RefreshStaleTasks(ctx) if err != nil { t.Fatalf("RefreshStaleTasks: %v", err) } if slurmQueried { t.Error("expected no Slurm query for fresh task") } } func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) { jobID := int32(42) env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := slurm.OpenapiJobInfoResp{ Jobs: slurm.JobInfoMsg{ { JobID: &jobID, JobState: []string{"COMPLETED"}, }, }, } json.NewEncoder(w).Encode(resp) })) defer env.close() ctx := context.Background() task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID) id, _ := env.taskStore.Create(ctx, task) staleTime := time.Now().Add(-60 * time.Second) env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id) err := env.svc.RefreshStaleTasks(ctx) if err != nil { t.Fatalf("RefreshStaleTasks: %v", err) } got, _ := env.taskStore.GetByID(ctx, id) if got.Status != model.TaskStatusCompleted { t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted) } }