package handler import ( "bytes" "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" "sync/atomic" "testing" "time" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/service" "gcy_hpc_server/internal/slurm" "gcy_hpc_server/internal/store" "github.com/gin-gonic/gin" "go.uber.org/zap" "gorm.io/driver/sqlite" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) var taskDBCounter atomic.Int64 func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) { t.Helper() dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1))) db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) if err != nil { t.Fatalf("open db: %v", err) } db.AutoMigrate(&model.Task{}, &model.Application{}) t.Cleanup(func() { os.Remove(dbFile) }) taskStore := store.NewTaskStore(db) appStore := store.NewApplicationStore(db) var jobSvc *service.JobService if slurmSrv != nil { client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client()) jobSvc = service.NewJobService(client, zap.NewNop()) } workDir := filepath.Join(t.TempDir(), "work") taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop()) h := NewTaskHandler(taskSvc, zap.NewNop()) return h, db } func setupTaskRouter(h *TaskHandler) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() v1 := r.Group("/api/v1") tasks := v1.Group("/tasks") tasks.POST("", h.CreateTask) tasks.GET("", h.ListTasks) return r } func createTestAppForTask(db *gorm.DB) int64 { app := &model.Application{ Name: "test-app", ScriptTemplate: "#!/bin/bash\necho hello", Parameters: json.RawMessage(`[]`), } db.Create(app) return app.ID } // ---- CreateTask Tests ---- func TestTaskHandler_CreateTask_Success(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: "my-task", }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusCreated { t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) if !resp["success"].(bool) { t.Fatal("expected success=true") } data := resp["data"].(map[string]interface{}) if _, ok := data["id"]; !ok { t.Fatal("expected id in response data") } } func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) { h, _ := setupTaskHandler(t, nil) r := setupTaskRouter(h) body := `{"task_name":"no-app"}` w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body))) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) } } func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) { h, _ := setupTaskHandler(t, nil) r := setupTaskRouter(h) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json"))) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) } } // ---- ListTasks Tests ---- func TestTaskHandler_ListTasks_Pagination(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() for i := 0; i < 5; i++ { body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: fmt.Sprintf("task-%d", i), }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) } // Wait for async processing time.Sleep(200 * time.Millisecond) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) if data["total"].(float64) != 5 { t.Fatalf("expected total=5, got %v", data["total"]) } items := data["items"].([]interface{}) if len(items) != 3 { t.Fatalf("expected 3 items, got %d", len(items)) } } func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() for i := 0; i < 3; i++ { body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: fmt.Sprintf("filter-task-%d", i), }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) } // Wait for async processing time.Sleep(200 * time.Millisecond) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) items := data["items"].([]interface{}) for _, item := range items { m := item.(map[string]interface{}) if m["status"] != "queued" { t.Fatalf("expected status=queued, got %v", m["status"]) } } } func TestTaskHandler_CreateTask_WithSchedulingFields(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() cpus := int32(8) memNode := int64(4096) tl := int32(60) qos := "high" body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: "sched-task", Partition: ptrStr("gpu"), Cpus: &cpus, MemoryPerNode: &memNode, TimeLimit: &tl, QOS: &qos, }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusCreated { t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) if !resp["success"].(bool) { t.Fatal("expected success=true") } data := resp["data"].(map[string]interface{}) if _, ok := data["id"]; !ok { t.Fatal("expected id in response data") } taskID := int64(data["id"].(float64)) // Wait for async processing time.Sleep(300 * time.Millisecond) // Verify persisted scheduling fields var task model.Task db.First(&task, taskID) if task.Partition != "gpu" { t.Errorf("expected partition=gpu, got %q", task.Partition) } if task.Cpus == nil || *task.Cpus != 8 { t.Errorf("expected cpus=8, got %v", task.Cpus) } if task.MemoryPerNode == nil || *task.MemoryPerNode != 4096 { t.Errorf("expected memory_per_node=4096, got %v", task.MemoryPerNode) } if task.TimeLimit == nil || *task.TimeLimit != 60 { t.Errorf("expected time_limit=60, got %v", task.TimeLimit) } if task.QOS == nil || *task.QOS != "high" { t.Errorf("expected qos=high, got %v", task.QOS) } } func TestTaskHandler_ListTasks_ReturnsSchedulingFields(t *testing.T) { h, db := setupTaskHandler(t, nil) r := setupTaskRouter(h) _ = createTestAppForTask(db) cpus := int32(16) memNode := int64(8192) tl := int32(120) partition := "gpu" qos := "normal" task := &model.Task{ TaskName: "sched-list-task", AppID: 1, AppName: "test-app", Status: model.TaskStatusSubmitted, SubmittedAt: time.Now(), Partition: partition, Cpus: &cpus, MemoryPerNode: &memNode, TimeLimit: &tl, QOS: &qos, } db.Create(task) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) items := data["items"].([]interface{}) if len(items) == 0 { t.Fatal("expected at least 1 item") } item := items[0].(map[string]interface{}) if item["partition"] != "gpu" { t.Errorf("expected partition=gpu, got %v", item["partition"]) } if item["cpus"] == nil { t.Error("expected cpus to be set") } else if item["cpus"].(float64) != 16 { t.Errorf("expected cpus=16, got %v", item["cpus"]) } if item["memory_per_node"] == nil { t.Error("expected memory_per_node to be set") } else if item["memory_per_node"].(float64) != 8192 { t.Errorf("expected memory_per_node=8192, got %v", item["memory_per_node"]) } if item["time_limit"] == nil { t.Error("expected time_limit to be set") } else if item["time_limit"].(float64) != 120 { t.Errorf("expected time_limit=120, got %v", item["time_limit"]) } if item["qos"] == nil { t.Error("expected qos to be set") } else if item["qos"].(string) != "normal" { t.Errorf("expected qos=normal, got %v", item["qos"]) } } func ptrStr(s string) *string { return &s } func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) { h, db := setupTaskHandler(t, nil) r := setupTaskRouter(h) _ = createTestAppForTask(db) // Directly insert tasks via DB to avoid needing processor for i := 0; i < 15; i++ { task := &model.Task{ TaskName: fmt.Sprintf("default-task-%d", i), AppID: 1, AppName: "test-app", Status: model.TaskStatusSubmitted, SubmittedAt: time.Now(), } db.Create(task) } w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) if data["total"].(float64) != 15 { t.Fatalf("expected total=15, got %v", data["total"]) } items := data["items"].([]interface{}) if len(items) != 10 { t.Fatalf("expected 10 items (default page_size), got %d", len(items)) } }