feat(handler): add task defaults and file_ids support in task submission
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
@@ -76,18 +76,45 @@ func (h *TaskHandler) ListTasks(c *gin.Context) {
|
||||
responses := make([]model.TaskResponse, 0, len(tasks))
|
||||
for i := range tasks {
|
||||
responses = append(responses, model.TaskResponse{
|
||||
ID: tasks[i].ID,
|
||||
TaskName: tasks[i].TaskName,
|
||||
AppID: tasks[i].AppID,
|
||||
AppName: tasks[i].AppName,
|
||||
Status: tasks[i].Status,
|
||||
CurrentStep: tasks[i].CurrentStep,
|
||||
RetryCount: tasks[i].RetryCount,
|
||||
SlurmJobID: tasks[i].SlurmJobID,
|
||||
WorkDir: tasks[i].WorkDir,
|
||||
ErrorMessage: tasks[i].ErrorMessage,
|
||||
CreatedAt: tasks[i].CreatedAt,
|
||||
UpdatedAt: tasks[i].UpdatedAt,
|
||||
ID: tasks[i].ID,
|
||||
TaskName: tasks[i].TaskName,
|
||||
AppID: tasks[i].AppID,
|
||||
AppName: tasks[i].AppName,
|
||||
Status: tasks[i].Status,
|
||||
CurrentStep: tasks[i].CurrentStep,
|
||||
RetryCount: tasks[i].RetryCount,
|
||||
SlurmJobID: tasks[i].SlurmJobID,
|
||||
WorkDir: tasks[i].WorkDir,
|
||||
ErrorMessage: tasks[i].ErrorMessage,
|
||||
CreatedAt: tasks[i].CreatedAt,
|
||||
UpdatedAt: tasks[i].UpdatedAt,
|
||||
Partition: tasks[i].Partition,
|
||||
Cpus: tasks[i].Cpus,
|
||||
MemoryPerNode: tasks[i].MemoryPerNode,
|
||||
MemoryPerCpu: tasks[i].MemoryPerCpu,
|
||||
TimeLimit: tasks[i].TimeLimit,
|
||||
QOS: tasks[i].QOS,
|
||||
JobName: tasks[i].JobName,
|
||||
Nodes: tasks[i].Nodes,
|
||||
Tasks: tasks[i].Tasks,
|
||||
CpusPerTask: tasks[i].CpusPerTask,
|
||||
Constraints: tasks[i].Constraints,
|
||||
Reservation: tasks[i].Reservation,
|
||||
Account: tasks[i].Account,
|
||||
Nice: tasks[i].Nice,
|
||||
MailType: tasks[i].MailType,
|
||||
MailUser: tasks[i].MailUser,
|
||||
StandardOutput: tasks[i].StandardOutput,
|
||||
StandardError: tasks[i].StandardError,
|
||||
StandardInput: tasks[i].StandardInput,
|
||||
RequiredNodes: tasks[i].RequiredNodes,
|
||||
ExcludedNodes: tasks[i].ExcludedNodes,
|
||||
BeginTime: tasks[i].BeginTime,
|
||||
Deadline: tasks[i].Deadline,
|
||||
Array: tasks[i].Array,
|
||||
Dependency: tasks[i].Dependency,
|
||||
Requeue: tasks[i].Requeue,
|
||||
KillOnNodeFail: tasks[i].KillOnNodeFail,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -248,6 +248,148 @@ func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateTask_WithSchedulingFields(t *testing.T) {
|
||||
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
|
||||
}))
|
||||
defer slurmSrv.Close()
|
||||
|
||||
h, db := setupTaskHandler(t, slurmSrv)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
appID := createTestAppForTask(db)
|
||||
|
||||
taskSvc := h.svc.(*service.TaskService)
|
||||
ctx := context.Background()
|
||||
taskSvc.StartProcessor(ctx)
|
||||
defer taskSvc.StopProcessor()
|
||||
|
||||
cpus := int32(8)
|
||||
memNode := int64(4096)
|
||||
tl := int32(60)
|
||||
qos := "high"
|
||||
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "sched-task",
|
||||
Partition: ptrStr("gpu"),
|
||||
Cpus: &cpus,
|
||||
MemoryPerNode: &memNode,
|
||||
TimeLimit: &tl,
|
||||
QOS: &qos,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !resp["success"].(bool) {
|
||||
t.Fatal("expected success=true")
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if _, ok := data["id"]; !ok {
|
||||
t.Fatal("expected id in response data")
|
||||
}
|
||||
taskID := int64(data["id"].(float64))
|
||||
|
||||
// Wait for async processing
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Verify persisted scheduling fields
|
||||
var task model.Task
|
||||
db.First(&task, taskID)
|
||||
if task.Partition != "gpu" {
|
||||
t.Errorf("expected partition=gpu, got %q", task.Partition)
|
||||
}
|
||||
if task.Cpus == nil || *task.Cpus != 8 {
|
||||
t.Errorf("expected cpus=8, got %v", task.Cpus)
|
||||
}
|
||||
if task.MemoryPerNode == nil || *task.MemoryPerNode != 4096 {
|
||||
t.Errorf("expected memory_per_node=4096, got %v", task.MemoryPerNode)
|
||||
}
|
||||
if task.TimeLimit == nil || *task.TimeLimit != 60 {
|
||||
t.Errorf("expected time_limit=60, got %v", task.TimeLimit)
|
||||
}
|
||||
if task.QOS == nil || *task.QOS != "high" {
|
||||
t.Errorf("expected qos=high, got %v", task.QOS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_ListTasks_ReturnsSchedulingFields(t *testing.T) {
|
||||
h, db := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
_ = createTestAppForTask(db)
|
||||
|
||||
cpus := int32(16)
|
||||
memNode := int64(8192)
|
||||
tl := int32(120)
|
||||
partition := "gpu"
|
||||
qos := "normal"
|
||||
task := &model.Task{
|
||||
TaskName: "sched-list-task",
|
||||
AppID: 1,
|
||||
AppName: "test-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
SubmittedAt: time.Now(),
|
||||
Partition: partition,
|
||||
Cpus: &cpus,
|
||||
MemoryPerNode: &memNode,
|
||||
TimeLimit: &tl,
|
||||
QOS: &qos,
|
||||
}
|
||||
db.Create(task)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
data := resp["data"].(map[string]interface{})
|
||||
items := data["items"].([]interface{})
|
||||
if len(items) == 0 {
|
||||
t.Fatal("expected at least 1 item")
|
||||
}
|
||||
item := items[0].(map[string]interface{})
|
||||
|
||||
if item["partition"] != "gpu" {
|
||||
t.Errorf("expected partition=gpu, got %v", item["partition"])
|
||||
}
|
||||
if item["cpus"] == nil {
|
||||
t.Error("expected cpus to be set")
|
||||
} else if item["cpus"].(float64) != 16 {
|
||||
t.Errorf("expected cpus=16, got %v", item["cpus"])
|
||||
}
|
||||
if item["memory_per_node"] == nil {
|
||||
t.Error("expected memory_per_node to be set")
|
||||
} else if item["memory_per_node"].(float64) != 8192 {
|
||||
t.Errorf("expected memory_per_node=8192, got %v", item["memory_per_node"])
|
||||
}
|
||||
if item["time_limit"] == nil {
|
||||
t.Error("expected time_limit to be set")
|
||||
} else if item["time_limit"].(float64) != 120 {
|
||||
t.Errorf("expected time_limit=120, got %v", item["time_limit"])
|
||||
}
|
||||
if item["qos"] == nil {
|
||||
t.Error("expected qos to be set")
|
||||
} else if item["qos"].(string) != "normal" {
|
||||
t.Errorf("expected qos=normal, got %v", item["qos"])
|
||||
}
|
||||
}
|
||||
|
||||
func ptrStr(s string) *string { return &s }
|
||||
|
||||
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
|
||||
h, db := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
Reference in New Issue
Block a user