Files
hpc/cmd/server/task_test.go

432 lines
12 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(1)
db.AutoMigrate(
&model.Application{},
&model.File{},
&model.FileBlob{},
&model.Task{},
)
t.Cleanup(func() {
sqlDB.Close()
})
return db
}
type mockSlurmHandler struct {
submitFn func(w http.ResponseWriter, r *http.Request)
}
func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
if h != nil && h.submitFn != nil {
h.submitFn(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`)
})
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) {
t.Helper()
return setupTaskTestServerWithSlurm(t, nil)
}
func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) {
t.Helper()
db := newTaskTestDB(t)
slurmSrv := newMockSlurmServer(t, slurmHandler)
client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
if err != nil {
t.Fatalf("slurm client: %v", err)
}
log := zap.NewNop()
jobSvc := service.NewJobService(client, log)
appStore := store.NewApplicationStore(db)
taskStore := store.NewTaskStore(db)
tmpDir, err := os.MkdirTemp("", "task-test-workdir-*")
if err != nil {
t.Fatalf("temp dir: %v", err)
}
taskSvc := service.NewTaskService(
taskStore, appStore,
nil, nil, nil,
jobSvc,
tmpDir,
log,
)
ctx, cancel := context.WithCancel(context.Background())
taskSvc.StartProcessor(ctx)
appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc)
appH := handler.NewApplicationHandler(appSvc, log)
taskH := handler.NewTaskHandler(taskSvc, log)
router := server.NewRouter(
handler.NewJobHandler(jobSvc, log),
handler.NewClusterHandler(service.NewClusterService(client, log), log),
appH,
nil, nil, nil,
taskH,
log,
)
httpSrv := httptest.NewServer(router)
cleanup := func() {
taskSvc.StopProcessor()
cancel()
httpSrv.Close()
os.RemoveAll(tmpDir)
}
return httpSrv, taskStore, cleanup
}
func createTestApp(t *testing.T, srvURL string) int64 {
t.Helper()
body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}`
resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("create app: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("create app: status %d, body %s", resp.StatusCode, b)
}
var result struct {
Data struct {
ID int64 `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("decode app response: %v", err)
}
return result.Data.ID
}
func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) {
t.Helper()
resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("post task: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
var result map[string]interface{}
json.Unmarshal(b, &result)
return resp, result
}
func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) {
t.Helper()
url := srvURL + "/api/v1/tasks"
if query != "" {
url += "?" + query
}
resp, err := http.Get(url)
if err != nil {
t.Fatalf("get tasks: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
var result map[string]interface{}
json.Unmarshal(b, &result)
data, _ := result["data"].(map[string]interface{})
items, _ := data["items"].([]interface{})
total, _ := data["total"].(float64)
return int(total), items
}
func TestTask_FullLifecycle(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
data, _ := result["data"].(map[string]interface{})
if data["id"] == nil {
t.Fatal("expected id in response")
}
time.Sleep(300 * time.Millisecond)
total, items := getTasks(t, srv.URL, "")
if total < 1 {
t.Fatalf("expected at least 1 task, got %d", total)
}
task := items[0].(map[string]interface{})
if task["status"] == nil || task["status"] == "" {
t.Error("expected non-empty status")
}
if task["task_name"] != "lifecycle-test" {
t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"])
}
if task["app_id"] != float64(appID) {
t.Errorf("expected app_id=%d, got %v", appID, task["app_id"])
}
}
func TestTask_CreateWithMissingApp(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_CreateWithInvalidBody(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
resp, result := postTask(t, srv.URL, `{}`)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_FileLimitExceeded(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
fileIDs := make([]int64, 101)
for i := range fileIDs {
fileIDs[i] = int64(i + 1)
}
idsJSON, _ := json.Marshal(fileIDs)
body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON))
resp, result := postTask(t, srv.URL, body)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_RetryScenario(t *testing.T) {
var failCount int32
slurmH := &mockSlurmHandler{
submitFn: func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&failCount, 1) <= 2 {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
return
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`)
},
}
srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
task, _ := taskStore.GetByID(context.Background(), taskID)
if task != nil && task.Status == model.TaskStatusQueued {
if task.SlurmJobID != nil && *task.SlurmJobID == 99 {
return
}
}
time.Sleep(100 * time.Millisecond)
}
task, _ := taskStore.GetByID(context.Background(), taskID)
if task == nil {
t.Fatalf("task %d not found after deadline", taskID)
}
t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v",
task.Status, task.RetryCount, task.SlurmJobID)
}
func TestTask_OldAPICompatibility(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
body := `{"values":{"np":"8"}}`
url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID)
resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("post submit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b)
}
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
data, _ := result["data"].(map[string]interface{})
if data == nil {
t.Fatal("expected data in response")
}
if data["job_id"] == nil {
t.Errorf("expected job_id in old API response, got: %v", data)
}
jobID, ok := data["job_id"].(float64)
if !ok || jobID == 0 {
t.Errorf("expected non-zero job_id, got %v", data["job_id"])
}
}
func TestTask_ListWithFilters(t *testing.T) {
srv, taskStore, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
now := time.Now()
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-completed", AppID: appID, AppName: "test-app",
Status: model.TaskStatusCompleted, SubmittedAt: now,
})
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-failed", AppID: appID, AppName: "test-app",
Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now,
})
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-queued", AppID: appID, AppName: "test-app",
Status: model.TaskStatusQueued, SubmittedAt: now,
})
total, items := getTasks(t, srv.URL, "status=completed")
if total != 1 {
t.Fatalf("expected 1 completed, got %d", total)
}
task := items[0].(map[string]interface{})
if task["status"] != "completed" {
t.Errorf("expected status=completed, got %v", task["status"])
}
total, items = getTasks(t, srv.URL, "status=failed")
if total != 1 {
t.Fatalf("expected 1 failed, got %d", total)
}
total, _ = getTasks(t, srv.URL, "")
if total != 3 {
t.Fatalf("expected 3 total, got %d", total)
}
}
func TestTask_WorkDirCreated(t *testing.T) {
srv, taskStore, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
task, _ := taskStore.GetByID(context.Background(), taskID)
if task != nil && task.WorkDir != "" {
if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) {
t.Fatalf("work dir %s not created", task.WorkDir)
}
if !filepath.IsAbs(task.WorkDir) {
t.Errorf("expected absolute work dir, got %s", task.WorkDir)
}
return
}
time.Sleep(50 * time.Millisecond)
}
task, _ := taskStore.GetByID(context.Background(), taskID)
if task == nil {
t.Fatalf("task %d not found after deadline", taskID)
}
t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir)
}