From f894e870ed5b663a7515c48c8b2e877f8366f6b2 Mon Sep 17 00:00:00 2001 From: dailz Date: Mon, 20 Apr 2026 10:38:30 +0800 Subject: [PATCH] test(model): add tests for task defaults, job queries, and DTOs Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- internal/model/job_test.go | 200 ++++++++++++++++++++ internal/model/task_test.go | 358 ++++++++++++++++++++++++++++++++++-- 2 files changed, 538 insertions(+), 20 deletions(-) create mode 100644 internal/model/job_test.go diff --git a/internal/model/job_test.go b/internal/model/job_test.go new file mode 100644 index 0000000..d02d04e --- /dev/null +++ b/internal/model/job_test.go @@ -0,0 +1,200 @@ +package model + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestSubmitJobRequest_SchedulingFields(t *testing.T) { + payload := `{ + "script": "#!/bin/bash\necho hello", + "work_dir": "/tmp/work", + "partition": "gpu", + "qos": "high", + "cpus": 16, + "memory": "4GB", + "time_limit": "60", + "job_name": "test-job", + "environment": {"PATH": "/usr/bin"}, + "memory_per_node": 32768, + "memory_per_cpu": 4096, + "nodes": "2", + "tasks": 4, + "cpus_per_task": 8, + "constraints": "gpu&a100", + "reservation": "res-001", + "account": "project-x", + "nice": 100, + "mail_type": "END,FAIL", + "mail_user": "user@example.com", + "standard_output": "/tmp/out_%j.log", + "standard_error": "/tmp/err_%j.log", + "standard_input": "/tmp/input.txt", + "required_nodes": "node[01-03]", + "excluded_nodes": "node04", + "begin_time": 1700000000, + "deadline": 1700086400, + "array": "1-100%10", + "dependency": "afterok:12345", + "requeue": true, + "kill_on_node_fail": false + }` + + var req SubmitJobRequest + if err := json.Unmarshal([]byte(payload), &req); err != nil { + t.Fatalf("unmarshal SubmitJobRequest: %v", err) + } + + // Existing fields + if req.Script != "#!/bin/bash\necho hello" { + t.Errorf("Script = %q, want script content", req.Script) + } + if req.WorkDir != "/tmp/work" { + t.Errorf("WorkDir = %q, want /tmp/work", req.WorkDir) + } + if req.Partition != "gpu" { + t.Errorf("Partition = %q, want gpu", req.Partition) + } + if req.QOS != "high" { + t.Errorf("QOS = %q, want high", req.QOS) + } + if req.CPUs != 16 { + t.Errorf("CPUs = %d, want 16", req.CPUs) + } + if req.Memory != "4GB" { + t.Errorf("Memory = %q, want 4GB", req.Memory) + } + if req.TimeLimit != "60" { + t.Errorf("TimeLimit = %q, want 60", req.TimeLimit) + } + if req.JobName != "test-job" { + t.Errorf("JobName = %q, want test-job", req.JobName) + } + if v, ok := req.Environment["PATH"]; !ok || v != "/usr/bin" { + t.Errorf("Environment[PATH] = %q, want /usr/bin", v) + } + + // New scheduling fields + if req.MemoryPerNode == nil || *req.MemoryPerNode != 32768 { + t.Errorf("MemoryPerNode = %v, want 32768", req.MemoryPerNode) + } + if req.MemoryPerCpu == nil || *req.MemoryPerCpu != 4096 { + t.Errorf("MemoryPerCpu = %v, want 4096", req.MemoryPerCpu) + } + if req.Nodes == nil || *req.Nodes != "2" { + t.Errorf("Nodes = %v, want 2", req.Nodes) + } + if req.Tasks == nil || *req.Tasks != 4 { + t.Errorf("Tasks = %v, want 4", req.Tasks) + } + if req.CpusPerTask == nil || *req.CpusPerTask != 8 { + t.Errorf("CpusPerTask = %v, want 8", req.CpusPerTask) + } + if req.Constraints == nil || *req.Constraints != "gpu&a100" { + t.Errorf("Constraints = %v, want gpu&a100", req.Constraints) + } + if req.Reservation == nil || *req.Reservation != "res-001" { + t.Errorf("Reservation = %v, want res-001", req.Reservation) + } + if req.Account == nil || *req.Account != "project-x" { + t.Errorf("Account = %v, want project-x", req.Account) + } + if req.Nice == nil || *req.Nice != 100 { + t.Errorf("Nice = %v, want 100", req.Nice) + } + if req.MailType == nil || *req.MailType != "END,FAIL" { + t.Errorf("MailType = %v, want END,FAIL", req.MailType) + } + if req.MailUser == nil || *req.MailUser != "user@example.com" { + t.Errorf("MailUser = %v, want user@example.com", req.MailUser) + } + if req.StandardOutput == nil || *req.StandardOutput != "/tmp/out_%j.log" { + t.Errorf("StandardOutput = %v, want /tmp/out_%%j.log", req.StandardOutput) + } + if req.StandardError == nil || *req.StandardError != "/tmp/err_%j.log" { + t.Errorf("StandardError = %v, want /tmp/err_%%j.log", req.StandardError) + } + if req.StandardInput == nil || *req.StandardInput != "/tmp/input.txt" { + t.Errorf("StandardInput = %v, want /tmp/input.txt", req.StandardInput) + } + if req.RequiredNodes == nil || *req.RequiredNodes != "node[01-03]" { + t.Errorf("RequiredNodes = %v, want node[01-03]", req.RequiredNodes) + } + if req.ExcludedNodes == nil || *req.ExcludedNodes != "node04" { + t.Errorf("ExcludedNodes = %v, want node04", req.ExcludedNodes) + } + if req.BeginTime == nil || *req.BeginTime != 1700000000 { + t.Errorf("BeginTime = %v, want 1700000000", req.BeginTime) + } + if req.Deadline == nil || *req.Deadline != 1700086400 { + t.Errorf("Deadline = %v, want 1700086400", req.Deadline) + } + if req.Array == nil || *req.Array != "1-100%10" { + t.Errorf("Array = %v, want 1-100%%10", req.Array) + } + if req.Dependency == nil || *req.Dependency != "afterok:12345" { + t.Errorf("Dependency = %v, want afterok:12345", req.Dependency) + } + if req.Requeue == nil || *req.Requeue != true { + t.Errorf("Requeue = %v, want true", req.Requeue) + } + if req.KillOnNodeFail == nil || *req.KillOnNodeFail != false { + t.Errorf("KillOnNodeFail = %v, want false", req.KillOnNodeFail) + } +} + +func TestSubmitJobRequest_BackwardCompat(t *testing.T) { + // Minimal JSON — only required fields + payload := `{"script": "#!/bin/bash\necho hello", "work_dir": "/tmp"}` + + var req SubmitJobRequest + if err := json.Unmarshal([]byte(payload), &req); err != nil { + t.Fatalf("unmarshal minimal SubmitJobRequest: %v", err) + } + + // Required fields present + if req.Script != "#!/bin/bash\necho hello" { + t.Errorf("Script = %q, want script content", req.Script) + } + if req.WorkDir != "/tmp" { + t.Errorf("WorkDir = %q, want /tmp", req.WorkDir) + } + + // Old fields exist with zero values + if req.Memory != "" { + t.Errorf("Memory = %q, want empty", req.Memory) + } + if req.Environment != nil { + t.Errorf("Environment = %v, want nil", req.Environment) + } + + // All new scheduling fields are nil + assertNil := func(name string, val any) { + if !reflect.ValueOf(val).IsNil() { + t.Errorf("%s = %v, want nil", name, val) + } + } + assertNil("MemoryPerNode", req.MemoryPerNode) + assertNil("MemoryPerCpu", req.MemoryPerCpu) + assertNil("Nodes", req.Nodes) + assertNil("Tasks", req.Tasks) + assertNil("CpusPerTask", req.CpusPerTask) + assertNil("Constraints", req.Constraints) + assertNil("Reservation", req.Reservation) + assertNil("Account", req.Account) + assertNil("Nice", req.Nice) + assertNil("MailType", req.MailType) + assertNil("MailUser", req.MailUser) + assertNil("StandardOutput", req.StandardOutput) + assertNil("StandardError", req.StandardError) + assertNil("StandardInput", req.StandardInput) + assertNil("RequiredNodes", req.RequiredNodes) + assertNil("ExcludedNodes", req.ExcludedNodes) + assertNil("BeginTime", req.BeginTime) + assertNil("Deadline", req.Deadline) + assertNil("Array", req.Array) + assertNil("Dependency", req.Dependency) + assertNil("Requeue", req.Requeue) + assertNil("KillOnNodeFail", req.KillOnNodeFail) +} diff --git a/internal/model/task_test.go b/internal/model/task_test.go index bd3882f..7bdff4d 100644 --- a/internal/model/task_test.go +++ b/internal/model/task_test.go @@ -17,27 +17,33 @@ func TestTask_JSONRoundTrip(t *testing.T) { now := time.Now().UTC().Truncate(time.Second) jobID := int32(42) + cpus := int32(16) + memPerNode := int64(32768) + timeLimit := int32(60) task := Task{ - ID: 1, - TaskName: "test task", - AppID: 10, - AppName: "GROMACS", - Status: TaskStatusRunning, - CurrentStep: TaskStepSubmitting, - RetryCount: 1, - Values: json.RawMessage(`{"np":"4"}`), - InputFileIDs: json.RawMessage(`[1,2,3]`), - Script: "#!/bin/bash", - SlurmJobID: &jobID, - WorkDir: "/data/work", - Partition: "gpu", - ErrorMessage: "", - UserID: "user1", - SubmittedAt: now, - StartedAt: &now, - FinishedAt: nil, - CreatedAt: now, - UpdatedAt: now, + ID: 1, + TaskName: "test task", + AppID: 10, + AppName: "GROMACS", + Status: TaskStatusRunning, + CurrentStep: TaskStepSubmitting, + RetryCount: 1, + Values: json.RawMessage(`{"np":"4"}`), + InputFileIDs: json.RawMessage(`[1,2,3]`), + Script: "#!/bin/bash", + SlurmJobID: &jobID, + WorkDir: "/data/work", + Partition: "gpu", + ErrorMessage: "", + UserID: "user1", + SubmittedAt: now, + StartedAt: &now, + FinishedAt: nil, + CreatedAt: now, + UpdatedAt: now, + Cpus: &cpus, + MemoryPerNode: &memPerNode, + TimeLimit: &timeLimit, } data, err := json.Marshal(task) @@ -80,6 +86,18 @@ func TestTask_JSONRoundTrip(t *testing.T) { if got.FinishedAt != nil { t.Errorf("FinishedAt = %v, want nil", got.FinishedAt) } + if got.Partition != task.Partition { + t.Errorf("Partition = %q, want %q", got.Partition, task.Partition) + } + if got.Cpus == nil || *got.Cpus != cpus { + t.Errorf("Cpus = %v, want %d", got.Cpus, cpus) + } + if got.MemoryPerNode == nil || *got.MemoryPerNode != memPerNode { + t.Errorf("MemoryPerNode = %v, want %d", got.MemoryPerNode, memPerNode) + } + if got.TimeLimit == nil || *got.TimeLimit != timeLimit { + t.Errorf("TimeLimit = %v, want %d", got.TimeLimit, timeLimit) + } } func TestCreateTaskRequest_JSONBinding(t *testing.T) { @@ -102,3 +120,303 @@ func TestCreateTaskRequest_JSONBinding(t *testing.T) { t.Errorf("InputFileIDs = %v, want [10 20]", req.InputFileIDs) } } + +func TestTaskResponse_JSONSerialization(t *testing.T) { + cpus := int32(8) + memPerNode := int64(16384) + memPerCpu := int64(4096) + timeLimit := int32(120) + qos := "high" + jobName := "gmx-md" + nodes := "2" + tasks := int32(4) + cpusPerTask := int32(2) + constraints := "haswell" + reservation := "my-resv" + account := "proj-123" + nice := int32(100) + mailType := "END" + mailUser := "user@example.com" + stdout := "/tmp/%j.out" + stderr := "/tmp/%j.err" + stdin := "/dev/null" + reqNodes := "node[01-03]" + exclNodes := "node04" + beginTime := int64(1700000000) + deadline := int64(1700086400) + array := "1-10" + dependency := "afterok:12345" + requeue := true + killOnNodeFail := true + + resp := TaskResponse{ + ID: 1, + TaskName: "test", + AppID: 10, + AppName: "GROMACS", + Status: "running", + CurrentStep: "submitting", + RetryCount: 0, + SlurmJobID: nil, + WorkDir: "/data", + ErrorMessage: "", + CreatedAt: time.Now().UTC().Truncate(time.Second), + UpdatedAt: time.Now().UTC().Truncate(time.Second), + Partition: "gpu", + Cpus: &cpus, + MemoryPerNode: &memPerNode, + MemoryPerCpu: &memPerCpu, + TimeLimit: &timeLimit, + QOS: &qos, + JobName: &jobName, + Nodes: &nodes, + Tasks: &tasks, + CpusPerTask: &cpusPerTask, + Constraints: &constraints, + Reservation: &reservation, + Account: &account, + Nice: &nice, + MailType: &mailType, + MailUser: &mailUser, + StandardOutput: &stdout, + StandardError: &stderr, + StandardInput: &stdin, + RequiredNodes: &reqNodes, + ExcludedNodes: &exclNodes, + BeginTime: &beginTime, + Deadline: &deadline, + Array: &array, + Dependency: &dependency, + Requeue: &requeue, + KillOnNodeFail: &killOnNodeFail, + } + + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("marshal TaskResponse: %v", err) + } + + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("unmarshal to map: %v", err) + } + + assertString(t, m, "partition", "gpu") + assertFloat64(t, m, "cpus", float64(cpus)) + assertFloat64(t, m, "memory_per_node", float64(memPerNode)) + assertFloat64(t, m, "memory_per_cpu", float64(memPerCpu)) + assertFloat64(t, m, "time_limit", float64(timeLimit)) + assertString(t, m, "qos", qos) + assertString(t, m, "job_name", jobName) + assertString(t, m, "nodes", nodes) + assertFloat64(t, m, "tasks", float64(tasks)) + assertFloat64(t, m, "cpus_per_task", float64(cpusPerTask)) + assertString(t, m, "constraints", constraints) + assertString(t, m, "reservation", reservation) + assertString(t, m, "account", account) + assertFloat64(t, m, "nice", float64(nice)) + assertString(t, m, "mail_type", mailType) + assertString(t, m, "mail_user", mailUser) + assertString(t, m, "standard_output", stdout) + assertString(t, m, "standard_error", stderr) + assertString(t, m, "standard_input", stdin) + assertString(t, m, "required_nodes", reqNodes) + assertString(t, m, "excluded_nodes", exclNodes) + assertFloat64(t, m, "begin_time", float64(beginTime)) + assertFloat64(t, m, "deadline", float64(deadline)) + assertString(t, m, "array", array) + assertString(t, m, "dependency", dependency) + assertBool(t, m, "requeue", requeue) + assertBool(t, m, "kill_on_node_fail", killOnNodeFail) +} + +func TestCreateTaskRequest_SchedulingFields(t *testing.T) { + payload := `{ + "app_id": 5, + "partition": "gpu", + "cpus": 16, + "memory_per_node": 32768, + "memory_per_cpu": 4096, + "time_limit": 120, + "qos": "high", + "job_name": "gmx-sim", + "nodes": "2", + "tasks": 4, + "cpus_per_task": 2, + "constraints": "haswell", + "reservation": "my-resv", + "account": "proj-123", + "nice": 50, + "mail_type": "ALL", + "mail_user": "user@example.com", + "standard_output": "/tmp/%j.out", + "standard_error": "/tmp/%j.err", + "standard_input": "/dev/null", + "required_nodes": "node[01-03]", + "excluded_nodes": "node04", + "begin_time": 1700000000, + "deadline": 1700086400, + "array": "1-10", + "dependency": "afterok:12345", + "requeue": true, + "kill_on_node_fail": false + }` + + var req CreateTaskRequest + if err := json.Unmarshal([]byte(payload), &req); err != nil { + t.Fatalf("unmarshal CreateTaskRequest: %v", err) + } + + if req.AppID != 5 { + t.Errorf("AppID = %d, want 5", req.AppID) + } + assertPtrString(t, req.Partition, "gpu") + assertPtrInt32(t, req.Cpus, 16) + assertPtrInt64(t, req.MemoryPerNode, 32768) + assertPtrInt64(t, req.MemoryPerCpu, 4096) + assertPtrInt32(t, req.TimeLimit, 120) + assertPtrString(t, req.QOS, "high") + assertPtrString(t, req.JobName, "gmx-sim") + assertPtrString(t, req.Nodes, "2") + assertPtrInt32(t, req.Tasks, 4) + assertPtrInt32(t, req.CpusPerTask, 2) + assertPtrString(t, req.Constraints, "haswell") + assertPtrString(t, req.Reservation, "my-resv") + assertPtrString(t, req.Account, "proj-123") + assertPtrInt32(t, req.Nice, 50) + assertPtrString(t, req.MailType, "ALL") + assertPtrString(t, req.MailUser, "user@example.com") + assertPtrString(t, req.StandardOutput, "/tmp/%j.out") + assertPtrString(t, req.StandardError, "/tmp/%j.err") + assertPtrString(t, req.StandardInput, "/dev/null") + assertPtrString(t, req.RequiredNodes, "node[01-03]") + assertPtrString(t, req.ExcludedNodes, "node04") + assertPtrInt64(t, req.BeginTime, 1700000000) + assertPtrInt64(t, req.Deadline, 1700086400) + assertPtrString(t, req.Array, "1-10") + assertPtrString(t, req.Dependency, "afterok:12345") + assertPtrBool(t, req.Requeue, true) + assertPtrBool(t, req.KillOnNodeFail, false) +} + +func TestCreateTaskRequest_BackwardCompat(t *testing.T) { + payload := `{"app_id": 1}` + var req CreateTaskRequest + if err := json.Unmarshal([]byte(payload), &req); err != nil { + t.Fatalf("unmarshal minimal CreateTaskRequest: %v", err) + } + + if req.AppID != 1 { + t.Errorf("AppID = %d, want 1", req.AppID) + } + if req.Partition != nil { + t.Errorf("Partition = %v, want nil", req.Partition) + } + if req.Cpus != nil { + t.Errorf("Cpus = %v, want nil", req.Cpus) + } + if req.MemoryPerNode != nil { + t.Errorf("MemoryPerNode = %v, want nil", req.MemoryPerNode) + } + if req.MemoryPerCpu != nil { + t.Errorf("MemoryPerCpu = %v, want nil", req.MemoryPerCpu) + } + if req.TimeLimit != nil { + t.Errorf("TimeLimit = %v, want nil", req.TimeLimit) + } + if req.QOS != nil { + t.Errorf("QOS = %v, want nil", req.QOS) + } + if req.Nodes != nil { + t.Errorf("Nodes = %v, want nil", req.Nodes) + } + if req.Tasks != nil { + t.Errorf("Tasks = %v, want nil", req.Tasks) + } + if req.Requeue != nil { + t.Errorf("Requeue = %v, want nil", req.Requeue) + } + if req.KillOnNodeFail != nil { + t.Errorf("KillOnNodeFail = %v, want nil", req.KillOnNodeFail) + } +} + +func assertString(t *testing.T, m map[string]interface{}, key, want string) { + t.Helper() + got, ok := m[key].(string) + if !ok { + t.Errorf("%s: not a string, got %T", key, m[key]) + return + } + if got != want { + t.Errorf("%s = %q, want %q", key, got, want) + } +} + +func assertFloat64(t *testing.T, m map[string]interface{}, key string, want float64) { + t.Helper() + got, ok := m[key].(float64) + if !ok { + t.Errorf("%s: not a float64, got %T", key, m[key]) + return + } + if got != want { + t.Errorf("%s = %v, want %v", key, got, want) + } +} + +func assertBool(t *testing.T, m map[string]interface{}, key string, want bool) { + t.Helper() + got, ok := m[key].(bool) + if !ok { + t.Errorf("%s: not a bool, got %T", key, m[key]) + return + } + if got != want { + t.Errorf("%s = %v, want %v", key, got, want) + } +} + +func assertPtrString(t *testing.T, got *string, want string) { + t.Helper() + if got == nil { + t.Errorf("got nil, want %q", want) + return + } + if *got != want { + t.Errorf("got %q, want %q", *got, want) + } +} + +func assertPtrInt32(t *testing.T, got *int32, want int32) { + t.Helper() + if got == nil { + t.Errorf("got nil, want %d", want) + return + } + if *got != want { + t.Errorf("got %d, want %d", *got, want) + } +} + +func assertPtrInt64(t *testing.T, got *int64, want int64) { + t.Helper() + if got == nil { + t.Errorf("got nil, want %d", want) + return + } + if *got != want { + t.Errorf("got %d, want %d", *got, want) + } +} + +func assertPtrBool(t *testing.T, got *bool, want bool) { + t.Helper() + if got == nil { + t.Errorf("got nil, want %v", want) + return + } + if *got != want { + t.Errorf("got %v, want %v", *got, want) + } +}