package main import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "os" "path/filepath" "sync/atomic" "testing" "time" "gcy_hpc_server/internal/handler" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/server" "gcy_hpc_server/internal/service" "gcy_hpc_server/internal/slurm" "gcy_hpc_server/internal/store" "go.uber.org/zap" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) func newTaskTestDB(t *testing.T) *gorm.DB { t.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) if err != nil { t.Fatalf("open sqlite: %v", err) } sqlDB, _ := db.DB() sqlDB.SetMaxOpenConns(1) db.AutoMigrate( &model.Application{}, &model.File{}, &model.FileBlob{}, &model.Task{}, ) t.Cleanup(func() { sqlDB.Close() }) return db } type mockSlurmHandler struct { submitFn func(w http.ResponseWriter, r *http.Request) } func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server { t.Helper() mux := http.NewServeMux() mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) { if h != nil && h.submitFn != nil { h.submitFn(w, r) return } w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`) }) mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}}) }) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{}) }) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) return srv } func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) { t.Helper() return setupTaskTestServerWithSlurm(t, nil) } func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) { t.Helper() db := newTaskTestDB(t) slurmSrv := newMockSlurmServer(t, slurmHandler) client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client())) if err != nil { t.Fatalf("slurm client: %v", err) } log := zap.NewNop() jobSvc := service.NewJobService(client, log) appStore := store.NewApplicationStore(db) taskStore := store.NewTaskStore(db) tmpDir, err := os.MkdirTemp("", "task-test-workdir-*") if err != nil { t.Fatalf("temp dir: %v", err) } taskSvc := service.NewTaskService( taskStore, appStore, nil, nil, nil, jobSvc, tmpDir, log, ) ctx, cancel := context.WithCancel(context.Background()) taskSvc.StartProcessor(ctx) appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc) appH := handler.NewApplicationHandler(appSvc, log) taskH := handler.NewTaskHandler(taskSvc, log) router := server.NewRouter( handler.NewJobHandler(jobSvc, log), handler.NewClusterHandler(service.NewClusterService(client, log), log), appH, nil, nil, nil, taskH, log, ) httpSrv := httptest.NewServer(router) cleanup := func() { taskSvc.StopProcessor() cancel() httpSrv.Close() os.RemoveAll(tmpDir) } return httpSrv, taskStore, cleanup } func createTestApp(t *testing.T, srvURL string) int64 { t.Helper() body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}` resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body))) if err != nil { t.Fatalf("create app: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { b, _ := io.ReadAll(resp.Body) t.Fatalf("create app: status %d, body %s", resp.StatusCode, b) } var result struct { Data struct { ID int64 `json:"id"` } `json:"data"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { t.Fatalf("decode app response: %v", err) } return result.Data.ID } func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) { t.Helper() resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body))) if err != nil { t.Fatalf("post task: %v", err) } defer resp.Body.Close() b, _ := io.ReadAll(resp.Body) var result map[string]interface{} json.Unmarshal(b, &result) return resp, result } func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) { t.Helper() url := srvURL + "/api/v1/tasks" if query != "" { url += "?" + query } resp, err := http.Get(url) if err != nil { t.Fatalf("get tasks: %v", err) } defer resp.Body.Close() b, _ := io.ReadAll(resp.Body) var result map[string]interface{} json.Unmarshal(b, &result) data, _ := result["data"].(map[string]interface{}) items, _ := data["items"].([]interface{}) total, _ := data["total"].(float64) return int(total), items } func TestTask_FullLifecycle(t *testing.T) { srv, _, cleanup := setupTaskTestServer(t) defer cleanup() appID := createTestApp(t, srv.URL) resp, result := postTask(t, srv.URL, fmt.Sprintf( `{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID, )) if resp.StatusCode != http.StatusCreated { t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) } data, _ := result["data"].(map[string]interface{}) if data["id"] == nil { t.Fatal("expected id in response") } time.Sleep(300 * time.Millisecond) total, items := getTasks(t, srv.URL, "") if total < 1 { t.Fatalf("expected at least 1 task, got %d", total) } task := items[0].(map[string]interface{}) if task["status"] == nil || task["status"] == "" { t.Error("expected non-empty status") } if task["task_name"] != "lifecycle-test" { t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"]) } if task["app_id"] != float64(appID) { t.Errorf("expected app_id=%d, got %v", appID, task["app_id"]) } } func TestTask_CreateWithMissingApp(t *testing.T) { srv, _, cleanup := setupTaskTestServer(t) defer cleanup() resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`) if resp.StatusCode != http.StatusNotFound { t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result) } } func TestTask_CreateWithInvalidBody(t *testing.T) { srv, _, cleanup := setupTaskTestServer(t) defer cleanup() resp, result := postTask(t, srv.URL, `{}`) if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result) } } func TestTask_FileLimitExceeded(t *testing.T) { srv, _, cleanup := setupTaskTestServer(t) defer cleanup() appID := createTestApp(t, srv.URL) fileIDs := make([]int64, 101) for i := range fileIDs { fileIDs[i] = int64(i + 1) } idsJSON, _ := json.Marshal(fileIDs) body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON)) resp, result := postTask(t, srv.URL, body) if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result) } } func TestTask_RetryScenario(t *testing.T) { var failCount int32 slurmH := &mockSlurmHandler{ submitFn: func(w http.ResponseWriter, r *http.Request) { if atomic.AddInt32(&failCount, 1) <= 2 { w.WriteHeader(http.StatusInternalServerError) fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`) return } w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`) }, } srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH) defer cleanup() appID := createTestApp(t, srv.URL) resp, result := postTask(t, srv.URL, fmt.Sprintf( `{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID, )) if resp.StatusCode != http.StatusCreated { t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) } taskID := int64(result["data"].(map[string]interface{})["id"].(float64)) deadline := time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { task, _ := taskStore.GetByID(context.Background(), taskID) if task != nil && task.Status == model.TaskStatusQueued { if task.SlurmJobID != nil && *task.SlurmJobID == 99 { return } } time.Sleep(100 * time.Millisecond) } task, _ := taskStore.GetByID(context.Background(), taskID) if task == nil { t.Fatalf("task %d not found after deadline", taskID) } t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v", task.Status, task.RetryCount, task.SlurmJobID) } // [已禁用] 前端已全部迁移到 POST /tasks 接口,旧 API 兼容性测试不再需要。 /* func TestTask_OldAPICompatibility(t *testing.T) { srv, _, cleanup := setupTaskTestServer(t) defer cleanup() appID := createTestApp(t, srv.URL) body := `{"values":{"np":"8"}}` url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID) resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body))) if err != nil { t.Fatalf("post submit: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusCreated { b, _ := io.ReadAll(resp.Body) t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b) } var result map[string]interface{} json.NewDecoder(resp.Body).Decode(&result) data, _ := result["data"].(map[string]interface{}) if data == nil { t.Fatal("expected data in response") } if data["job_id"] == nil { t.Errorf("expected job_id in old API response, got: %v", data) } jobID, ok := data["job_id"].(float64) if !ok || jobID == 0 { t.Errorf("expected non-zero job_id, got %v", data["job_id"]) } } */ func TestTask_ListWithFilters(t *testing.T) { srv, taskStore, cleanup := setupTaskTestServer(t) defer cleanup() appID := createTestApp(t, srv.URL) now := time.Now() taskStore.Create(context.Background(), &model.Task{ TaskName: "task-completed", AppID: appID, AppName: "test-app", Status: model.TaskStatusCompleted, SubmittedAt: now, }) taskStore.Create(context.Background(), &model.Task{ TaskName: "task-failed", AppID: appID, AppName: "test-app", Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now, }) taskStore.Create(context.Background(), &model.Task{ TaskName: "task-queued", AppID: appID, AppName: "test-app", Status: model.TaskStatusQueued, SubmittedAt: now, }) total, items := getTasks(t, srv.URL, "status=completed") if total != 1 { t.Fatalf("expected 1 completed, got %d", total) } task := items[0].(map[string]interface{}) if task["status"] != "completed" { t.Errorf("expected status=completed, got %v", task["status"]) } total, items = getTasks(t, srv.URL, "status=failed") if total != 1 { t.Fatalf("expected 1 failed, got %d", total) } total, _ = getTasks(t, srv.URL, "") if total != 3 { t.Fatalf("expected 3 total, got %d", total) } } func TestTask_WorkDirCreated(t *testing.T) { srv, taskStore, cleanup := setupTaskTestServer(t) defer cleanup() appID := createTestApp(t, srv.URL) resp, result := postTask(t, srv.URL, fmt.Sprintf( `{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID, )) if resp.StatusCode != http.StatusCreated { t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) } taskID := int64(result["data"].(map[string]interface{})["id"].(float64)) deadline := time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { task, _ := taskStore.GetByID(context.Background(), taskID) if task != nil && task.WorkDir != "" { if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) { t.Fatalf("work dir %s not created", task.WorkDir) } if !filepath.IsAbs(task.WorkDir) { t.Errorf("expected absolute work dir, got %s", task.WorkDir) } return } time.Sleep(50 * time.Millisecond) } task, _ := taskStore.GetByID(context.Background(), taskID) if task == nil { t.Fatalf("task %d not found after deadline", taskID) } t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir) }