From e90904cedb1972cf12473c941ccd8d4f2fd5a902 Mon Sep 17 00:00:00 2001 From: dailz Date: Mon, 20 Apr 2026 10:38:49 +0800 Subject: [PATCH] test(service): add tests for task defaults and job status Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- internal/service/job_service_test.go | 350 +++++++++++ .../service/task_service_defaults_test.go | 184 ++++++ internal/service/task_service_test.go | 589 ++++++++++++++++++ 3 files changed, 1123 insertions(+) create mode 100644 internal/service/task_service_defaults_test.go diff --git a/internal/service/job_service_test.go b/internal/service/job_service_test.go index 49a2e26..d5e5630 100644 --- a/internal/service/job_service_test.go +++ b/internal/service/job_service_test.go @@ -829,6 +829,356 @@ func TestGetJob_FallbackToHistory_HistoryError(t *testing.T) { } } +// --------------------------------------------------------------------------- +// New scheduling field mapping tests +// --------------------------------------------------------------------------- + +func TestSubmitJob_AllSchedulingFields(t *testing.T) { + jobID := int32(999) + + // Prepare all 22 new scheduling field values + var ( + memoryPerNode = int64(4096) + memoryPerCpu = int64(1024) + nodes = "2" + tasks = int32(4) + cpusPerTask = int32(2) + constraints = "gpu&fast" + reservation = "resv01" + account = "proj-alpha" + nice = int32(100) + mailType = "BEGIN,END,FAIL" + mailUser = "admin@example.com" + stdOut = "/tmp/job_%j.out" + stdErr = "/tmp/job_%j.err" + stdIn = "/dev/null" + reqNodes = "node01,node02" + exclNodes = "node03,node04" + beginTime = int64(1700000000) + deadline = int64(1700099999) + array = "1-10" + dependency = "afterok:123" + requeue = true + killOnNodeFail = true + ) + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body slurm.JobSubmitReq + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Job == nil { + t.Fatal("job desc is nil") + } + j := body.Job + + // --- Existing fields still work --- + if j.Script == nil || *j.Script != "#!/bin/bash\necho test" { + t.Errorf("Script mismatch: %v", j.Script) + } + if j.Partition == nil || *j.Partition != "normal" { + t.Errorf("Partition mismatch: %v", j.Partition) + } + if j.Qos == nil || *j.Qos != "high" { + t.Errorf("QOS mismatch: %v", j.Qos) + } + if j.Name == nil || *j.Name != "full-test" { + t.Errorf("Name mismatch: %v", j.Name) + } + if j.MinimumCpus == nil || *j.MinimumCpus != int32(8) { + t.Errorf("MinimumCpus mismatch: %v", j.MinimumCpus) + } + + // --- 22 new scheduling fields --- + + // MemoryPerNode → *Uint64NoVal + if j.MemoryPerNode == nil || j.MemoryPerNode.Number == nil || *j.MemoryPerNode.Number != memoryPerNode { + t.Errorf("MemoryPerNode mismatch: %v", j.MemoryPerNode) + } + // MemoryPerCpu → *Uint64NoVal + if j.MemoryPerCpu == nil || j.MemoryPerCpu.Number == nil || *j.MemoryPerCpu.Number != memoryPerCpu { + t.Errorf("MemoryPerCpu mismatch: %v", j.MemoryPerCpu) + } + // Nodes → *string + if j.Nodes == nil || *j.Nodes != nodes { + t.Errorf("Nodes mismatch: %v", j.Nodes) + } + // Tasks → *int32 + if j.Tasks == nil || *j.Tasks != tasks { + t.Errorf("Tasks mismatch: got %v, want %d", j.Tasks, tasks) + } + // CpusPerTask → *int32 + if j.CpusPerTask == nil || *j.CpusPerTask != cpusPerTask { + t.Errorf("CpusPerTask mismatch: got %v, want %d", j.CpusPerTask, cpusPerTask) + } + // Constraints → *string + if j.Constraints == nil || *j.Constraints != constraints { + t.Errorf("Constraints mismatch: %v", j.Constraints) + } + // Reservation → *string + if j.Reservation == nil || *j.Reservation != reservation { + t.Errorf("Reservation mismatch: %v", j.Reservation) + } + // Account → *string + if j.Account == nil || *j.Account != account { + t.Errorf("Account mismatch: %v", j.Account) + } + // Nice → *int32 + if j.Nice == nil || *j.Nice != nice { + t.Errorf("Nice mismatch: got %v, want %d", j.Nice, nice) + } + // MailType → []string (comma-split) + if len(j.MailType) != 3 || j.MailType[0] != "BEGIN" || j.MailType[1] != "END" || j.MailType[2] != "FAIL" { + t.Errorf("MailType mismatch: %v", j.MailType) + } + // MailUser → *string + if j.MailUser == nil || *j.MailUser != mailUser { + t.Errorf("MailUser mismatch: %v", j.MailUser) + } + // StandardOutput → *string + if j.StandardOutput == nil || *j.StandardOutput != stdOut { + t.Errorf("StandardOutput mismatch: %v", j.StandardOutput) + } + // StandardError → *string + if j.StandardError == nil || *j.StandardError != stdErr { + t.Errorf("StandardError mismatch: %v", j.StandardError) + } + // StandardInput → *string + if j.StandardInput == nil || *j.StandardInput != stdIn { + t.Errorf("StandardInput mismatch: %v", j.StandardInput) + } + // RequiredNodes → CSVString ([]string) + if len(j.RequiredNodes) != 2 || j.RequiredNodes[0] != "node01" || j.RequiredNodes[1] != "node02" { + t.Errorf("RequiredNodes mismatch: %v", j.RequiredNodes) + } + // ExcludedNodes → CSVString ([]string) + if len(j.ExcludedNodes) != 2 || j.ExcludedNodes[0] != "node03" || j.ExcludedNodes[1] != "node04" { + t.Errorf("ExcludedNodes mismatch: %v", j.ExcludedNodes) + } + // BeginTime → *Uint64NoVal + if j.BeginTime == nil || j.BeginTime.Number == nil || *j.BeginTime.Number != beginTime { + t.Errorf("BeginTime mismatch: %v", j.BeginTime) + } + // Deadline → *int64 (NO wrapper) + if j.Deadline == nil || *j.Deadline != deadline { + t.Errorf("Deadline mismatch: %v", j.Deadline) + } + // Array → *string + if j.Array == nil || *j.Array != array { + t.Errorf("Array mismatch: %v", j.Array) + } + // Dependency → *string + if j.Dependency == nil || *j.Dependency != dependency { + t.Errorf("Dependency mismatch: %v", j.Dependency) + } + // Requeue → *bool + if j.Requeue == nil || *j.Requeue != requeue { + t.Errorf("Requeue mismatch: %v", j.Requeue) + } + // KillOnNodeFail → *bool + if j.KillOnNodeFail == nil || *j.KillOnNodeFail != killOnNodeFail { + t.Errorf("KillOnNodeFail mismatch: %v", j.KillOnNodeFail) + } + + resp := slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "#!/bin/bash\necho test", + Partition: "normal", + QOS: "high", + JobName: "full-test", + CPUs: 8, + TimeLimit: "60", + MemoryPerNode: &memoryPerNode, + MemoryPerCpu: &memoryPerCpu, + Nodes: &nodes, + Tasks: &tasks, + CpusPerTask: &cpusPerTask, + Constraints: &constraints, + Reservation: &reservation, + Account: &account, + Nice: &nice, + MailType: &mailType, + MailUser: &mailUser, + StandardOutput: &stdOut, + StandardError: &stdErr, + StandardInput: &stdIn, + RequiredNodes: &reqNodes, + ExcludedNodes: &exclNodes, + BeginTime: &beginTime, + Deadline: &deadline, + Array: &array, + Dependency: &dependency, + Requeue: &requeue, + KillOnNodeFail: &killOnNodeFail, + }) + if err != nil { + t.Fatalf("SubmitJob: %v", err) + } + if resp.JobID != 999 { + t.Errorf("expected JobID 999, got %d", resp.JobID) + } +} + +func TestSubmitJob_BackwardCompat(t *testing.T) { + jobID := int32(555) + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body slurm.JobSubmitReq + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Job == nil { + t.Fatal("job desc is nil") + } + j := body.Job + + // Existing fields: Script and WorkDir should be set + if j.Script == nil || *j.Script != "echo hi" { + t.Errorf("Script mismatch: %v", j.Script) + } + if j.CurrentWorkingDirectory == nil || *j.CurrentWorkingDirectory != "/home/user" { + t.Errorf("CurrentWorkingDirectory mismatch: %v", j.CurrentWorkingDirectory) + } + + // All new scheduling fields should be nil/empty + if j.MemoryPerNode != nil { + t.Errorf("MemoryPerNode should be nil, got %v", j.MemoryPerNode) + } + if j.MemoryPerCpu != nil { + t.Errorf("MemoryPerCpu should be nil, got %v", j.MemoryPerCpu) + } + if j.Nodes != nil { + t.Errorf("Nodes should be nil, got %v", j.Nodes) + } + if j.Tasks != nil { + t.Errorf("Tasks should be nil, got %v", j.Tasks) + } + if j.CpusPerTask != nil { + t.Errorf("CpusPerTask should be nil, got %v", j.CpusPerTask) + } + if j.Constraints != nil { + t.Errorf("Constraints should be nil, got %v", j.Constraints) + } + if j.Reservation != nil { + t.Errorf("Reservation should be nil, got %v", j.Reservation) + } + if j.Account != nil { + t.Errorf("Account should be nil, got %v", j.Account) + } + if j.Nice != nil { + t.Errorf("Nice should be nil, got %v", j.Nice) + } + if len(j.MailType) != 0 { + t.Errorf("MailType should be empty, got %v", j.MailType) + } + if j.MailUser != nil { + t.Errorf("MailUser should be nil, got %v", j.MailUser) + } + if j.StandardOutput != nil { + t.Errorf("StandardOutput should be nil, got %v", j.StandardOutput) + } + if j.StandardError != nil { + t.Errorf("StandardError should be nil, got %v", j.StandardError) + } + if j.StandardInput != nil { + t.Errorf("StandardInput should be nil, got %v", j.StandardInput) + } + if len(j.RequiredNodes) != 0 { + t.Errorf("RequiredNodes should be empty, got %v", j.RequiredNodes) + } + if len(j.ExcludedNodes) != 0 { + t.Errorf("ExcludedNodes should be empty, got %v", j.ExcludedNodes) + } + if j.BeginTime != nil { + t.Errorf("BeginTime should be nil, got %v", j.BeginTime) + } + if j.Deadline != nil { + t.Errorf("Deadline should be nil, got %v", j.Deadline) + } + if j.Array != nil { + t.Errorf("Array should be nil, got %v", j.Array) + } + if j.Dependency != nil { + t.Errorf("Dependency should be nil, got %v", j.Dependency) + } + if j.Requeue != nil { + t.Errorf("Requeue should be nil, got %v", j.Requeue) + } + if j.KillOnNodeFail != nil { + t.Errorf("KillOnNodeFail should be nil, got %v", j.KillOnNodeFail) + } + + resp := slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "echo hi", + WorkDir: "/home/user", + }) + if err != nil { + t.Fatalf("SubmitJob: %v", err) + } + if resp.JobID != 555 { + t.Errorf("expected JobID 555, got %d", resp.JobID) + } +} + +func TestSubmitJob_MemoryBothSet(t *testing.T) { + jobID := int32(777) + memoryPerNode := int64(4096) + memoryPerCpu := int64(1024) + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body slurm.JobSubmitReq + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Job == nil { + t.Fatal("job desc is nil") + } + j := body.Job + + // Both memory fields should be mapped independently + if j.MemoryPerNode == nil || j.MemoryPerNode.Number == nil || *j.MemoryPerNode.Number != memoryPerNode { + t.Errorf("MemoryPerNode mismatch: %v", j.MemoryPerNode) + } + if j.MemoryPerCpu == nil || j.MemoryPerCpu.Number == nil || *j.MemoryPerCpu.Number != memoryPerCpu { + t.Errorf("MemoryPerCpu mismatch: %v", j.MemoryPerCpu) + } + + resp := slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "echo mem", + MemoryPerNode: &memoryPerNode, + MemoryPerCpu: &memoryPerCpu, + }) + if err != nil { + t.Fatalf("SubmitJob: %v", err) + } + if resp.JobID != 777 { + t.Errorf("expected JobID 777, got %d", resp.JobID) + } +} + func TestGetJob_FallbackToHistory_EmptyHistory(t *testing.T) { client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/internal/service/task_service_defaults_test.go b/internal/service/task_service_defaults_test.go new file mode 100644 index 0000000..a00bfe3 --- /dev/null +++ b/internal/service/task_service_defaults_test.go @@ -0,0 +1,184 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "testing" + "time" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" +) + +func TestProcessTask_DefaultTimeLimit(t *testing.T) { + jobID := int32(42) + + var capturedReq slurm.JobSubmitReq + + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { + t.Fatalf("decode request body: %v", err) + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "default-tl-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + ctx := context.Background() + + task := &model.Task{ + AppName: "default-tl-app", + AppID: appID, + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + Partition: "debug", + } + taskID, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("create task in DB: %v", err) + } + + err = env.svc.ProcessTask(ctx, taskID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + j := capturedReq.Job + if j == nil { + t.Fatal("Job desc is nil in captured request") + } + + if j.TimeLimit == nil { + t.Fatal("TimeLimit should not be nil, got nil") + } + if j.TimeLimit.Number == nil { + t.Fatal("TimeLimit.Number should not be nil, got nil") + } + if *j.TimeLimit.Number != int64(10080) { + t.Errorf("TimeLimit.Number = %d, want %d", *j.TimeLimit.Number, int64(10080)) + } +} + +func TestProcessTask_DefaultStdoutStderr(t *testing.T) { + jobID := int32(42) + + var capturedReq slurm.JobSubmitReq + + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { + t.Fatalf("decode request body: %v", err) + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "default-std-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + ctx := context.Background() + + task := &model.Task{ + AppName: "default-std-app", + AppID: appID, + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + Partition: "debug", + } + taskID, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("create task in DB: %v", err) + } + + err = env.svc.ProcessTask(ctx, taskID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + j := capturedReq.Job + if j == nil { + t.Fatal("Job desc is nil in captured request") + } + + if j.StandardOutput == nil { + t.Fatal("StandardOutput should not be nil, got nil") + } + if !strings.HasSuffix(*j.StandardOutput, "/slurm-%j.out") { + t.Errorf("StandardOutput = %q, want suffix /slurm-%%j.out", *j.StandardOutput) + } + + if j.StandardError == nil { + t.Fatal("StandardError should not be nil, got nil") + } + if !strings.HasSuffix(*j.StandardError, "/slurm-%j.err") { + t.Errorf("StandardError = %q, want suffix /slurm-%%j.err", *j.StandardError) + } +} + +func TestProcessTask_NoOverrideWhenSet(t *testing.T) { + jobID := int32(42) + + var capturedReq slurm.JobSubmitReq + + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { + t.Fatalf("decode request body: %v", err) + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "no-override-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + ctx := context.Background() + + customTL := int32(60) + customOut := "/custom/path.out" + task := &model.Task{ + AppName: "no-override-app", + AppID: appID, + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + Partition: "debug", + TimeLimit: &customTL, + StandardOutput: &customOut, + } + taskID, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("create task in DB: %v", err) + } + + err = env.svc.ProcessTask(ctx, taskID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + j := capturedReq.Job + if j == nil { + t.Fatal("Job desc is nil in captured request") + } + + if j.TimeLimit == nil { + t.Fatal("TimeLimit should not be nil, got nil") + } + if j.TimeLimit.Number == nil { + t.Fatal("TimeLimit.Number should not be nil, got nil") + } + if *j.TimeLimit.Number != int64(60) { + t.Errorf("TimeLimit.Number = %d, want %d (user value should be preserved)", *j.TimeLimit.Number, int64(60)) + } + + if j.StandardOutput == nil { + t.Fatal("StandardOutput should not be nil, got nil") + } + if *j.StandardOutput != "/custom/path.out" { + t.Errorf("StandardOutput = %q, want %q (user value should be preserved)", *j.StandardOutput, "/custom/path.out") + } +} diff --git a/internal/service/task_service_test.go b/internal/service/task_service_test.go index 6360aa7..ca9540d 100644 --- a/internal/service/task_service_test.go +++ b/internal/service/task_service_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "testing" + "time" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/slurm" @@ -673,3 +674,591 @@ func TestTaskService_ProcessTask_FileParamNotInInputFiles(t *testing.T) { t.Errorf("error should mention 'not found', got: %v", err) } } + +// --- Pointer helpers for scheduling field tests --- + +func int64Ptr(i int64) *int64 { return &i } +func boolPtr(b bool) *bool { return &b } + +func TestDerefInt32ToStr(t *testing.T) { + t.Run("nil_returns_empty", func(t *testing.T) { + if got := derefInt32ToStr(nil); got != "" { + t.Errorf("derefInt32ToStr(nil) = %q, want %q", got, "") + } + }) + t.Run("zero_returns_zero_string", func(t *testing.T) { + if got := derefInt32ToStr(int32Ptr(0)); got != "0" { + t.Errorf("derefInt32ToStr(0) = %q, want %q", got, "0") + } + }) + t.Run("positive_returns_string", func(t *testing.T) { + if got := derefInt32ToStr(int32Ptr(60)); got != "60" { + t.Errorf("derefInt32ToStr(60) = %q, want %q", got, "60") + } + }) +} + +func TestCreateTask_SchedulingFields(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "sched-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + var ( + partition = "gpu" + cpus = int32(8) + memoryPerNode = int64(4096) + memoryPerCpu = int64(1024) + timeLimit = int32(60) + qos = "high" + jobName = "test-job" + nodes = "2" + tasks = int32(4) + cpusPerTask = int32(2) + constraints = "gpu&fast" + reservation = "resv01" + account = "proj-alpha" + nice = int32(100) + mailType = "BEGIN,END" + mailUser = "admin@example.com" + stdOut = "/tmp/job_%j.out" + stdErr = "/tmp/job_%j.err" + stdIn = "/dev/null" + reqNodes = "node01,node02" + exclNodes = "node03,node04" + beginTime = int64(1700000000) + deadline = int64(1700099999) + array = "1-10" + dependency = "afterok:123" + requeue = true + killOnNodeFail = true + ) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + TaskName: "sched-task", + Partition: &partition, + Cpus: &cpus, + MemoryPerNode: &memoryPerNode, + MemoryPerCpu: &memoryPerCpu, + TimeLimit: &timeLimit, + QOS: &qos, + JobName: &jobName, + Nodes: &nodes, + Tasks: &tasks, + CpusPerTask: &cpusPerTask, + Constraints: &constraints, + Reservation: &reservation, + Account: &account, + Nice: &nice, + MailType: &mailType, + MailUser: &mailUser, + StandardOutput: &stdOut, + StandardError: &stdErr, + StandardInput: &stdIn, + RequiredNodes: &reqNodes, + ExcludedNodes: &exclNodes, + BeginTime: &beginTime, + Deadline: &deadline, + Array: &array, + Dependency: &dependency, + Requeue: &requeue, + KillOnNodeFail: &killOnNodeFail, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + if task.Partition != partition { + t.Errorf("Partition = %q, want %q", task.Partition, partition) + } + if task.Cpus == nil || *task.Cpus != cpus { + t.Errorf("Cpus = %v, want %d", task.Cpus, cpus) + } + if task.MemoryPerNode == nil || *task.MemoryPerNode != memoryPerNode { + t.Errorf("MemoryPerNode = %v, want %d", task.MemoryPerNode, memoryPerNode) + } + if task.MemoryPerCpu == nil || *task.MemoryPerCpu != memoryPerCpu { + t.Errorf("MemoryPerCpu = %v, want %d", task.MemoryPerCpu, memoryPerCpu) + } + if task.TimeLimit == nil || *task.TimeLimit != timeLimit { + t.Errorf("TimeLimit = %v, want %d", task.TimeLimit, timeLimit) + } + if task.QOS == nil || *task.QOS != qos { + t.Errorf("QOS = %v, want %q", task.QOS, qos) + } + if task.JobName == nil || *task.JobName != jobName { + t.Errorf("JobName = %v, want %q", task.JobName, jobName) + } + if task.Nodes == nil || *task.Nodes != nodes { + t.Errorf("Nodes = %v, want %q", task.Nodes, nodes) + } + if task.Tasks == nil || *task.Tasks != tasks { + t.Errorf("Tasks = %v, want %d", task.Tasks, tasks) + } + if task.CpusPerTask == nil || *task.CpusPerTask != cpusPerTask { + t.Errorf("CpusPerTask = %v, want %d", task.CpusPerTask, cpusPerTask) + } + if task.Constraints == nil || *task.Constraints != constraints { + t.Errorf("Constraints = %v, want %q", task.Constraints, constraints) + } + if task.Reservation == nil || *task.Reservation != reservation { + t.Errorf("Reservation = %v, want %q", task.Reservation, reservation) + } + if task.Account == nil || *task.Account != account { + t.Errorf("Account = %v, want %q", task.Account, account) + } + if task.Nice == nil || *task.Nice != nice { + t.Errorf("Nice = %v, want %d", task.Nice, nice) + } + if task.MailType == nil || *task.MailType != mailType { + t.Errorf("MailType = %v, want %q", task.MailType, mailType) + } + if task.MailUser == nil || *task.MailUser != mailUser { + t.Errorf("MailUser = %v, want %q", task.MailUser, mailUser) + } + if task.StandardOutput == nil || *task.StandardOutput != stdOut { + t.Errorf("StandardOutput = %v, want %q", task.StandardOutput, stdOut) + } + if task.StandardError == nil || *task.StandardError != stdErr { + t.Errorf("StandardError = %v, want %q", task.StandardError, stdErr) + } + if task.StandardInput == nil || *task.StandardInput != stdIn { + t.Errorf("StandardInput = %v, want %q", task.StandardInput, stdIn) + } + if task.RequiredNodes == nil || *task.RequiredNodes != reqNodes { + t.Errorf("RequiredNodes = %v, want %q", task.RequiredNodes, reqNodes) + } + if task.ExcludedNodes == nil || *task.ExcludedNodes != exclNodes { + t.Errorf("ExcludedNodes = %v, want %q", task.ExcludedNodes, exclNodes) + } + if task.BeginTime == nil || *task.BeginTime != beginTime { + t.Errorf("BeginTime = %v, want %d", task.BeginTime, beginTime) + } + if task.Deadline == nil || *task.Deadline != deadline { + t.Errorf("Deadline = %v, want %d", task.Deadline, deadline) + } + if task.Array == nil || *task.Array != array { + t.Errorf("Array = %v, want %q", task.Array, array) + } + if task.Dependency == nil || *task.Dependency != dependency { + t.Errorf("Dependency = %v, want %q", task.Dependency, dependency) + } + if task.Requeue == nil || *task.Requeue != requeue { + t.Errorf("Requeue = %v, want %v", task.Requeue, requeue) + } + if task.KillOnNodeFail == nil || *task.KillOnNodeFail != killOnNodeFail { + t.Errorf("KillOnNodeFail = %v, want %v", task.KillOnNodeFail, killOnNodeFail) + } +} + +func TestCreateTask_BackwardCompat(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "compat-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + TaskName: "compat-task", + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + if task.Partition != "" { + t.Errorf("Partition = %q, want empty", task.Partition) + } + if task.Cpus != nil { + t.Errorf("Cpus = %v, want nil", task.Cpus) + } + if task.MemoryPerNode != nil { + t.Errorf("MemoryPerNode = %v, want nil", task.MemoryPerNode) + } + if task.MemoryPerCpu != nil { + t.Errorf("MemoryPerCpu = %v, want nil", task.MemoryPerCpu) + } + if task.TimeLimit != nil { + t.Errorf("TimeLimit = %v, want nil", task.TimeLimit) + } + if task.QOS != nil { + t.Errorf("QOS = %v, want nil", task.QOS) + } + if task.JobName != nil { + t.Errorf("JobName = %v, want nil", task.JobName) + } + if task.Nodes != nil { + t.Errorf("Nodes = %v, want nil", task.Nodes) + } + if task.Tasks != nil { + t.Errorf("Tasks = %v, want nil", task.Tasks) + } + if task.CpusPerTask != nil { + t.Errorf("CpusPerTask = %v, want nil", task.CpusPerTask) + } + if task.Constraints != nil { + t.Errorf("Constraints = %v, want nil", task.Constraints) + } + if task.Reservation != nil { + t.Errorf("Reservation = %v, want nil", task.Reservation) + } + if task.Account != nil { + t.Errorf("Account = %v, want nil", task.Account) + } + if task.Nice != nil { + t.Errorf("Nice = %v, want nil", task.Nice) + } + if task.MailType != nil { + t.Errorf("MailType = %v, want nil", task.MailType) + } + if task.MailUser != nil { + t.Errorf("MailUser = %v, want nil", task.MailUser) + } + if task.StandardOutput != nil { + t.Errorf("StandardOutput = %v, want nil", task.StandardOutput) + } + if task.StandardError != nil { + t.Errorf("StandardError = %v, want nil", task.StandardError) + } + if task.StandardInput != nil { + t.Errorf("StandardInput = %v, want nil", task.StandardInput) + } + if task.RequiredNodes != nil { + t.Errorf("RequiredNodes = %v, want nil", task.RequiredNodes) + } + if task.ExcludedNodes != nil { + t.Errorf("ExcludedNodes = %v, want nil", task.ExcludedNodes) + } + if task.BeginTime != nil { + t.Errorf("BeginTime = %v, want nil", task.BeginTime) + } + if task.Deadline != nil { + t.Errorf("Deadline = %v, want nil", task.Deadline) + } + if task.Array != nil { + t.Errorf("Array = %v, want nil", task.Array) + } + if task.Dependency != nil { + t.Errorf("Dependency = %v, want nil", task.Dependency) + } + if task.Requeue != nil { + t.Errorf("Requeue = %v, want nil", task.Requeue) + } + if task.KillOnNodeFail != nil { + t.Errorf("KillOnNodeFail = %v, want nil", task.KillOnNodeFail) + } +} + +func TestProcessTask_SchedulingParams(t *testing.T) { + jobID := int32(42) + + var ( + cpus = int32(8) + memoryPerNode = int64(4096) + memoryPerCpu = int64(1024) + timeLimit = int32(60) + qos = "high" + jobName = "test-job" + nodes = "2" + tasks = int32(4) + cpusPerTask = int32(2) + constraints = "gpu&fast" + reservation = "resv01" + account = "proj-alpha" + nice = int32(100) + mailType = "BEGIN,END" + mailUser = "admin@example.com" + stdOut = "/tmp/job_%j.out" + stdErr = "/tmp/job_%j.err" + stdIn = "/dev/null" + reqNodes = "node01,node02" + exclNodes = "node03,node04" + beginTime = int64(1700000000) + deadline = int64(1700099999) + array = "1-10" + dependency = "afterok:123" + requeue = true + killOnNodeFail = true + ) + + var capturedReq slurm.JobSubmitReq + + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { + t.Fatalf("decode request body: %v", err) + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "sched-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + ctx := context.Background() + + task := &model.Task{ + AppName: "sched-app", + AppID: appID, + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + Partition: "gpu", + Cpus: &cpus, + MemoryPerNode: &memoryPerNode, + MemoryPerCpu: &memoryPerCpu, + TimeLimit: &timeLimit, + QOS: &qos, + JobName: &jobName, + Nodes: &nodes, + Tasks: &tasks, + CpusPerTask: &cpusPerTask, + Constraints: &constraints, + Reservation: &reservation, + Account: &account, + Nice: &nice, + MailType: &mailType, + MailUser: &mailUser, + StandardOutput: &stdOut, + StandardError: &stdErr, + StandardInput: &stdIn, + RequiredNodes: &reqNodes, + ExcludedNodes: &exclNodes, + BeginTime: &beginTime, + Deadline: &deadline, + Array: &array, + Dependency: &dependency, + Requeue: &requeue, + KillOnNodeFail: &killOnNodeFail, + } + taskID, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("create task in DB: %v", err) + } + + err = env.svc.ProcessTask(ctx, taskID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + j := capturedReq.Job + if j == nil { + t.Fatal("Job desc is nil in captured request") + } + + if j.Partition == nil || *j.Partition != "gpu" { + t.Errorf("Partition = %v, want %q", j.Partition, "gpu") + } + if j.MinimumCpus == nil || *j.MinimumCpus != int32(8) { + t.Errorf("MinimumCpus = %v, want 8", j.MinimumCpus) + } + if j.TimeLimit == nil || j.TimeLimit.Number == nil || *j.TimeLimit.Number != int64(60) { + t.Errorf("TimeLimit = %v, want 60", j.TimeLimit) + } + if j.Qos == nil || *j.Qos != "high" { + t.Errorf("Qos = %v, want %q", j.Qos, "high") + } + if j.Name == nil || *j.Name != "test-job" { + t.Errorf("Name = %v, want %q", j.Name, "test-job") + } + if j.MemoryPerNode == nil || j.MemoryPerNode.Number == nil || *j.MemoryPerNode.Number != int64(4096) { + t.Errorf("MemoryPerNode = %v, want 4096", j.MemoryPerNode) + } + if j.MemoryPerCpu == nil || j.MemoryPerCpu.Number == nil || *j.MemoryPerCpu.Number != int64(1024) { + t.Errorf("MemoryPerCpu = %v, want 1024", j.MemoryPerCpu) + } + if j.Nodes == nil || *j.Nodes != "2" { + t.Errorf("Nodes = %v, want %q", j.Nodes, "2") + } + if j.Tasks == nil || *j.Tasks != int32(4) { + t.Errorf("Tasks = %v, want 4", j.Tasks) + } + if j.CpusPerTask == nil || *j.CpusPerTask != int32(2) { + t.Errorf("CpusPerTask = %v, want 2", j.CpusPerTask) + } + if j.Constraints == nil || *j.Constraints != "gpu&fast" { + t.Errorf("Constraints = %v, want %q", j.Constraints, "gpu&fast") + } + if j.Reservation == nil || *j.Reservation != "resv01" { + t.Errorf("Reservation = %v, want %q", j.Reservation, "resv01") + } + if j.Account == nil || *j.Account != "proj-alpha" { + t.Errorf("Account = %v, want %q", j.Account, "proj-alpha") + } + if j.Nice == nil || *j.Nice != int32(100) { + t.Errorf("Nice = %v, want 100", j.Nice) + } + // MailType is split by comma in job_service.go + if len(j.MailType) != 2 || j.MailType[0] != "BEGIN" || j.MailType[1] != "END" { + t.Errorf("MailType = %v, want [BEGIN, END]", j.MailType) + } + if j.MailUser == nil || *j.MailUser != "admin@example.com" { + t.Errorf("MailUser = %v, want %q", j.MailUser, "admin@example.com") + } + if j.StandardOutput == nil || *j.StandardOutput != "/tmp/job_%j.out" { + t.Errorf("StandardOutput = %v, want %q", j.StandardOutput, "/tmp/job_%j.out") + } + if j.StandardError == nil || *j.StandardError != "/tmp/job_%j.err" { + t.Errorf("StandardError = %v, want %q", j.StandardError, "/tmp/job_%j.err") + } + if j.StandardInput == nil || *j.StandardInput != "/dev/null" { + t.Errorf("StandardInput = %v, want %q", j.StandardInput, "/dev/null") + } + // RequiredNodes/ExcludedNodes are split by comma in job_service.go + if len(j.RequiredNodes) != 2 || j.RequiredNodes[0] != "node01" || j.RequiredNodes[1] != "node02" { + t.Errorf("RequiredNodes = %v, want [node01, node02]", j.RequiredNodes) + } + if len(j.ExcludedNodes) != 2 || j.ExcludedNodes[0] != "node03" || j.ExcludedNodes[1] != "node04" { + t.Errorf("ExcludedNodes = %v, want [node03, node04]", j.ExcludedNodes) + } + if j.BeginTime == nil || j.BeginTime.Number == nil || *j.BeginTime.Number != int64(1700000000) { + t.Errorf("BeginTime = %v, want 1700000000", j.BeginTime) + } + if j.Deadline == nil || *j.Deadline != int64(1700099999) { + t.Errorf("Deadline = %v, want 1700099999", j.Deadline) + } + if j.Array == nil || *j.Array != "1-10" { + t.Errorf("Array = %v, want %q", j.Array, "1-10") + } + if j.Dependency == nil || *j.Dependency != "afterok:123" { + t.Errorf("Dependency = %v, want %q", j.Dependency, "afterok:123") + } + if j.Requeue == nil || *j.Requeue != true { + t.Errorf("Requeue = %v, want true", j.Requeue) + } + if j.KillOnNodeFail == nil || *j.KillOnNodeFail != true { + t.Errorf("KillOnNodeFail = %v, want true", j.KillOnNodeFail) + } +} + +func TestProcessTask_PartialSchedulingParams(t *testing.T) { + jobID := int32(42) + + var capturedReq slurm.JobSubmitReq + + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { + t.Fatalf("decode request body: %v", err) + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "partial-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + ctx := context.Background() + + task := &model.Task{ + AppName: "partial-app", + AppID: appID, + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + Partition: "debug", + } + taskID, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("create task in DB: %v", err) + } + + err = env.svc.ProcessTask(ctx, taskID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + j := capturedReq.Job + if j == nil { + t.Fatal("Job desc is nil in captured request") + } + + if j.Partition == nil || *j.Partition != "debug" { + t.Errorf("Partition = %v, want %q", j.Partition, "debug") + } + + if j.MinimumCpus != nil { + t.Errorf("MinimumCpus = %v, want nil (no cpus set)", j.MinimumCpus) + } + if j.TimeLimit == nil { + t.Errorf("TimeLimit = nil, want non-nil (default should be injected)") + } else if j.TimeLimit.Number == nil { + t.Errorf("TimeLimit.Number = nil, want non-nil") + } else if *j.TimeLimit.Number != int64(10080) { + t.Errorf("TimeLimit.Number = %d, want %d (default)", *j.TimeLimit.Number, int64(10080)) + } + if j.Qos != nil { + t.Errorf("Qos = %v, want nil (no qos set)", j.Qos) + } + if j.Name != nil { + t.Errorf("Name = %v, want nil (no job_name set)", j.Name) + } + if j.MemoryPerNode != nil { + t.Errorf("MemoryPerNode = %v, want nil", j.MemoryPerNode) + } + if j.MemoryPerCpu != nil { + t.Errorf("MemoryPerCpu = %v, want nil", j.MemoryPerCpu) + } + if j.Nodes != nil { + t.Errorf("Nodes = %v, want nil", j.Nodes) + } + if j.Tasks != nil { + t.Errorf("Tasks = %v, want nil", j.Tasks) + } + if j.CpusPerTask != nil { + t.Errorf("CpusPerTask = %v, want nil", j.CpusPerTask) + } + if j.Constraints != nil { + t.Errorf("Constraints = %v, want nil", j.Constraints) + } + if j.Reservation != nil { + t.Errorf("Reservation = %v, want nil", j.Reservation) + } + if j.Account != nil { + t.Errorf("Account = %v, want nil", j.Account) + } + if j.Nice != nil { + t.Errorf("Nice = %v, want nil", j.Nice) + } + if len(j.MailType) != 0 { + t.Errorf("MailType = %v, want empty", j.MailType) + } + if j.MailUser != nil { + t.Errorf("MailUser = %v, want nil", j.MailUser) + } + if j.StandardOutput == nil { + t.Errorf("StandardOutput = nil, want non-nil (default should be injected)") + } else if !strings.HasSuffix(*j.StandardOutput, "/slurm-%j.out") { + t.Errorf("StandardOutput = %q, want suffix /slurm-%%j.out (default)", *j.StandardOutput) + } + if j.StandardError == nil { + t.Errorf("StandardError = nil, want non-nil (default should be injected)") + } else if !strings.HasSuffix(*j.StandardError, "/slurm-%j.err") { + t.Errorf("StandardError = %q, want suffix /slurm-%%j.err (default)", *j.StandardError) + } + if j.StandardInput != nil { + t.Errorf("StandardInput = %v, want nil", j.StandardInput) + } + if len(j.RequiredNodes) != 0 { + t.Errorf("RequiredNodes = %v, want empty", j.RequiredNodes) + } + if len(j.ExcludedNodes) != 0 { + t.Errorf("ExcludedNodes = %v, want empty", j.ExcludedNodes) + } + if j.BeginTime != nil { + t.Errorf("BeginTime = %v, want nil", j.BeginTime) + } + if j.Deadline != nil { + t.Errorf("Deadline = %v, want nil", j.Deadline) + } + if j.Array != nil { + t.Errorf("Array = %v, want nil", j.Array) + } + if j.Dependency != nil { + t.Errorf("Dependency = %v, want nil", j.Dependency) + } + if j.Requeue != nil { + t.Errorf("Requeue = %v, want nil", j.Requeue) + } + if j.KillOnNodeFail != nil { + t.Errorf("KillOnNodeFail = %v, want nil", j.KillOnNodeFail) + } +}