diff --git a/internal/service/application_service_test.go b/internal/service/application_service_test.go index 9683a87..ae0a080 100644 --- a/internal/service/application_service_test.go +++ b/internal/service/application_service_test.go @@ -104,6 +104,89 @@ func TestValidateParams_BooleanValues(t *testing.T) { } } +func TestValidateParams_FileTypeValid(t *testing.T) { + params := []model.ParameterSchema{ + {Name: "MODEL", Type: model.ParamTypeFile, Required: true}, + } + values := map[string]string{"MODEL": "12345"} + if err := ValidateParams(params, values); err != nil { + t.Errorf("expected no error for valid file ID, got %v", err) + } +} + +func TestValidateParams_FileTypeInvalid(t *testing.T) { + params := []model.ParameterSchema{ + {Name: "MODEL", Type: model.ParamTypeFile, Required: true}, + } + values := map[string]string{"MODEL": "not_a_number"} + err := ValidateParams(params, values) + if err == nil { + t.Fatal("expected error for non-numeric file ID") + } + if !strings.Contains(err.Error(), "file ID") { + t.Errorf("error should mention 'file ID', got: %v", err) + } +} + +func TestValidateParams_DirectoryTypeValid(t *testing.T) { + params := []model.ParameterSchema{ + {Name: "DATA_DIR", Type: model.ParamTypeDirectory, Required: true}, + } + values := map[string]string{"DATA_DIR": "99"} + if err := ValidateParams(params, values); err != nil { + t.Errorf("expected no error for valid directory ID, got %v", err) + } +} + +func TestValidateParams_DirectoryTypeInvalid(t *testing.T) { + params := []model.ParameterSchema{ + {Name: "DATA_DIR", Type: model.ParamTypeDirectory, Required: true}, + } + values := map[string]string{"DATA_DIR": "abc"} + err := ValidateParams(params, values) + if err == nil { + t.Fatal("expected error for non-numeric directory ID") + } + if !strings.Contains(err.Error(), "file ID") { + t.Errorf("error should mention 'file ID', got: %v", err) + } +} + +func TestRenderScript_FileTypeNotEscaped(t *testing.T) { + params := []model.ParameterSchema{{Name: "MODEL_PATH", Type: model.ParamTypeFile}} + values := map[string]string{"MODEL_PATH": "model_v2.bin"} + result := RenderScript("python train.py --model $MODEL_PATH", params, values) + expected := "python train.py --model model_v2.bin" + if result != expected { + t.Errorf("got %q, want %q", result, expected) + } +} + +func TestRenderScript_WorkDirInjected(t *testing.T) { + params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}} + values := map[string]string{ + "INPUT": "data.txt", + "WORK_DIR": "/data/work/myapp_20260101_abcd", + } + result := RenderScript("#SBATCH --chdir=$WORK_DIR\necho $INPUT", params, values) + if !strings.Contains(result, "#SBATCH --chdir=/data/work/myapp_20260101_abcd") { + t.Errorf("WORK_DIR should be replaced raw, got: %s", result) + } + if !strings.Contains(result, "'data.txt'") { + t.Errorf("INPUT should still be shell-escaped, got: %s", result) + } +} + +func TestRenderScript_DirectoryTypeNotEscaped(t *testing.T) { + params := []model.ParameterSchema{{Name: "DATA_DIR", Type: model.ParamTypeDirectory}} + values := map[string]string{"DATA_DIR": "input_folder"} + result := RenderScript("ls $DATA_DIR", params, values) + expected := "ls input_folder" + if result != expected { + t.Errorf("got %q, want %q", result, expected) + } +} + func TestRenderScript_SimpleReplacement(t *testing.T) { params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}} values := map[string]string{"INPUT": "data.txt"} diff --git a/internal/service/task_service_test.go b/internal/service/task_service_test.go index d87c70b..6360aa7 100644 --- a/internal/service/task_service_test.go +++ b/internal/service/task_service_test.go @@ -3,6 +3,7 @@ package service import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "os" @@ -536,3 +537,139 @@ func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID) } } + +func TestTaskService_ProcessTask_FileParamResolution(t *testing.T) { + jobID := int32(88) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + ctx := context.Background() + + blob := &model.FileBlob{SHA256: "deadbeef", MinioKey: "test/file.bin", FileSize: 1024} + if err := env.db.Create(blob).Error; err != nil { + t.Fatalf("create blob: %v", err) + } + file := &model.File{Name: "model_weights.bin", BlobSHA256: blob.SHA256} + if err := env.fileStore.Create(ctx, file); err != nil { + t.Fatalf("create file: %v", err) + } + + appID := env.createApp(t, "file-param-app", + "#!/bin/bash\n#SBATCH --chdir=$WORK_DIR\npython train.py --model $MODEL_PATH", + json.RawMessage(`[{"name":"MODEL_PATH","type":"file","required":true}]`)) + + // No InputFileIDs because stagingSvc is nil in tests; file is still resolvable via fileStore. + task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{"MODEL_PATH": fmt.Sprintf("%d", file.ID)}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(ctx, task.ID) + if err != nil { + t.Fatalf("ProcessTask with file param: %v", err) + } + + updated, _ := env.taskStore.GetByID(ctx, task.ID) + if updated.Status != model.TaskStatusQueued { + t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) + } +} + +func TestTaskService_ProcessTask_WorkDirInScript(t *testing.T) { + jobID := int32(77) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "workdir-app", + "#!/bin/bash\n#SBATCH --chdir=$WORK_DIR\necho hello", + json.RawMessage(`[]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + updated, _ := env.taskStore.GetByID(context.Background(), task.ID) + if updated.WorkDir == "" { + t.Fatal("WorkDir should be set") + } +} + +func TestTaskService_ProcessTask_FileParamInvalidID(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "bad-file-app", + "#!/bin/bash\npython train.py --model $MODEL", + json.RawMessage(`[{"name":"MODEL","type":"file","required":true}]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{"MODEL": "not_a_number"}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err == nil { + t.Fatal("expected error for non-numeric file ID in file-type param") + } + if !strings.Contains(err.Error(), "file ID") { + t.Errorf("error should mention 'file ID', got: %v", err) + } +} + +func TestTaskService_ProcessTask_FileParamNotInInputFiles(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "missing-file-app", + "#!/bin/bash\npython train.py --model $MODEL", + json.RawMessage(`[{"name":"MODEL","type":"file","required":true}]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{"MODEL": "999"}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err == nil { + t.Fatal("expected error for file_id not in input_file_ids") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error should mention 'not found', got: %v", err) + } +}