diff --git a/internal/handler/task_handler.go b/internal/handler/task_handler.go index d5027be..496c858 100644 --- a/internal/handler/task_handler.go +++ b/internal/handler/task_handler.go @@ -76,18 +76,45 @@ func (h *TaskHandler) ListTasks(c *gin.Context) { responses := make([]model.TaskResponse, 0, len(tasks)) for i := range tasks { responses = append(responses, model.TaskResponse{ - ID: tasks[i].ID, - TaskName: tasks[i].TaskName, - AppID: tasks[i].AppID, - AppName: tasks[i].AppName, - Status: tasks[i].Status, - CurrentStep: tasks[i].CurrentStep, - RetryCount: tasks[i].RetryCount, - SlurmJobID: tasks[i].SlurmJobID, - WorkDir: tasks[i].WorkDir, - ErrorMessage: tasks[i].ErrorMessage, - CreatedAt: tasks[i].CreatedAt, - UpdatedAt: tasks[i].UpdatedAt, + ID: tasks[i].ID, + TaskName: tasks[i].TaskName, + AppID: tasks[i].AppID, + AppName: tasks[i].AppName, + Status: tasks[i].Status, + CurrentStep: tasks[i].CurrentStep, + RetryCount: tasks[i].RetryCount, + SlurmJobID: tasks[i].SlurmJobID, + WorkDir: tasks[i].WorkDir, + ErrorMessage: tasks[i].ErrorMessage, + CreatedAt: tasks[i].CreatedAt, + UpdatedAt: tasks[i].UpdatedAt, + Partition: tasks[i].Partition, + Cpus: tasks[i].Cpus, + MemoryPerNode: tasks[i].MemoryPerNode, + MemoryPerCpu: tasks[i].MemoryPerCpu, + TimeLimit: tasks[i].TimeLimit, + QOS: tasks[i].QOS, + JobName: tasks[i].JobName, + Nodes: tasks[i].Nodes, + Tasks: tasks[i].Tasks, + CpusPerTask: tasks[i].CpusPerTask, + Constraints: tasks[i].Constraints, + Reservation: tasks[i].Reservation, + Account: tasks[i].Account, + Nice: tasks[i].Nice, + MailType: tasks[i].MailType, + MailUser: tasks[i].MailUser, + StandardOutput: tasks[i].StandardOutput, + StandardError: tasks[i].StandardError, + StandardInput: tasks[i].StandardInput, + RequiredNodes: tasks[i].RequiredNodes, + ExcludedNodes: tasks[i].ExcludedNodes, + BeginTime: tasks[i].BeginTime, + Deadline: tasks[i].Deadline, + Array: tasks[i].Array, + Dependency: tasks[i].Dependency, + Requeue: tasks[i].Requeue, + KillOnNodeFail: tasks[i].KillOnNodeFail, }) } diff --git a/internal/handler/task_handler_test.go b/internal/handler/task_handler_test.go index 166f81a..73aa7e3 100644 --- a/internal/handler/task_handler_test.go +++ b/internal/handler/task_handler_test.go @@ -248,6 +248,148 @@ func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) { } } +func TestTaskHandler_CreateTask_WithSchedulingFields(t *testing.T) { + slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345}) + })) + defer slurmSrv.Close() + + h, db := setupTaskHandler(t, slurmSrv) + r := setupTaskRouter(h) + + appID := createTestAppForTask(db) + + taskSvc := h.svc.(*service.TaskService) + ctx := context.Background() + taskSvc.StartProcessor(ctx) + defer taskSvc.StopProcessor() + + cpus := int32(8) + memNode := int64(4096) + tl := int32(60) + qos := "high" + body, _ := json.Marshal(model.CreateTaskRequest{ + AppID: appID, + TaskName: "sched-task", + Partition: ptrStr("gpu"), + Cpus: &cpus, + MemoryPerNode: &memNode, + TimeLimit: &tl, + QOS: &qos, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + if !resp["success"].(bool) { + t.Fatal("expected success=true") + } + data := resp["data"].(map[string]interface{}) + if _, ok := data["id"]; !ok { + t.Fatal("expected id in response data") + } + taskID := int64(data["id"].(float64)) + + // Wait for async processing + time.Sleep(300 * time.Millisecond) + + // Verify persisted scheduling fields + var task model.Task + db.First(&task, taskID) + if task.Partition != "gpu" { + t.Errorf("expected partition=gpu, got %q", task.Partition) + } + if task.Cpus == nil || *task.Cpus != 8 { + t.Errorf("expected cpus=8, got %v", task.Cpus) + } + if task.MemoryPerNode == nil || *task.MemoryPerNode != 4096 { + t.Errorf("expected memory_per_node=4096, got %v", task.MemoryPerNode) + } + if task.TimeLimit == nil || *task.TimeLimit != 60 { + t.Errorf("expected time_limit=60, got %v", task.TimeLimit) + } + if task.QOS == nil || *task.QOS != "high" { + t.Errorf("expected qos=high, got %v", task.QOS) + } +} + +func TestTaskHandler_ListTasks_ReturnsSchedulingFields(t *testing.T) { + h, db := setupTaskHandler(t, nil) + r := setupTaskRouter(h) + + _ = createTestAppForTask(db) + + cpus := int32(16) + memNode := int64(8192) + tl := int32(120) + partition := "gpu" + qos := "normal" + task := &model.Task{ + TaskName: "sched-list-task", + AppID: 1, + AppName: "test-app", + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + Partition: partition, + Cpus: &cpus, + MemoryPerNode: &memNode, + TimeLimit: &tl, + QOS: &qos, + } + db.Create(task) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + data := resp["data"].(map[string]interface{}) + items := data["items"].([]interface{}) + if len(items) == 0 { + t.Fatal("expected at least 1 item") + } + item := items[0].(map[string]interface{}) + + if item["partition"] != "gpu" { + t.Errorf("expected partition=gpu, got %v", item["partition"]) + } + if item["cpus"] == nil { + t.Error("expected cpus to be set") + } else if item["cpus"].(float64) != 16 { + t.Errorf("expected cpus=16, got %v", item["cpus"]) + } + if item["memory_per_node"] == nil { + t.Error("expected memory_per_node to be set") + } else if item["memory_per_node"].(float64) != 8192 { + t.Errorf("expected memory_per_node=8192, got %v", item["memory_per_node"]) + } + if item["time_limit"] == nil { + t.Error("expected time_limit to be set") + } else if item["time_limit"].(float64) != 120 { + t.Errorf("expected time_limit=120, got %v", item["time_limit"]) + } + if item["qos"] == nil { + t.Error("expected qos to be set") + } else if item["qos"].(string) != "normal" { + t.Errorf("expected qos=normal, got %v", item["qos"]) + } +} + +func ptrStr(s string) *string { return &s } + func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) { h, db := setupTaskHandler(t, nil) r := setupTaskRouter(h)