diff --git a/internal/model/application.go b/internal/model/application.go index b572f9e..07bad3b 100644 --- a/internal/model/application.go +++ b/internal/model/application.go @@ -43,7 +43,8 @@ type ParameterSchema struct { Required bool `json:"required,omitempty"` // 是否必填 Default string `json:"default,omitempty"` // 默认值 Options []string `json:"options,omitempty"` // 枚举选项列表 - Description string `json:"description,omitempty"` // 参数说明 + Description string `json:"description,omitempty"` // 参数说明 + SchedulingMap string `json:"scheduling_map,omitempty"` // maps to a scheduling param } // CreateApplicationRequest 是创建应用的 API 请求。 diff --git a/internal/service/cluster_service.go b/internal/service/cluster_service.go index 962af08..e0b6d8a 100644 --- a/internal/service/cluster_service.go +++ b/internal/service/cluster_service.go @@ -42,6 +42,13 @@ func derefInt32ToStr(i *int32) string { return strconv.FormatInt(int64(*i), 10) } +func derefInt64ToStr(i *int64) string { + if i == nil { + return "" + } + return strconv.FormatInt(*i, 10) +} + func uint32NoValString(v *slurm.Uint32NoVal) string { if v == nil { return "" diff --git a/internal/service/cluster_service_test.go b/internal/service/cluster_service_test.go index fed218d..2b7b123 100644 --- a/internal/service/cluster_service_test.go +++ b/internal/service/cluster_service_test.go @@ -444,6 +444,26 @@ func TestClusterService_GetPartition_ErrorLogging(t *testing.T) { } } +func TestDerefInt64ToStr(t *testing.T) { + t.Run("nil returns empty", func(t *testing.T) { + if got := derefInt64ToStr(nil); got != "" { + t.Errorf("derefInt64ToStr(nil) = %q, want empty", got) + } + }) + t.Run("non-nil returns string", func(t *testing.T) { + v := int64(4096) + if got := derefInt64ToStr(&v); got != "4096" { + t.Errorf("derefInt64ToStr(4096) = %q, want %q", got, "4096") + } + }) + t.Run("zero value", func(t *testing.T) { + v := int64(0) + if got := derefInt64ToStr(&v); got != "0" { + t.Errorf("derefInt64ToStr(0) = %q, want %q", got, "0") + } + }) +} + func TestClusterService_GetDiag_ErrorLogging(t *testing.T) { srv := errorServer() defer srv.Close() diff --git a/internal/service/script_utils.go b/internal/service/script_utils.go index 97b1b4e..2b0ed3f 100644 --- a/internal/service/script_utils.go +++ b/internal/service/script_utils.go @@ -123,3 +123,28 @@ func RandomSuffix(n int) string { } return string(b) } + +func ResolveSchedulingMap(field string, task *model.Task) string { + switch field { + case "cpus": + return derefInt32ToStr(task.Cpus) + case "memory_per_node": + return derefInt64ToStr(task.MemoryPerNode) + case "memory_per_cpu": + return derefInt64ToStr(task.MemoryPerCpu) + case "nodes": + return derefStr(task.Nodes) + case "tasks": + return derefInt32ToStr(task.Tasks) + case "cpus_per_task": + return derefInt32ToStr(task.CpusPerTask) + case "partition": + return task.Partition + case "time_limit": + return derefInt32ToStr(task.TimeLimit) + case "qos": + return derefStr(task.QOS) + default: + return "" + } +} diff --git a/internal/service/script_utils_scheduling_test.go b/internal/service/script_utils_scheduling_test.go new file mode 100644 index 0000000..8e209c5 --- /dev/null +++ b/internal/service/script_utils_scheduling_test.go @@ -0,0 +1,72 @@ +package service + +import ( + "testing" + + "gcy_hpc_server/internal/model" +) + +func strPtr(v string) *string { return &v } + +func TestResolveSchedulingMap(t *testing.T) { + cpus := int32Ptr(8) + memPerNode := int64Ptr(4096) + memPerCpu := int64Ptr(512) + nodes := strPtr("2-4") + tasks := int32Ptr(4) + cpusPerTask := int32Ptr(2) + timeLimit := int32Ptr(60) + qos := strPtr("high") + + task := &model.Task{ + Partition: "gpu", + Cpus: cpus, + MemoryPerNode: memPerNode, + MemoryPerCpu: memPerCpu, + Nodes: nodes, + Tasks: tasks, + CpusPerTask: cpusPerTask, + TimeLimit: timeLimit, + QOS: qos, + } + + tests := []struct { + field string + want string + }{ + {"cpus", "8"}, + {"memory_per_node", "4096"}, + {"memory_per_cpu", "512"}, + {"nodes", "2-4"}, + {"tasks", "4"}, + {"cpus_per_task", "2"}, + {"partition", "gpu"}, + {"time_limit", "60"}, + {"qos", "high"}, + {"unknown_field", ""}, + } + + for _, tt := range tests { + t.Run(tt.field, func(t *testing.T) { + got := ResolveSchedulingMap(tt.field, task) + if got != tt.want { + t.Errorf("ResolveSchedulingMap(%q) = %q, want %q", tt.field, got, tt.want) + } + }) + } +} + +func TestResolveSchedulingMap_NilFields(t *testing.T) { + // All scheduling fields are nil/empty — should return empty strings + task := &model.Task{} + for _, field := range []string{"cpus", "memory_per_node", "memory_per_cpu", "nodes", "tasks", "cpus_per_task", "time_limit", "qos"} { + got := ResolveSchedulingMap(field, task) + if got != "" { + t.Errorf("ResolveSchedulingMap(%q) with nil fields = %q, want empty", field, got) + } + } + // partition is a plain string, not a pointer — empty string is the zero value + if got := ResolveSchedulingMap("partition", task); got != "" { + t.Errorf("ResolveSchedulingMap(partition) = %q, want empty", got) + } +} diff --git a/internal/service/task_service.go b/internal/service/task_service.go index c922a08..993e6ac 100644 --- a/internal/service/task_service.go +++ b/internal/service/task_service.go @@ -369,6 +369,18 @@ func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error { } } + // 16b-3. Auto-inject scheduling params based on scheduling_map. + // If an Application parameter declares scheduling_map, the corresponding + // scheduling field value overrides any user-provided value. + for _, p := range params { + if p.SchedulingMap == "" { + continue + } + if val := ResolveSchedulingMap(p.SchedulingMap, task); val != "" { + values[p.Name] = val + } + } + // 16c. Validate all params (WORK_DIR and file params now have values). if err := ValidateParams(params, values); err != nil { return fail(model.TaskStepSubmitting, err.Error()) diff --git a/internal/service/task_service_test.go b/internal/service/task_service_test.go index bc7e983..df02104 100644 --- a/internal/service/task_service_test.go +++ b/internal/service/task_service_test.go @@ -1266,3 +1266,49 @@ func TestProcessTask_PartialSchedulingParams(t *testing.T) { t.Errorf("KillOnNodeFail = %v, want nil", j.KillOnNodeFail) } } + +func TestTaskService_ProcessTask_SchedulingMapInjection(t *testing.T) { + jobID := int32(42) + + var capturedReq slurm.JobSubmitReq + + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&capturedReq); err != nil { + t.Fatalf("decode request body: %v", err) + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + params := json.RawMessage(`[ + {"name": "NP", "type": "integer", "scheduling_map": "cpus", "required": true} + ]`) + appID := env.createApp(t, "sched-map-app", "#!/bin/bash\nmpirun -np $NP my_app", params) + + cpus := int32(8) + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + TaskName: "sched-map-test", + Cpus: &cpus, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + if err := env.svc.ProcessTask(context.Background(), task.ID); err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + if capturedReq.Script == nil { + t.Fatal("submitted script is nil") + } + + if !strings.Contains(*capturedReq.Script, "'8'") { + t.Errorf("rendered script does not contain shell-escaped scheduling value:\n%s", *capturedReq.Script) + } + if !strings.Contains(*capturedReq.Script, "mpirun -np '8'") { + t.Errorf("rendered script does not contain expected mpirun command:\n%s", *capturedReq.Script) + } +} diff --git a/web/src/types/tasks.ts b/web/src/types/tasks.ts index cff37e8..f44c20f 100644 --- a/web/src/types/tasks.ts +++ b/web/src/types/tasks.ts @@ -6,6 +6,7 @@ export interface ParameterSchema { default?: string options?: string[] description?: string + scheduling_map?: string } export interface Application { diff --git a/web/src/views/Tasks/Submit.vue b/web/src/views/Tasks/Submit.vue index 7f7b6e4..e864b80 100644 --- a/web/src/views/Tasks/Submit.vue +++ b/web/src/views/Tasks/Submit.vue @@ -124,7 +124,7 @@ const selectedApp = computed(() => appList.value.find(a => a.id === selectedAppI const autoParams = new Set(['WORK_DIR']) const visibleParams = computed(() => - (selectedApp.value?.parameters || []).filter((p: any) => !autoParams.has(p.name)) + (selectedApp.value?.parameters || []).filter((p: any) => !autoParams.has(p.name) && !p.scheduling_map) ) const fileParams = computed(() => @@ -159,12 +159,32 @@ onMounted(async () => { } }) +const resolveSchedMapValue = (mapField: string): string | undefined => { + switch (mapField) { + case 'cpus': return form.cpus != null ? String(form.cpus) : undefined + case 'memory_per_node': return form.memory_per_node != null ? String(form.memory_per_node) : undefined + case 'nodes': return form.nodes || undefined + case 'tasks': return form.tasks != null ? String(form.tasks) : undefined + case 'cpus_per_task': return form.cpus_per_task != null ? String(form.cpus_per_task) : undefined + case 'partition': return form.partition || undefined + default: return undefined + } +} + const handleSubmit = async () => { if (!selectedAppId.value) { ElMessage.warning('请选择应用'); return } submitting.value = true try { const taskName = form.task_name.trim() || `task_${selectedAppId.value}_${Date.now()}` const mergedValues = { ...values.value, ...fileParamMapping.value } + // Auto-inject scheduling_map values + const schedParams = (selectedApp.value?.parameters || []).filter((p: any) => p.scheduling_map) + for (const p of schedParams) { + const val = resolveSchedMapValue(p.scheduling_map) + if (val !== undefined && val !== '') { + mergedValues[p.name] = val + } + } const resp = await createTask({ ...form, task_name: taskName, job_name: taskName, app_id: selectedAppId.value, values: mergedValues, file_ids: selectedFiles.value.map(f => f.id) }) if (resp.success) { ElMessage.success('任务提交成功')