feat(handler): add task defaults and file_ids support in task submission
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
@@ -248,6 +248,148 @@ func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateTask_WithSchedulingFields(t *testing.T) {
|
||||
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
|
||||
}))
|
||||
defer slurmSrv.Close()
|
||||
|
||||
h, db := setupTaskHandler(t, slurmSrv)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
appID := createTestAppForTask(db)
|
||||
|
||||
taskSvc := h.svc.(*service.TaskService)
|
||||
ctx := context.Background()
|
||||
taskSvc.StartProcessor(ctx)
|
||||
defer taskSvc.StopProcessor()
|
||||
|
||||
cpus := int32(8)
|
||||
memNode := int64(4096)
|
||||
tl := int32(60)
|
||||
qos := "high"
|
||||
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "sched-task",
|
||||
Partition: ptrStr("gpu"),
|
||||
Cpus: &cpus,
|
||||
MemoryPerNode: &memNode,
|
||||
TimeLimit: &tl,
|
||||
QOS: &qos,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !resp["success"].(bool) {
|
||||
t.Fatal("expected success=true")
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if _, ok := data["id"]; !ok {
|
||||
t.Fatal("expected id in response data")
|
||||
}
|
||||
taskID := int64(data["id"].(float64))
|
||||
|
||||
// Wait for async processing
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Verify persisted scheduling fields
|
||||
var task model.Task
|
||||
db.First(&task, taskID)
|
||||
if task.Partition != "gpu" {
|
||||
t.Errorf("expected partition=gpu, got %q", task.Partition)
|
||||
}
|
||||
if task.Cpus == nil || *task.Cpus != 8 {
|
||||
t.Errorf("expected cpus=8, got %v", task.Cpus)
|
||||
}
|
||||
if task.MemoryPerNode == nil || *task.MemoryPerNode != 4096 {
|
||||
t.Errorf("expected memory_per_node=4096, got %v", task.MemoryPerNode)
|
||||
}
|
||||
if task.TimeLimit == nil || *task.TimeLimit != 60 {
|
||||
t.Errorf("expected time_limit=60, got %v", task.TimeLimit)
|
||||
}
|
||||
if task.QOS == nil || *task.QOS != "high" {
|
||||
t.Errorf("expected qos=high, got %v", task.QOS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_ListTasks_ReturnsSchedulingFields(t *testing.T) {
|
||||
h, db := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
_ = createTestAppForTask(db)
|
||||
|
||||
cpus := int32(16)
|
||||
memNode := int64(8192)
|
||||
tl := int32(120)
|
||||
partition := "gpu"
|
||||
qos := "normal"
|
||||
task := &model.Task{
|
||||
TaskName: "sched-list-task",
|
||||
AppID: 1,
|
||||
AppName: "test-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
SubmittedAt: time.Now(),
|
||||
Partition: partition,
|
||||
Cpus: &cpus,
|
||||
MemoryPerNode: &memNode,
|
||||
TimeLimit: &tl,
|
||||
QOS: &qos,
|
||||
}
|
||||
db.Create(task)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
data := resp["data"].(map[string]interface{})
|
||||
items := data["items"].([]interface{})
|
||||
if len(items) == 0 {
|
||||
t.Fatal("expected at least 1 item")
|
||||
}
|
||||
item := items[0].(map[string]interface{})
|
||||
|
||||
if item["partition"] != "gpu" {
|
||||
t.Errorf("expected partition=gpu, got %v", item["partition"])
|
||||
}
|
||||
if item["cpus"] == nil {
|
||||
t.Error("expected cpus to be set")
|
||||
} else if item["cpus"].(float64) != 16 {
|
||||
t.Errorf("expected cpus=16, got %v", item["cpus"])
|
||||
}
|
||||
if item["memory_per_node"] == nil {
|
||||
t.Error("expected memory_per_node to be set")
|
||||
} else if item["memory_per_node"].(float64) != 8192 {
|
||||
t.Errorf("expected memory_per_node=8192, got %v", item["memory_per_node"])
|
||||
}
|
||||
if item["time_limit"] == nil {
|
||||
t.Error("expected time_limit to be set")
|
||||
} else if item["time_limit"].(float64) != 120 {
|
||||
t.Errorf("expected time_limit=120, got %v", item["time_limit"])
|
||||
}
|
||||
if item["qos"] == nil {
|
||||
t.Error("expected qos to be set")
|
||||
} else if item["qos"].(string) != "normal" {
|
||||
t.Errorf("expected qos=normal, got %v", item["qos"])
|
||||
}
|
||||
}
|
||||
|
||||
func ptrStr(s string) *string { return &s }
|
||||
|
||||
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
|
||||
h, db := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
Reference in New Issue
Block a user