Files
hpc/cmd/server/integration_task_test.go

519 lines
14 KiB
Go

package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/testutil/testenv"
)
// taskAPIResponse decodes the unified API response envelope.
type taskAPIResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// taskCreateData is the data payload from a successful task creation.
type taskCreateData struct {
ID int64 `json:"id"`
}
// taskListData is the data payload from listing tasks.
type taskListData struct {
Items []taskListItem `json:"items"`
Total int64 `json:"total"`
}
type taskListItem struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
AppID int64 `json:"app_id"`
Status string `json:"status"`
SlurmJobID *int32 `json:"slurm_job_id"`
Partition string `json:"partition,omitempty"`
Cpus *int32 `json:"cpus,omitempty"`
MemoryPerNode *int64 `json:"memory_per_node,omitempty"`
MemoryPerCpu *int64 `json:"memory_per_cpu,omitempty"`
TimeLimit *int32 `json:"time_limit,omitempty"`
QOS *string `json:"qos,omitempty"`
JobName *string `json:"job_name,omitempty"`
Nodes *string `json:"nodes,omitempty"`
Tasks *int32 `json:"tasks,omitempty"`
CpusPerTask *int32 `json:"cpus_per_task,omitempty"`
Constraints *string `json:"constraints,omitempty"`
Reservation *string `json:"reservation,omitempty"`
Account *string `json:"account,omitempty"`
Nice *int32 `json:"nice,omitempty"`
MailType *string `json:"mail_type,omitempty"`
MailUser *string `json:"mail_user,omitempty"`
StandardOutput *string `json:"standard_output,omitempty"`
StandardError *string `json:"standard_error,omitempty"`
StandardInput *string `json:"standard_input,omitempty"`
RequiredNodes *string `json:"required_nodes,omitempty"`
ExcludedNodes *string `json:"excluded_nodes,omitempty"`
BeginTime *int64 `json:"begin_time,omitempty"`
Deadline *int64 `json:"deadline,omitempty"`
Array *string `json:"array,omitempty"`
Dependency *string `json:"dependency,omitempty"`
Requeue *bool `json:"requeue,omitempty"`
KillOnNodeFail *bool `json:"kill_on_node_fail,omitempty"`
}
// taskSendReq sends an HTTP request via the test env and returns the response.
func taskSendReq(t *testing.T, env *testenv.TestEnv, method, path string, body string) *http.Response {
t.Helper()
var r io.Reader
if body != "" {
r = strings.NewReader(body)
}
resp := env.DoRequest(method, path, r)
return resp
}
// taskParseResp decodes the response body into a taskAPIResponse.
func taskParseResp(t *testing.T, resp *http.Response) taskAPIResponse {
t.Helper()
b, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatalf("read response body: %v", err)
}
var result taskAPIResponse
if err := json.Unmarshal(b, &result); err != nil {
t.Fatalf("unmarshal response: %v (body: %s)", err, string(b))
}
return result
}
// taskCreateViaAPI creates a task via the HTTP API and returns the task ID.
func taskCreateViaAPI(t *testing.T, env *testenv.TestEnv, appID int64, taskName string) int64 {
t.Helper()
body := fmt.Sprintf(`{"app_id":%d,"task_name":"%s","values":{},"file_ids":[]}`, appID, taskName)
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success=true, got error: %s", parsed.Error)
}
var data taskCreateData
if err := json.Unmarshal(parsed.Data, &data); err != nil {
t.Fatalf("unmarshal create data: %v", err)
}
if data.ID == 0 {
t.Fatal("expected non-zero task ID")
}
return data.ID
}
// ---------- Tests ----------
func TestIntegration_Task_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
// Create application
appID, err := env.CreateApp("task-create-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// Create task via API
taskID := taskCreateViaAPI(t, env, appID, "test-task-create")
// Verify the task ID is positive
if taskID <= 0 {
t.Fatalf("expected positive task ID, got %d", taskID)
}
// Wait briefly for async processing, then verify task exists in DB via list
time.Sleep(200 * time.Millisecond)
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer resp.Body.Close()
parsed := taskParseResp(t, resp)
var listData taskListData
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
found := false
for _, item := range listData.Items {
if item.ID == taskID {
found = true
if item.TaskName != "test-task-create" {
t.Errorf("expected task_name=test-task-create, got %s", item.TaskName)
}
if item.AppID != appID {
t.Errorf("expected app_id=%d, got %d", appID, item.AppID)
}
break
}
}
if !found {
t.Fatalf("task %d not found in list", taskID)
}
}
func TestIntegration_Task_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Create application
appID, err := env.CreateApp("task-list-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// Create 3 tasks
taskCreateViaAPI(t, env, appID, "list-task-1")
taskCreateViaAPI(t, env, appID, "list-task-2")
taskCreateViaAPI(t, env, appID, "list-task-3")
// Allow async processing
time.Sleep(200 * time.Millisecond)
// List tasks
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success, got error: %s", parsed.Error)
}
var listData taskListData
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if listData.Total < 3 {
t.Fatalf("expected at least 3 tasks, got %d", listData.Total)
}
// Verify each created task has required fields
for _, item := range listData.Items {
if item.ID == 0 {
t.Error("expected non-zero ID")
}
if item.Status == "" {
t.Error("expected non-empty status")
}
if item.AppID == 0 {
t.Error("expected non-zero app_id")
}
}
}
func TestIntegration_Task_PollerLifecycle(t *testing.T) {
env := testenv.NewTestEnv(t)
// 1. Create application
appID, err := env.CreateApp("poller-lifecycle-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// 2. Submit task via API
taskID := taskCreateViaAPI(t, env, appID, "poller-lifecycle-task")
// 3. Wait for queued — TaskProcessor submits to MockSlurm asynchronously.
// Intermediate states (submitted→preparing→downloading→ready→queued) are
// non-deterministic; only assert the final "queued" state.
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
t.Fatalf("wait for queued: %v", err)
}
// 4. Get slurm job ID from DB (not returned by API)
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
if err != nil {
t.Fatalf("get slurm job id: %v", err)
}
// 5. Transition: queued → running
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("make task stale (running): %v", err)
}
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
t.Fatalf("wait for running: %v", err)
}
// 6. Transition: running → completed
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("make task stale (completed): %v", err)
}
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
t.Fatalf("wait for completed: %v", err)
}
}
func TestIntegration_Task_Validation(t *testing.T) {
env := testenv.NewTestEnv(t)
// Missing required app_id
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", `{"task_name":"no-app-id"}`)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing app_id, got %d", resp.StatusCode)
}
parsed := taskParseResp(t, resp)
if parsed.Success {
t.Fatal("expected success=false for validation error")
}
if parsed.Error == "" {
t.Error("expected non-empty error message")
}
}
func TestIntegration_Task_WithSchedulingParams(t *testing.T) {
env := testenv.NewTestEnv(t)
appID, err := env.CreateApp("sched-param-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
body := fmt.Sprintf(`{
"app_id": %d,
"task_name": "sched-task",
"values": {},
"file_ids": [],
"partition": "gpu",
"cpus": 16,
"memory_per_node": 32768,
"time_limit": 120
}`, appID)
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success=true, got error: %s", parsed.Error)
}
var data taskCreateData
if err := json.Unmarshal(parsed.Data, &data); err != nil {
t.Fatalf("unmarshal create data: %v", err)
}
if data.ID == 0 {
t.Fatal("expected non-zero task ID")
}
}
func TestIntegration_Task_ListWithScheduling(t *testing.T) {
env := testenv.NewTestEnv(t)
appID, err := env.CreateApp("sched-list-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
body := fmt.Sprintf(`{
"app_id": %d,
"task_name": "sched-list-task",
"values": {},
"file_ids": [],
"partition": "gpu",
"cpus": 16,
"memory_per_node": 32768,
"time_limit": 120
}`, appID)
createResp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(createResp.Body)
t.Fatalf("expected 201 creating task, got %d: %s", createResp.StatusCode, string(b))
}
createParsed := taskParseResp(t, createResp)
var createData taskCreateData
if err := json.Unmarshal(createParsed.Data, &createData); err != nil {
t.Fatalf("unmarshal create data: %v", err)
}
time.Sleep(200 * time.Millisecond)
listResp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer listResp.Body.Close()
if listResp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(listResp.Body)
t.Fatalf("expected 200 listing tasks, got %d: %s", listResp.StatusCode, string(b))
}
listParsed := taskParseResp(t, listResp)
var listData taskListData
if err := json.Unmarshal(listParsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
var found *taskListItem
for i := range listData.Items {
if listData.Items[i].ID == createData.ID {
found = &listData.Items[i]
break
}
}
if found == nil {
t.Fatalf("task %d not found in list", createData.ID)
}
if found.Partition != "gpu" {
t.Errorf("expected partition=gpu, got %q", found.Partition)
}
if found.Cpus == nil || *found.Cpus != 16 {
t.Errorf("expected cpus=16, got %v", found.Cpus)
}
if found.MemoryPerNode == nil || *found.MemoryPerNode != 32768 {
t.Errorf("expected memory_per_node=32768, got %v", found.MemoryPerNode)
}
if found.TimeLimit == nil || *found.TimeLimit != 120 {
t.Errorf("expected time_limit=120, got %v", found.TimeLimit)
}
}
func TestIntegration_Task_PartialScheduling(t *testing.T) {
env := testenv.NewTestEnv(t)
appID, err := env.CreateApp("partial-sched-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
body := fmt.Sprintf(`{
"app_id": %d,
"task_name": "partial-sched-task",
"values": {},
"file_ids": [],
"partition": "gpu",
"cpus": 8
}`, appID)
createResp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(createResp.Body)
t.Fatalf("expected 201, got %d: %s", createResp.StatusCode, string(b))
}
createParsed := taskParseResp(t, createResp)
var createData taskCreateData
if err := json.Unmarshal(createParsed.Data, &createData); err != nil {
t.Fatalf("unmarshal create data: %v", err)
}
time.Sleep(200 * time.Millisecond)
listResp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer listResp.Body.Close()
listParsed := taskParseResp(t, listResp)
var listData taskListData
if err := json.Unmarshal(listParsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
var found *taskListItem
for i := range listData.Items {
if listData.Items[i].ID == createData.ID {
found = &listData.Items[i]
break
}
}
if found == nil {
t.Fatalf("task %d not found in list", createData.ID)
}
if found.Partition != "gpu" {
t.Errorf("expected partition=gpu, got %q", found.Partition)
}
if found.Cpus == nil || *found.Cpus != 8 {
t.Errorf("expected cpus=8, got %v", found.Cpus)
}
if found.MemoryPerNode != nil {
t.Errorf("expected memory_per_node=nil, got %v", found.MemoryPerNode)
}
if found.TimeLimit != nil {
t.Errorf("expected time_limit=nil, got %v", found.TimeLimit)
}
}
func TestIntegration_Task_BackwardCompat(t *testing.T) {
env := testenv.NewTestEnv(t)
appID, err := env.CreateApp("compat-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
taskID := taskCreateViaAPI(t, env, appID, "compat-task")
time.Sleep(200 * time.Millisecond)
listResp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer listResp.Body.Close()
listParsed := taskParseResp(t, listResp)
var listData taskListData
if err := json.Unmarshal(listParsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
var found *taskListItem
for i := range listData.Items {
if listData.Items[i].ID == taskID {
found = &listData.Items[i]
break
}
}
if found == nil {
t.Fatalf("task %d not found in list", taskID)
}
if found.Partition != "" {
t.Errorf("expected empty partition, got %q", found.Partition)
}
if found.Cpus != nil {
t.Errorf("expected nil cpus, got %v", found.Cpus)
}
if found.MemoryPerNode != nil {
t.Errorf("expected nil memory_per_node, got %v", found.MemoryPerNode)
}
if found.TimeLimit != nil {
t.Errorf("expected nil time_limit, got %v", found.TimeLimit)
}
if found.QOS != nil {
t.Errorf("expected nil qos, got %v", found.QOS)
}
}