package handler import ( "bytes" "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" "sync/atomic" "testing" "time" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/service" "gcy_hpc_server/internal/slurm" "gcy_hpc_server/internal/store" "github.com/gin-gonic/gin" "go.uber.org/zap" "gorm.io/driver/sqlite" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) var taskDBCounter atomic.Int64 func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) { t.Helper() dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1))) db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) if err != nil { t.Fatalf("open db: %v", err) } db.AutoMigrate(&model.Task{}, &model.Application{}) t.Cleanup(func() { os.Remove(dbFile) }) taskStore := store.NewTaskStore(db) appStore := store.NewApplicationStore(db) var jobSvc *service.JobService if slurmSrv != nil { client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client()) jobSvc = service.NewJobService(client, zap.NewNop()) } workDir := filepath.Join(t.TempDir(), "work") taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop()) h := NewTaskHandler(taskSvc, zap.NewNop()) return h, db } func setupTaskRouter(h *TaskHandler) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() v1 := r.Group("/api/v1") tasks := v1.Group("/tasks") tasks.POST("", h.CreateTask) tasks.GET("", h.ListTasks) return r } func createTestAppForTask(db *gorm.DB) int64 { app := &model.Application{ Name: "test-app", ScriptTemplate: "#!/bin/bash\necho hello", Parameters: json.RawMessage(`[]`), } db.Create(app) return app.ID } // ---- CreateTask Tests ---- func TestTaskHandler_CreateTask_Success(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: "my-task", }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusCreated { t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) if !resp["success"].(bool) { t.Fatal("expected success=true") } data := resp["data"].(map[string]interface{}) if _, ok := data["id"]; !ok { t.Fatal("expected id in response data") } } func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) { h, _ := setupTaskHandler(t, nil) r := setupTaskRouter(h) body := `{"task_name":"no-app"}` w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body))) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) } } func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) { h, _ := setupTaskHandler(t, nil) r := setupTaskRouter(h) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json"))) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) } } // ---- ListTasks Tests ---- func TestTaskHandler_ListTasks_Pagination(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() for i := 0; i < 5; i++ { body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: fmt.Sprintf("task-%d", i), }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) } // Wait for async processing time.Sleep(200 * time.Millisecond) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) if data["total"].(float64) != 5 { t.Fatalf("expected total=5, got %v", data["total"]) } items := data["items"].([]interface{}) if len(items) != 3 { t.Fatalf("expected 3 items, got %d", len(items)) } } func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) { slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)}) })) defer slurmSrv.Close() h, db := setupTaskHandler(t, slurmSrv) r := setupTaskRouter(h) appID := createTestAppForTask(db) taskSvc := h.svc.(*service.TaskService) ctx := context.Background() taskSvc.StartProcessor(ctx) defer taskSvc.StopProcessor() for i := 0; i < 3; i++ { body, _ := json.Marshal(model.CreateTaskRequest{ AppID: appID, TaskName: fmt.Sprintf("filter-task-%d", i), }) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) } // Wait for async processing time.Sleep(200 * time.Millisecond) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) items := data["items"].([]interface{}) for _, item := range items { m := item.(map[string]interface{}) if m["status"] != "queued" { t.Fatalf("expected status=queued, got %v", m["status"]) } } } func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) { h, db := setupTaskHandler(t, nil) r := setupTaskRouter(h) _ = createTestAppForTask(db) // Directly insert tasks via DB to avoid needing processor for i := 0; i < 15; i++ { task := &model.Task{ TaskName: fmt.Sprintf("default-task-%d", i), AppID: 1, AppName: "test-app", Status: model.TaskStatusSubmitted, SubmittedAt: time.Now(), } db.Create(task) } w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) data := resp["data"].(map[string]interface{}) if data["total"].(float64) != 15 { t.Fatalf("expected total=15, got %v", data["total"]) } items := data["items"].([]interface{}) if len(items) != 10 { t.Fatalf("expected 10 items (default page_size), got %d", len(items)) } }