package service import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/slurm" "gcy_hpc_server/internal/store" "go.uber.org/zap" "gorm.io/driver/sqlite" gormlogger "gorm.io/gorm/logger" "gorm.io/gorm" ) func setupTaskTestDB(t *testing.T) *gorm.DB { t.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ Logger: gormlogger.Default.LogMode(gormlogger.Silent), }) if err != nil { t.Fatalf("open sqlite: %v", err) } if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil { t.Fatalf("auto migrate: %v", err) } return db } type taskTestEnv struct { taskStore *store.TaskStore appStore *store.ApplicationStore fileStore *store.FileStore blobStore *store.BlobStore svc *TaskService srv *httptest.Server db *gorm.DB workDirBase string } func newTaskTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *taskTestEnv { t.Helper() db := setupTaskTestDB(t) ts := store.NewTaskStore(db) as := store.NewApplicationStore(db) fs := store.NewFileStore(db) bs := store.NewBlobStore(db) srv := httptest.NewServer(slurmHandler) client, _ := slurm.NewClient(srv.URL, srv.Client()) jobSvc := NewJobService(client, zap.NewNop()) workDirBase := filepath.Join(t.TempDir(), "workdir") os.MkdirAll(workDirBase, 0777) svc := NewTaskService(ts, as, fs, bs, nil, jobSvc, workDirBase, zap.NewNop()) return &taskTestEnv{ taskStore: ts, appStore: as, fileStore: fs, blobStore: bs, svc: svc, srv: srv, db: db, workDirBase: workDirBase, } } func (e *taskTestEnv) close() { e.srv.Close() } func (e *taskTestEnv) createApp(t *testing.T, name, script string, params json.RawMessage) int64 { t.Helper() id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{ Name: name, ScriptTemplate: script, Parameters: params, }) if err != nil { t.Fatalf("create app: %v", err) } return id } func TestTaskService_CreateTask_Success(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "my-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, TaskName: "test-task", Values: map[string]string{"KEY": "val"}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } if task.ID == 0 { t.Error("expected non-zero task ID") } if task.AppID != appID { t.Errorf("AppID = %d, want %d", task.AppID, appID) } if task.AppName != "my-app" { t.Errorf("AppName = %q, want %q", task.AppName, "my-app") } if task.Status != model.TaskStatusSubmitted { t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusSubmitted) } if task.TaskName != "test-task" { t.Errorf("TaskName = %q, want %q", task.TaskName, "test-task") } var values map[string]string if err := json.Unmarshal(task.Values, &values); err != nil { t.Fatalf("unmarshal values: %v", err) } if values["KEY"] != "val" { t.Errorf("values[KEY] = %q, want %q", values["KEY"], "val") } } func TestTaskService_CreateTask_InvalidAppID(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() _, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: 999, }) if err == nil { t.Fatal("expected error for invalid app_id") } if !strings.Contains(err.Error(), "not found") { t.Errorf("error should mention 'not found', got: %v", err) } } func TestTaskService_CreateTask_ExceedsFileLimit(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil) fileIDs := make([]int64, 101) for i := range fileIDs { fileIDs[i] = int64(i + 1) } _, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, InputFileIDs: fileIDs, }) if err == nil { t.Fatal("expected error for exceeding file limit") } if !strings.Contains(err.Error(), "exceeds limit") { t.Errorf("error should mention limit, got: %v", err) } } func TestTaskService_CreateTask_DuplicateFileIDs(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil) ctx := context.Background() for _, id := range []int64{1, 2} { f := &model.File{ Name: "file.txt", BlobSHA256: "abc123", } if err := env.fileStore.Create(ctx, f); err != nil { t.Fatalf("create file: %v", err) } if f.ID != id { t.Fatalf("expected file ID %d, got %d", id, f.ID) } } task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{ AppID: appID, InputFileIDs: []int64{1, 1, 2, 2}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } var fileIDs []int64 if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil { t.Fatalf("unmarshal file ids: %v", err) } if len(fileIDs) != 2 { t.Fatalf("expected 2 deduplicated file IDs, got %d: %v", len(fileIDs), fileIDs) } if fileIDs[0] != 1 || fileIDs[1] != 2 { t.Errorf("expected [1,2], got %v", fileIDs) } } func TestTaskService_CreateTask_AutoName(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "My Cool App", "#!/bin/bash\necho hi", nil) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, }) if err != nil { t.Fatalf("CreateTask: %v", err) } if !strings.HasPrefix(task.TaskName, "My_Cool_App_") { t.Errorf("auto-generated name should start with 'My_Cool_App_', got %q", task.TaskName) } } func TestTaskService_CreateTask_NilValues(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: nil, }) if err != nil { t.Fatalf("CreateTask: %v", err) } if string(task.Values) != `{}` { t.Errorf("Values = %q, want {}", string(task.Values)) } } func TestTaskService_ProcessTask_Success(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "test-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{"INPUT": "hello"}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err != nil { t.Fatalf("ProcessTask: %v", err) } updated, _ := env.taskStore.GetByID(context.Background(), task.ID) if updated.Status != model.TaskStatusQueued { t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) } if updated.SlurmJobID == nil || *updated.SlurmJobID != 42 { t.Errorf("SlurmJobID = %v, want 42", updated.SlurmJobID) } if updated.WorkDir == "" { t.Error("WorkDir should not be empty") } if !strings.HasPrefix(updated.WorkDir, env.workDirBase) { t.Errorf("WorkDir = %q, should start with %q", updated.WorkDir, env.workDirBase) } info, err := os.Stat(updated.WorkDir) if err != nil { t.Fatalf("stat workdir: %v", err) } if !info.IsDir() { t.Error("WorkDir should be a directory") } } func TestTaskService_ProcessTask_TaskNotFound(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() err := env.svc.ProcessTask(context.Background(), 999) if err == nil { t.Fatal("expected error for non-existent task") } if !strings.Contains(err.Error(), "not found") { t.Errorf("error should mention 'not found', got: %v", err) } } func TestTaskService_ProcessTask_SlurmError(t *testing.T) { env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(`{"error":"slurm down"}`)) })) defer env.close() appID := env.createApp(t, "test-app", "#!/bin/bash\necho hello", nil) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err == nil { t.Fatal("expected error from Slurm") } updated, _ := env.taskStore.GetByID(context.Background(), task.ID) if updated.Status != model.TaskStatusFailed { t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusFailed) } if updated.CurrentStep != model.TaskStepSubmitting { t.Errorf("CurrentStep = %q, want %q", updated.CurrentStep, model.TaskStepSubmitting) } if !strings.Contains(updated.ErrorMessage, "submit job") { t.Errorf("ErrorMessage should mention 'submit job', got: %q", updated.ErrorMessage) } } func TestTaskService_ProcessTaskSync(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "sync-app", "#!/bin/bash\necho hello", nil) resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{ AppID: appID, }) if err != nil { t.Fatalf("ProcessTaskSync: %v", err) } if resp.JobID != 42 { t.Errorf("JobID = %d, want 42", resp.JobID) } } func TestTaskService_ProcessTaskSync_NoMinIO(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "no-minio-app", "#!/bin/bash\necho hello", nil) resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{ AppID: appID, InputFileIDs: nil, }) if err != nil { t.Fatalf("ProcessTaskSync: %v", err) } if resp.JobID != 42 { t.Errorf("JobID = %d, want 42", resp.JobID) } } func TestTaskService_ProcessTask_NilValues(t *testing.T) { jobID := int32(55) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "nil-val-app", "#!/bin/bash\necho hello", nil) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: nil, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err != nil { t.Fatalf("ProcessTask: %v", err) } updated, _ := env.taskStore.GetByID(context.Background(), task.ID) if updated.Status != model.TaskStatusQueued { t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) } } func TestTaskService_ListTasks(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "list-app", "#!/bin/bash\necho hi", nil) for i := 0; i < 3; i++ { _, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, TaskName: "task-" + string(rune('A'+i)), }) if err != nil { t.Fatalf("CreateTask %d: %v", i, err) } } tasks, total, err := env.svc.ListTasks(context.Background(), &model.TaskListQuery{ Page: 1, PageSize: 10, }) if err != nil { t.Fatalf("ListTasks: %v", err) } if total != 3 { t.Errorf("total = %d, want 3", total) } if len(tasks) != 3 { t.Errorf("len(tasks) = %d, want 3", len(tasks)) } } func TestTaskService_ProcessTask_ValidateParams_MissingRequired(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() // App requires INPUT param, but we submit without it appID := env.createApp(t, "validation-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{}, // missing required INPUT }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err == nil { t.Fatal("expected error for missing required parameter, got nil — ValidateParams is not being called in ProcessTask pipeline") } errStr := err.Error() if !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "missing") && !strings.Contains(errStr, "INPUT") { t.Errorf("error should mention 'validation', 'missing', or 'INPUT', got: %v", err) } } func TestTaskService_ProcessTask_ValidateParams_InvalidInteger(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() // App expects integer param NUM, but we submit "abc" appID := env.createApp(t, "int-validation-app", "#!/bin/bash\necho $NUM", json.RawMessage(`[{"name":"NUM","type":"integer","required":true}]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{"NUM": "abc"}, // invalid integer }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err == nil { t.Fatal("expected error for invalid integer parameter, got nil — ValidateParams is not being called in ProcessTask pipeline") } errStr := err.Error() if !strings.Contains(errStr, "integer") && !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "NUM") { t.Errorf("error should mention 'integer', 'validation', or 'NUM', got: %v", err) } } func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) { jobID := int32(99) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "valid-params-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{"INPUT": "hello"}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err != nil { t.Fatalf("ProcessTask with valid params: %v", err) } updated, _ := env.taskStore.GetByID(context.Background(), task.ID) if updated.Status != model.TaskStatusQueued { t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) } if updated.SlurmJobID == nil || *updated.SlurmJobID != 99 { t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID) } } func TestTaskService_ProcessTask_FileParamResolution(t *testing.T) { jobID := int32(88) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() ctx := context.Background() blob := &model.FileBlob{SHA256: "deadbeef", MinioKey: "test/file.bin", FileSize: 1024} if err := env.db.Create(blob).Error; err != nil { t.Fatalf("create blob: %v", err) } file := &model.File{Name: "model_weights.bin", BlobSHA256: blob.SHA256} if err := env.fileStore.Create(ctx, file); err != nil { t.Fatalf("create file: %v", err) } appID := env.createApp(t, "file-param-app", "#!/bin/bash\n#SBATCH --chdir=$WORK_DIR\npython train.py --model $MODEL_PATH", json.RawMessage(`[{"name":"MODEL_PATH","type":"file","required":true}]`)) // No InputFileIDs because stagingSvc is nil in tests; file is still resolvable via fileStore. task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{"MODEL_PATH": fmt.Sprintf("%d", file.ID)}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(ctx, task.ID) if err != nil { t.Fatalf("ProcessTask with file param: %v", err) } updated, _ := env.taskStore.GetByID(ctx, task.ID) if updated.Status != model.TaskStatusQueued { t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) } } func TestTaskService_ProcessTask_WorkDirInScript(t *testing.T) { jobID := int32(77) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "workdir-app", "#!/bin/bash\n#SBATCH --chdir=$WORK_DIR\necho hello", json.RawMessage(`[]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err != nil { t.Fatalf("ProcessTask: %v", err) } updated, _ := env.taskStore.GetByID(context.Background(), task.ID) if updated.WorkDir == "" { t.Fatal("WorkDir should be set") } } func TestTaskService_ProcessTask_FileParamInvalidID(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "bad-file-app", "#!/bin/bash\npython train.py --model $MODEL", json.RawMessage(`[{"name":"MODEL","type":"file","required":true}]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{"MODEL": "not_a_number"}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err == nil { t.Fatal("expected error for non-numeric file ID in file-type param") } if !strings.Contains(err.Error(), "file ID") { t.Errorf("error should mention 'file ID', got: %v", err) } } func TestTaskService_ProcessTask_FileParamNotInInputFiles(t *testing.T) { jobID := int32(42) env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "missing-file-app", "#!/bin/bash\npython train.py --model $MODEL", json.RawMessage(`[{"name":"MODEL","type":"file","required":true}]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, Values: map[string]string{"MODEL": "999"}, }) if err != nil { t.Fatalf("CreateTask: %v", err) } err = env.svc.ProcessTask(context.Background(), task.ID) if err == nil { t.Fatal("expected error for file_id not in input_file_ids") } if !strings.Contains(err.Error(), "not found") { t.Errorf("error should mention 'not found', got: %v", err) } } // --- Pointer helpers for scheduling field tests --- func int64Ptr(i int64) *int64 { return &i } func boolPtr(b bool) *bool { return &b } func TestDerefInt32ToStr(t *testing.T) { t.Run("nil_returns_empty", func(t *testing.T) { if got := derefInt32ToStr(nil); got != "" { t.Errorf("derefInt32ToStr(nil) = %q, want %q", got, "") } }) t.Run("zero_returns_zero_string", func(t *testing.T) { if got := derefInt32ToStr(int32Ptr(0)); got != "0" { t.Errorf("derefInt32ToStr(0) = %q, want %q", got, "0") } }) t.Run("positive_returns_string", func(t *testing.T) { if got := derefInt32ToStr(int32Ptr(60)); got != "60" { t.Errorf("derefInt32ToStr(60) = %q, want %q", got, "60") } }) } func TestCreateTask_SchedulingFields(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "sched-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) var ( partition = "gpu" cpus = int32(8) memoryPerNode = int64(4096) memoryPerCpu = int64(1024) timeLimit = int32(60) qos = "high" jobName = "test-job" nodes = "2" tasks = int32(4) cpusPerTask = int32(2) constraints = "gpu&fast" reservation = "resv01" account = "proj-alpha" nice = int32(100) mailType = "BEGIN,END" mailUser = "admin@example.com" stdOut = "/tmp/job_%j.out" stdErr = "/tmp/job_%j.err" stdIn = "/dev/null" reqNodes = "node01,node02" exclNodes = "node03,node04" beginTime = int64(1700000000) deadline = int64(1700099999) array = "1-10" dependency = "afterok:123" requeue = true killOnNodeFail = true ) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, TaskName: "sched-task", Partition: &partition, Cpus: &cpus, MemoryPerNode: &memoryPerNode, MemoryPerCpu: &memoryPerCpu, TimeLimit: &timeLimit, QOS: &qos, JobName: &jobName, Nodes: &nodes, Tasks: &tasks, CpusPerTask: &cpusPerTask, Constraints: &constraints, Reservation: &reservation, Account: &account, Nice: &nice, MailType: &mailType, MailUser: &mailUser, StandardOutput: &stdOut, StandardError: &stdErr, StandardInput: &stdIn, RequiredNodes: &reqNodes, ExcludedNodes: &exclNodes, BeginTime: &beginTime, Deadline: &deadline, Array: &array, Dependency: &dependency, Requeue: &requeue, KillOnNodeFail: &killOnNodeFail, }) if err != nil { t.Fatalf("CreateTask: %v", err) } if task.Partition != partition { t.Errorf("Partition = %q, want %q", task.Partition, partition) } if task.Cpus == nil || *task.Cpus != cpus { t.Errorf("Cpus = %v, want %d", task.Cpus, cpus) } if task.MemoryPerNode == nil || *task.MemoryPerNode != memoryPerNode { t.Errorf("MemoryPerNode = %v, want %d", task.MemoryPerNode, memoryPerNode) } if task.MemoryPerCpu == nil || *task.MemoryPerCpu != memoryPerCpu { t.Errorf("MemoryPerCpu = %v, want %d", task.MemoryPerCpu, memoryPerCpu) } if task.TimeLimit == nil || *task.TimeLimit != timeLimit { t.Errorf("TimeLimit = %v, want %d", task.TimeLimit, timeLimit) } if task.QOS == nil || *task.QOS != qos { t.Errorf("QOS = %v, want %q", task.QOS, qos) } if task.JobName == nil || *task.JobName != jobName { t.Errorf("JobName = %v, want %q", task.JobName, jobName) } if task.Nodes == nil || *task.Nodes != nodes { t.Errorf("Nodes = %v, want %q", task.Nodes, nodes) } if task.Tasks == nil || *task.Tasks != tasks { t.Errorf("Tasks = %v, want %d", task.Tasks, tasks) } if task.CpusPerTask == nil || *task.CpusPerTask != cpusPerTask { t.Errorf("CpusPerTask = %v, want %d", task.CpusPerTask, cpusPerTask) } if task.Constraints == nil || *task.Constraints != constraints { t.Errorf("Constraints = %v, want %q", task.Constraints, constraints) } if task.Reservation == nil || *task.Reservation != reservation { t.Errorf("Reservation = %v, want %q", task.Reservation, reservation) } if task.Account == nil || *task.Account != account { t.Errorf("Account = %v, want %q", task.Account, account) } if task.Nice == nil || *task.Nice != nice { t.Errorf("Nice = %v, want %d", task.Nice, nice) } if task.MailType == nil || *task.MailType != mailType { t.Errorf("MailType = %v, want %q", task.MailType, mailType) } if task.MailUser == nil || *task.MailUser != mailUser { t.Errorf("MailUser = %v, want %q", task.MailUser, mailUser) } if task.StandardOutput == nil || *task.StandardOutput != stdOut { t.Errorf("StandardOutput = %v, want %q", task.StandardOutput, stdOut) } if task.StandardError == nil || *task.StandardError != stdErr { t.Errorf("StandardError = %v, want %q", task.StandardError, stdErr) } if task.StandardInput == nil || *task.StandardInput != stdIn { t.Errorf("StandardInput = %v, want %q", task.StandardInput, stdIn) } if task.RequiredNodes == nil || *task.RequiredNodes != reqNodes { t.Errorf("RequiredNodes = %v, want %q", task.RequiredNodes, reqNodes) } if task.ExcludedNodes == nil || *task.ExcludedNodes != exclNodes { t.Errorf("ExcludedNodes = %v, want %q", task.ExcludedNodes, exclNodes) } if task.BeginTime == nil || *task.BeginTime != beginTime { t.Errorf("BeginTime = %v, want %d", task.BeginTime, beginTime) } if task.Deadline == nil || *task.Deadline != deadline { t.Errorf("Deadline = %v, want %d", task.Deadline, deadline) } if task.Array == nil || *task.Array != array { t.Errorf("Array = %v, want %q", task.Array, array) } if task.Dependency == nil || *task.Dependency != dependency { t.Errorf("Dependency = %v, want %q", task.Dependency, dependency) } if task.Requeue == nil || *task.Requeue != requeue { t.Errorf("Requeue = %v, want %v", task.Requeue, requeue) } if task.KillOnNodeFail == nil || *task.KillOnNodeFail != killOnNodeFail { t.Errorf("KillOnNodeFail = %v, want %v", task.KillOnNodeFail, killOnNodeFail) } } func TestCreateTask_BackwardCompat(t *testing.T) { env := newTaskTestEnv(t, nil) defer env.close() appID := env.createApp(t, "compat-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, TaskName: "compat-task", }) if err != nil { t.Fatalf("CreateTask: %v", err) } if task.Partition != "" { t.Errorf("Partition = %q, want empty", task.Partition) } if task.Cpus != nil { t.Errorf("Cpus = %v, want nil", task.Cpus) } if task.MemoryPerNode != nil { t.Errorf("MemoryPerNode = %v, want nil", task.MemoryPerNode) } if task.MemoryPerCpu != nil { t.Errorf("MemoryPerCpu = %v, want nil", task.MemoryPerCpu) } if task.TimeLimit != nil { t.Errorf("TimeLimit = %v, want nil", task.TimeLimit) } if task.QOS != nil { t.Errorf("QOS = %v, want nil", task.QOS) } if task.JobName != nil { t.Errorf("JobName = %v, want nil", task.JobName) } if task.Nodes != nil { t.Errorf("Nodes = %v, want nil", task.Nodes) } if task.Tasks != nil { t.Errorf("Tasks = %v, want nil", task.Tasks) } if task.CpusPerTask != nil { t.Errorf("CpusPerTask = %v, want nil", task.CpusPerTask) } if task.Constraints != nil { t.Errorf("Constraints = %v, want nil", task.Constraints) } if task.Reservation != nil { t.Errorf("Reservation = %v, want nil", task.Reservation) } if task.Account != nil { t.Errorf("Account = %v, want nil", task.Account) } if task.Nice != nil { t.Errorf("Nice = %v, want nil", task.Nice) } if task.MailType != nil { t.Errorf("MailType = %v, want nil", task.MailType) } if task.MailUser != nil { t.Errorf("MailUser = %v, want nil", task.MailUser) } if task.StandardOutput != nil { t.Errorf("StandardOutput = %v, want nil", task.StandardOutput) } if task.StandardError != nil { t.Errorf("StandardError = %v, want nil", task.StandardError) } if task.StandardInput != nil { t.Errorf("StandardInput = %v, want nil", task.StandardInput) } if task.RequiredNodes != nil { t.Errorf("RequiredNodes = %v, want nil", task.RequiredNodes) } if task.ExcludedNodes != nil { t.Errorf("ExcludedNodes = %v, want nil", task.ExcludedNodes) } if task.BeginTime != nil { t.Errorf("BeginTime = %v, want nil", task.BeginTime) } if task.Deadline != nil { t.Errorf("Deadline = %v, want nil", task.Deadline) } if task.Array != nil { t.Errorf("Array = %v, want nil", task.Array) } if task.Dependency != nil { t.Errorf("Dependency = %v, want nil", task.Dependency) } if task.Requeue != nil { t.Errorf("Requeue = %v, want nil", task.Requeue) } if task.KillOnNodeFail != nil { t.Errorf("KillOnNodeFail = %v, want nil", task.KillOnNodeFail) } } func TestProcessTask_SchedulingParams(t *testing.T) { jobID := int32(42) var ( cpus = int32(8) memoryPerNode = int64(4096) memoryPerCpu = int64(1024) timeLimit = int32(60) qos = "high" jobName = "test-job" nodes = "2" tasks = int32(4) cpusPerTask = int32(2) constraints = "gpu&fast" reservation = "resv01" account = "proj-alpha" nice = int32(100) mailType = "BEGIN,END" mailUser = "admin@example.com" stdOut = "/tmp/job_%j.out" stdErr = "/tmp/job_%j.err" stdIn = "/dev/null" reqNodes = "node01,node02" exclNodes = "node03,node04" beginTime = int64(1700000000) deadline = int64(1700099999) array = "1-10" dependency = "afterok:123" requeue = true killOnNodeFail = true ) var capturedReq slurm.JobSubmitReq env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { t.Fatalf("decode request body: %v", err) } json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "sched-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) ctx := context.Background() task := &model.Task{ AppName: "sched-app", AppID: appID, Status: model.TaskStatusSubmitted, SubmittedAt: time.Now(), Partition: "gpu", Cpus: &cpus, MemoryPerNode: &memoryPerNode, MemoryPerCpu: &memoryPerCpu, TimeLimit: &timeLimit, QOS: &qos, JobName: &jobName, Nodes: &nodes, Tasks: &tasks, CpusPerTask: &cpusPerTask, Constraints: &constraints, Reservation: &reservation, Account: &account, Nice: &nice, MailType: &mailType, MailUser: &mailUser, StandardOutput: &stdOut, StandardError: &stdErr, StandardInput: &stdIn, RequiredNodes: &reqNodes, ExcludedNodes: &exclNodes, BeginTime: &beginTime, Deadline: &deadline, Array: &array, Dependency: &dependency, Requeue: &requeue, KillOnNodeFail: &killOnNodeFail, } taskID, err := env.taskStore.Create(ctx, task) if err != nil { t.Fatalf("create task in DB: %v", err) } err = env.svc.ProcessTask(ctx, taskID) if err != nil { t.Fatalf("ProcessTask: %v", err) } j := capturedReq.Job if j == nil { t.Fatal("Job desc is nil in captured request") } if j.Partition == nil || *j.Partition != "gpu" { t.Errorf("Partition = %v, want %q", j.Partition, "gpu") } // CPUs=8 maps to CpusPerTask, then overridden by explicit CpusPerTask=2 if j.CpusPerTask == nil || *j.CpusPerTask != int32(2) { t.Errorf("CpusPerTask = %v, want 2 (explicit CpusPerTask overrides CPUs)", j.CpusPerTask) } if j.MinimumCpus != nil { t.Errorf("MinimumCpus should be nil, got %v", j.MinimumCpus) } if j.TimeLimit == nil || j.TimeLimit.Number == nil || *j.TimeLimit.Number != int64(60) { t.Errorf("TimeLimit = %v, want 60", j.TimeLimit) } if j.Qos == nil || *j.Qos != "high" { t.Errorf("Qos = %v, want %q", j.Qos, "high") } if j.Name == nil || *j.Name != "test-job" { t.Errorf("Name = %v, want %q", j.Name, "test-job") } if j.MemoryPerNode == nil || j.MemoryPerNode.Number == nil || *j.MemoryPerNode.Number != int64(4096) { t.Errorf("MemoryPerNode = %v, want 4096", j.MemoryPerNode) } if j.MemoryPerCpu == nil || j.MemoryPerCpu.Number == nil || *j.MemoryPerCpu.Number != int64(1024) { t.Errorf("MemoryPerCpu = %v, want 1024", j.MemoryPerCpu) } if j.Nodes == nil || *j.Nodes != "2" { t.Errorf("Nodes = %v, want %q", j.Nodes, "2") } if j.Tasks == nil || *j.Tasks != int32(4) { t.Errorf("Tasks = %v, want 4", j.Tasks) } if j.CpusPerTask == nil || *j.CpusPerTask != int32(2) { t.Errorf("CpusPerTask = %v, want 2", j.CpusPerTask) } if j.Constraints == nil || *j.Constraints != "gpu&fast" { t.Errorf("Constraints = %v, want %q", j.Constraints, "gpu&fast") } if j.Reservation == nil || *j.Reservation != "resv01" { t.Errorf("Reservation = %v, want %q", j.Reservation, "resv01") } if j.Account == nil || *j.Account != "proj-alpha" { t.Errorf("Account = %v, want %q", j.Account, "proj-alpha") } if j.Nice == nil || *j.Nice != int32(100) { t.Errorf("Nice = %v, want 100", j.Nice) } // MailType is split by comma in job_service.go if len(j.MailType) != 2 || j.MailType[0] != "BEGIN" || j.MailType[1] != "END" { t.Errorf("MailType = %v, want [BEGIN, END]", j.MailType) } if j.MailUser == nil || *j.MailUser != "admin@example.com" { t.Errorf("MailUser = %v, want %q", j.MailUser, "admin@example.com") } if j.StandardOutput == nil || *j.StandardOutput != "/tmp/job_%j.out" { t.Errorf("StandardOutput = %v, want %q", j.StandardOutput, "/tmp/job_%j.out") } if j.StandardError == nil || *j.StandardError != "/tmp/job_%j.err" { t.Errorf("StandardError = %v, want %q", j.StandardError, "/tmp/job_%j.err") } if j.StandardInput == nil || *j.StandardInput != "/dev/null" { t.Errorf("StandardInput = %v, want %q", j.StandardInput, "/dev/null") } // RequiredNodes/ExcludedNodes are split by comma in job_service.go if len(j.RequiredNodes) != 2 || j.RequiredNodes[0] != "node01" || j.RequiredNodes[1] != "node02" { t.Errorf("RequiredNodes = %v, want [node01, node02]", j.RequiredNodes) } if len(j.ExcludedNodes) != 2 || j.ExcludedNodes[0] != "node03" || j.ExcludedNodes[1] != "node04" { t.Errorf("ExcludedNodes = %v, want [node03, node04]", j.ExcludedNodes) } if j.BeginTime == nil || j.BeginTime.Number == nil || *j.BeginTime.Number != int64(1700000000) { t.Errorf("BeginTime = %v, want 1700000000", j.BeginTime) } if j.Deadline == nil || *j.Deadline != int64(1700099999) { t.Errorf("Deadline = %v, want 1700099999", j.Deadline) } if j.Array == nil || *j.Array != "1-10" { t.Errorf("Array = %v, want %q", j.Array, "1-10") } if j.Dependency == nil || *j.Dependency != "afterok:123" { t.Errorf("Dependency = %v, want %q", j.Dependency, "afterok:123") } if j.Requeue == nil || *j.Requeue != true { t.Errorf("Requeue = %v, want true", j.Requeue) } if j.KillOnNodeFail == nil || *j.KillOnNodeFail != true { t.Errorf("KillOnNodeFail = %v, want true", j.KillOnNodeFail) } } func TestProcessTask_PartialSchedulingParams(t *testing.T) { jobID := int32(42) var capturedReq slurm.JobSubmitReq env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { t.Fatalf("decode request body: %v", err) } json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() appID := env.createApp(t, "partial-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) ctx := context.Background() task := &model.Task{ AppName: "partial-app", AppID: appID, Status: model.TaskStatusSubmitted, SubmittedAt: time.Now(), Partition: "debug", } taskID, err := env.taskStore.Create(ctx, task) if err != nil { t.Fatalf("create task in DB: %v", err) } err = env.svc.ProcessTask(ctx, taskID) if err != nil { t.Fatalf("ProcessTask: %v", err) } j := capturedReq.Job if j == nil { t.Fatal("Job desc is nil in captured request") } if j.Partition == nil || *j.Partition != "debug" { t.Errorf("Partition = %v, want %q", j.Partition, "debug") } if j.CpusPerTask != nil { t.Errorf("CpusPerTask = %v, want nil (no cpus set)", j.CpusPerTask) } if j.TimeLimit == nil { t.Errorf("TimeLimit = nil, want non-nil (default should be injected)") } else if j.TimeLimit.Number == nil { t.Errorf("TimeLimit.Number = nil, want non-nil") } else if *j.TimeLimit.Number != int64(10080) { t.Errorf("TimeLimit.Number = %d, want %d (default)", *j.TimeLimit.Number, int64(10080)) } if j.Qos != nil { t.Errorf("Qos = %v, want nil (no qos set)", j.Qos) } if j.Name != nil { t.Errorf("Name = %v, want nil (no job_name set)", j.Name) } if j.MemoryPerNode != nil { t.Errorf("MemoryPerNode = %v, want nil", j.MemoryPerNode) } if j.MemoryPerCpu != nil { t.Errorf("MemoryPerCpu = %v, want nil", j.MemoryPerCpu) } if j.Nodes != nil { t.Errorf("Nodes = %v, want nil", j.Nodes) } if j.Tasks != nil { t.Errorf("Tasks = %v, want nil", j.Tasks) } if j.CpusPerTask != nil { t.Errorf("CpusPerTask = %v, want nil", j.CpusPerTask) } if j.Constraints != nil { t.Errorf("Constraints = %v, want nil", j.Constraints) } if j.Reservation != nil { t.Errorf("Reservation = %v, want nil", j.Reservation) } if j.Account != nil { t.Errorf("Account = %v, want nil", j.Account) } if j.Nice != nil { t.Errorf("Nice = %v, want nil", j.Nice) } if len(j.MailType) != 0 { t.Errorf("MailType = %v, want empty", j.MailType) } if j.MailUser != nil { t.Errorf("MailUser = %v, want nil", j.MailUser) } if j.StandardOutput == nil { t.Errorf("StandardOutput = nil, want non-nil (default should be injected)") } else if !strings.HasSuffix(*j.StandardOutput, "/slurm-%j.out") { t.Errorf("StandardOutput = %q, want suffix /slurm-%%j.out (default)", *j.StandardOutput) } if j.StandardError == nil { t.Errorf("StandardError = nil, want non-nil (default should be injected)") } else if !strings.HasSuffix(*j.StandardError, "/slurm-%j.err") { t.Errorf("StandardError = %q, want suffix /slurm-%%j.err (default)", *j.StandardError) } if j.StandardInput != nil { t.Errorf("StandardInput = %v, want nil", j.StandardInput) } if len(j.RequiredNodes) != 0 { t.Errorf("RequiredNodes = %v, want empty", j.RequiredNodes) } if len(j.ExcludedNodes) != 0 { t.Errorf("ExcludedNodes = %v, want empty", j.ExcludedNodes) } if j.BeginTime != nil { t.Errorf("BeginTime = %v, want nil", j.BeginTime) } if j.Deadline != nil { t.Errorf("Deadline = %v, want nil", j.Deadline) } if j.Array != nil { t.Errorf("Array = %v, want nil", j.Array) } if j.Dependency != nil { t.Errorf("Dependency = %v, want nil", j.Dependency) } if j.Requeue != nil { t.Errorf("Requeue = %v, want nil", j.Requeue) } if j.KillOnNodeFail != nil { t.Errorf("KillOnNodeFail = %v, want nil", j.KillOnNodeFail) } } func TestTaskService_ProcessTask_SchedulingMapInjection(t *testing.T) { jobID := int32(42) var capturedReq slurm.JobSubmitReq env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { t.Fatalf("decode request body: %v", err) } json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, }) })) defer env.close() params := json.RawMessage(`[ {"name": "NP", "type": "integer", "scheduling_map": "cpus", "required": true} ]`) appID := env.createApp(t, "sched-map-app", "#!/bin/bash\nmpirun -np $NP my_app", params) cpus := int32(8) task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ AppID: appID, TaskName: "sched-map-test", Cpus: &cpus, }) if err != nil { t.Fatalf("CreateTask: %v", err) } if err := env.svc.ProcessTask(context.Background(), task.ID); err != nil { t.Fatalf("ProcessTask: %v", err) } if capturedReq.Script == nil { t.Fatal("submitted script is nil") } if !strings.Contains(*capturedReq.Script, "'8'") { t.Errorf("rendered script does not contain shell-escaped scheduling value:\n%s", *capturedReq.Script) } if !strings.Contains(*capturedReq.Script, "mpirun -np '8'") { t.Errorf("rendered script does not contain expected mpirun command:\n%s", *capturedReq.Script) } }