From ec64300ff2d01dab35fd79b7e24cf0415af66d17 Mon Sep 17 00:00:00 2001 From: dailz Date: Wed, 15 Apr 2026 21:31:02 +0800 Subject: [PATCH] feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission --- internal/service/application_service.go | 139 +---- internal/service/application_service_test.go | 100 +++- internal/service/file_staging_service.go | 145 +++++ internal/service/file_staging_service_test.go | 232 ++++++++ internal/service/script_utils.go | 112 ++++ internal/service/task_service.go | 554 ++++++++++++++++++ internal/service/task_service_async_test.go | 416 +++++++++++++ internal/service/task_service_status_test.go | 294 ++++++++++ internal/service/task_service_test.go | 538 +++++++++++++++++ 9 files changed, 2394 insertions(+), 136 deletions(-) create mode 100644 internal/service/file_staging_service.go create mode 100644 internal/service/file_staging_service_test.go create mode 100644 internal/service/script_utils.go create mode 100644 internal/service/task_service.go create mode 100644 internal/service/task_service_async_test.go create mode 100644 internal/service/task_service_status_test.go create mode 100644 internal/service/task_service_test.go diff --git a/internal/service/application_service.go b/internal/service/application_service.go index 9034177..ffa3615 100644 --- a/internal/service/application_service.go +++ b/internal/service/application_service.go @@ -4,13 +4,8 @@ import ( "context" "encoding/json" "fmt" - "math/rand" "os" "path/filepath" - "regexp" - "sort" - "strconv" - "strings" "time" "gcy_hpc_server/internal/model" @@ -19,8 +14,6 @@ import ( "go.uber.org/zap" ) -var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) - // ApplicationService handles parameter validation, script rendering, and job // submission for parameterized HPC applications. type ApplicationService struct { @@ -28,92 +21,15 @@ type ApplicationService struct { jobSvc *JobService workDirBase string logger *zap.Logger + taskSvc *TaskService } -func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger) *ApplicationService { - return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger} -} - -// ValidateParams checks that all required parameters are present and values match their types. -// Parameters not in the schema are silently ignored. -func (s *ApplicationService) ValidateParams(params []model.ParameterSchema, values map[string]string) error { - var errs []string - - for _, p := range params { - if !paramNameRegex.MatchString(p.Name) { - errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name)) - continue - } - - val, ok := values[p.Name] - - if p.Required && !ok { - errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name)) - continue - } - - if !ok { - continue - } - - switch p.Type { - case model.ParamTypeInteger: - if _, err := strconv.Atoi(val); err != nil { - errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val)) - } - case model.ParamTypeBoolean: - if val != "true" && val != "false" && val != "1" && val != "0" { - errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val)) - } - case model.ParamTypeEnum: - if len(p.Options) > 0 { - found := false - for _, opt := range p.Options { - if val == opt { - found = true - break - } - } - if !found { - errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val)) - } - } - case model.ParamTypeFile, model.ParamTypeDirectory: - case model.ParamTypeString: - } +func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger, taskSvc ...*TaskService) *ApplicationService { + var ts *TaskService + if len(taskSvc) > 0 { + ts = taskSvc[0] } - - if len(errs) > 0 { - return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; ")) - } - return nil -} - -// RenderScript replaces $PARAM tokens in the template with user-provided values. -// Only tokens defined in the schema are replaced. Replacement is done longest-name-first -// to avoid partial matches (e.g., $JOB_NAME before $JOB). -// All values are shell-escaped using single-quote wrapping. -func (s *ApplicationService) RenderScript(template string, params []model.ParameterSchema, values map[string]string) string { - sorted := make([]model.ParameterSchema, len(params)) - copy(sorted, params) - sort.Slice(sorted, func(i, j int) bool { - return len(sorted[i].Name) > len(sorted[j].Name) - }) - - result := template - for _, p := range sorted { - val, ok := values[p.Name] - if !ok { - if p.Default != "" { - val = p.Default - } else { - continue - } - } - escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'" - result = strings.ReplaceAll(result, "$"+p.Name, escaped) - } - return result + return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger, taskSvc: ts} } // ListApplications delegates to the store. @@ -141,13 +57,22 @@ func (s *ApplicationService) DeleteApplication(ctx context.Context, id int64) er return s.store.Delete(ctx, id) } -// SubmitFromApplication orchestrates the full submission flow: -// 1. Fetch application by ID -// 2. Parse parameters schema -// 3. Validate parameter values -// 4. Render script template -// 5. Submit job via JobService +// SubmitFromApplication orchestrates the full submission flow. +// When TaskService is available, it delegates to ProcessTaskSync which creates +// an hpc_tasks record and runs the full pipeline. Otherwise falls back to the +// original direct implementation. func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicationID int64, values map[string]string) (*model.JobResponse, error) { + if s.taskSvc != nil { + req := &model.CreateTaskRequest{ + AppID: applicationID, + Values: values, + InputFileIDs: nil, // old API has no file_ids concept + TaskName: "", + } + return s.taskSvc.ProcessTaskSync(ctx, req) + } + + // Fallback: original direct logic when TaskService not available app, err := s.store.GetByID(ctx, applicationID) if err != nil { return nil, fmt.Errorf("get application: %w", err) @@ -163,16 +88,16 @@ func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicat } } - if err := s.ValidateParams(params, values); err != nil { + if err := ValidateParams(params, values); err != nil { return nil, err } - rendered := s.RenderScript(app.ScriptTemplate, params, values) + rendered := RenderScript(app.ScriptTemplate, params, values) workDir := "" if s.workDirBase != "" { - safeName := sanitizeDirName(app.Name) - subDir := time.Now().Format("20060102_150405") + "_" + randomSuffix(4) + safeName := SanitizeDirName(app.Name) + subDir := time.Now().Format("20060102_150405") + "_" + RandomSuffix(4) workDir = filepath.Join(s.workDirBase, safeName, subDir) if err := os.MkdirAll(workDir, 0777); err != nil { return nil, fmt.Errorf("create work directory %s: %w", workDir, err) @@ -187,17 +112,3 @@ func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicat req := &model.SubmitJobRequest{Script: rendered, WorkDir: workDir} return s.jobSvc.SubmitJob(ctx, req) } - -func sanitizeDirName(name string) string { - replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_") - return replacer.Replace(name) -} - -func randomSuffix(n int) string { - const charset = "abcdefghijklmnopqrstuvwxyz0123456789" - b := make([]byte, n) - for i := range b { - b[i] = charset[rand.Intn(len(charset))] - } - return string(b) -} diff --git a/internal/service/application_service_test.go b/internal/service/application_service_test.go index 409258f..e46362a 100644 --- a/internal/service/application_service_test.go +++ b/internal/service/application_service_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" @@ -42,24 +44,22 @@ func setupApplicationService(t *testing.T, slurmHandler http.HandlerFunc) (*Appl } func TestValidateParams_AllRequired(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{ {Name: "NAME", Type: model.ParamTypeString, Required: true}, {Name: "COUNT", Type: model.ParamTypeInteger, Required: true}, } values := map[string]string{"NAME": "hello", "COUNT": "5"} - if err := svc.ValidateParams(params, values); err != nil { + if err := ValidateParams(params, values); err != nil { t.Errorf("expected no error, got %v", err) } } func TestValidateParams_MissingRequired(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{ {Name: "NAME", Type: model.ParamTypeString, Required: true}, } values := map[string]string{} - err := svc.ValidateParams(params, values) + err := ValidateParams(params, values) if err == nil { t.Fatal("expected error for missing required param") } @@ -69,12 +69,11 @@ func TestValidateParams_MissingRequired(t *testing.T) { } func TestValidateParams_InvalidInteger(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{ {Name: "COUNT", Type: model.ParamTypeInteger, Required: true}, } values := map[string]string{"COUNT": "abc"} - err := svc.ValidateParams(params, values) + err := ValidateParams(params, values) if err == nil { t.Fatal("expected error for invalid integer") } @@ -84,12 +83,11 @@ func TestValidateParams_InvalidInteger(t *testing.T) { } func TestValidateParams_InvalidEnum(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{ {Name: "MODE", Type: model.ParamTypeEnum, Required: true, Options: []string{"fast", "slow"}}, } values := map[string]string{"MODE": "medium"} - err := svc.ValidateParams(params, values) + err := ValidateParams(params, values) if err == nil { t.Fatal("expected error for invalid enum value") } @@ -99,12 +97,11 @@ func TestValidateParams_InvalidEnum(t *testing.T) { } func TestValidateParams_BooleanValues(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{ {Name: "FLAG", Type: model.ParamTypeBoolean, Required: true}, } for _, val := range []string{"true", "false", "1", "0"} { - err := svc.ValidateParams(params, map[string]string{"FLAG": val}) + err := ValidateParams(params, map[string]string{"FLAG": val}) if err != nil { t.Errorf("boolean value %q should be valid, got error: %v", val, err) } @@ -112,10 +109,9 @@ func TestValidateParams_BooleanValues(t *testing.T) { } func TestRenderScript_SimpleReplacement(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}} values := map[string]string{"INPUT": "data.txt"} - result := svc.RenderScript("echo $INPUT", params, values) + result := RenderScript("echo $INPUT", params, values) expected := "echo 'data.txt'" if result != expected { t.Errorf("got %q, want %q", result, expected) @@ -123,10 +119,9 @@ func TestRenderScript_SimpleReplacement(t *testing.T) { } func TestRenderScript_DefaultValues(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{{Name: "OUTPUT", Type: model.ParamTypeString, Default: "out.log"}} values := map[string]string{} - result := svc.RenderScript("cat $OUTPUT", params, values) + result := RenderScript("cat $OUTPUT", params, values) expected := "cat 'out.log'" if result != expected { t.Errorf("got %q, want %q", result, expected) @@ -134,10 +129,9 @@ func TestRenderScript_DefaultValues(t *testing.T) { } func TestRenderScript_PreservesUnknownVars(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}} values := map[string]string{"INPUT": "data.txt"} - result := svc.RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values) + result := RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values) if !strings.Contains(result, "$HOME") { t.Error("$HOME should be preserved") } @@ -150,7 +144,6 @@ func TestRenderScript_PreservesUnknownVars(t *testing.T) { } func TestRenderScript_ShellEscaping(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}} tests := []struct { @@ -165,7 +158,7 @@ func TestRenderScript_ShellEscaping(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value}) + result := RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value}) if result != tt.expected { t.Errorf("got %q, want %q", result, tt.expected) } @@ -174,14 +167,13 @@ func TestRenderScript_ShellEscaping(t *testing.T) { } func TestRenderScript_OverlappingParams(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) template := "$JOB_NAME and $JOB" params := []model.ParameterSchema{ {Name: "JOB", Type: model.ParamTypeString}, {Name: "JOB_NAME", Type: model.ParamTypeString}, } values := map[string]string{"JOB": "myjob", "JOB_NAME": "my-test-job"} - result := svc.RenderScript(template, params, values) + result := RenderScript(template, params, values) if strings.Contains(result, "$JOB_NAME") { t.Error("$JOB_NAME was not replaced") } @@ -197,10 +189,9 @@ func TestRenderScript_OverlappingParams(t *testing.T) { } func TestRenderScript_NewlineInValue(t *testing.T) { - svc := NewApplicationService(nil, nil, "", zap.NewNop()) params := []model.ParameterSchema{{Name: "CMD", Type: model.ParamTypeString}} values := map[string]string{"CMD": "line1\nline2"} - result := svc.RenderScript("echo $CMD", params, values) + result := RenderScript("echo $CMD", params, values) expected := "echo 'line1\nline2'" if result != expected { t.Errorf("got %q, want %q", result, expected) @@ -298,3 +289,68 @@ func TestSubmitFromApplication_NoParameters(t *testing.T) { t.Errorf("JobID = %d, want 99", resp.JobID) } } + +func TestSubmitFromApplication_DelegatesToTaskService(t *testing.T) { + jobID := int32(77) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer srv.Close() + + client, _ := slurm.NewClient(srv.URL, srv.Client()) + jobSvc := NewJobService(client, zap.NewNop()) + + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: gormlogger.Default.LogMode(gormlogger.Silent), + }) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Application{}, &model.Task{}, &model.File{}, &model.FileBlob{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + + appStore := store.NewApplicationStore(db) + taskStore := store.NewTaskStore(db) + fileStore := store.NewFileStore(db) + blobStore := store.NewBlobStore(db) + + workDirBase := filepath.Join(t.TempDir(), "workdir") + os.MkdirAll(workDirBase, 0777) + + taskSvc := NewTaskService(taskStore, appStore, fileStore, blobStore, nil, jobSvc, workDirBase, zap.NewNop()) + appSvc := NewApplicationService(appStore, jobSvc, workDirBase, zap.NewNop(), taskSvc) + + id, err := appStore.Create(context.Background(), &model.CreateApplicationRequest{ + Name: "delegated-app", + ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT", + Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`), + }) + if err != nil { + t.Fatalf("create app: %v", err) + } + + resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{ + "JOB_NAME": "delegated-job", + "INPUT": "test-data", + }) + if err != nil { + t.Fatalf("SubmitFromApplication() error = %v", err) + } + if resp.JobID != 77 { + t.Errorf("JobID = %d, want 77", resp.JobID) + } + + var task model.Task + if err := db.Where("app_id = ?", id).First(&task).Error; err != nil { + t.Fatalf("no hpc_tasks record found for app_id %d: %v", id, err) + } + if task.SlurmJobID == nil || *task.SlurmJobID != 77 { + t.Errorf("task SlurmJobID = %v, want 77", task.SlurmJobID) + } + if task.Status != model.TaskStatusQueued { + t.Errorf("task Status = %q, want %q", task.Status, model.TaskStatusQueued) + } +} diff --git a/internal/service/file_staging_service.go b/internal/service/file_staging_service.go new file mode 100644 index 0000000..5348f52 --- /dev/null +++ b/internal/service/file_staging_service.go @@ -0,0 +1,145 @@ +package service + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/storage" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" +) + +// FileStagingService batch downloads files from MinIO to a local (NFS) directory, +// deduplicating by blob SHA256 so each unique blob is fetched only once. +type FileStagingService struct { + fileStore *store.FileStore + blobStore *store.BlobStore + storage storage.ObjectStorage + bucket string + logger *zap.Logger +} + +func NewFileStagingService(fileStore *store.FileStore, blobStore *store.BlobStore, st storage.ObjectStorage, bucket string, logger *zap.Logger) *FileStagingService { + return &FileStagingService{ + fileStore: fileStore, + blobStore: blobStore, + storage: st, + bucket: bucket, + logger: logger, + } +} + +// DownloadFilesToDir downloads the given files into destDir. +// Files sharing the same blob SHA256 are deduplicated: the blob is fetched once +// and then copied to each filename. Filenames are sanitized with filepath.Base +// to prevent path traversal. +func (s *FileStagingService) DownloadFilesToDir(ctx context.Context, fileIDs []int64, destDir string) error { + if len(fileIDs) == 0 { + return nil + } + + files, err := s.fileStore.GetByIDs(ctx, fileIDs) + if err != nil { + return fmt.Errorf("fetch files: %w", err) + } + + type group struct { + primary *model.File // first file — written via io.Copy from MinIO + others []*model.File // remaining files — local copy of primary + } + groups := make(map[string]*group) + for i := range files { + f := &files[i] + g, ok := groups[f.BlobSHA256] + if !ok { + groups[f.BlobSHA256] = &group{primary: f} + } else { + g.others = append(g.others, f) + } + } + + sha256s := make([]string, 0, len(groups)) + for sh := range groups { + sha256s = append(sha256s, sh) + } + blobs, err := s.blobStore.GetBySHA256s(ctx, sha256s) + if err != nil { + return fmt.Errorf("fetch blobs: %w", err) + } + + blobMap := make(map[string]*model.FileBlob, len(blobs)) + for i := range blobs { + blobMap[blobs[i].SHA256] = &blobs[i] + } + + for sha256, g := range groups { + blob, ok := blobMap[sha256] + if !ok { + return fmt.Errorf("blob %s not found", sha256) + } + + reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{}) + if err != nil { + return fmt.Errorf("get object %s: %w", blob.MinioKey, err) + } + + // TODO: handle filename collisions when multiple files have the same Name (low risk without user auth, revisit when auth is added) + primaryName := filepath.Base(g.primary.Name) + primaryPath := filepath.Join(destDir, primaryName) + + if err := writeFile(primaryPath, reader); err != nil { + reader.Close() + os.Remove(primaryPath) + return fmt.Errorf("write file %s: %w", primaryName, err) + } + reader.Close() + + for _, other := range g.others { + otherName := filepath.Base(other.Name) + otherPath := filepath.Join(destDir, otherName) + + if err := copyFile(primaryPath, otherPath); err != nil { + return fmt.Errorf("copy %s to %s: %w", primaryName, otherName, err) + } + } + } + + return nil +} + +func writeFile(path string, reader io.Reader) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(f, reader); err != nil { + return err + } + return nil +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + if _, err := io.Copy(out, in); err != nil { + return err + } + return nil +} diff --git a/internal/service/file_staging_service_test.go b/internal/service/file_staging_service_test.go new file mode 100644 index 0000000..75378ea --- /dev/null +++ b/internal/service/file_staging_service_test.go @@ -0,0 +1,232 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/storage" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type stagingMockStorage struct { + getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) +} + +func (m *stagingMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) { + if m.getObjectFn != nil { + return m.getObjectFn(ctx, bucket, key, opts) + } + return nil, storage.ObjectInfo{}, nil +} + +func (m *stagingMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) { + return storage.UploadInfo{}, nil +} +func (m *stagingMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) { + return storage.UploadInfo{}, nil +} +func (m *stagingMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error { + return nil +} +func (m *stagingMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error { + return nil +} +func (m *stagingMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error { + return nil +} +func (m *stagingMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) { + return nil, nil +} +func (m *stagingMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error { + return nil +} +func (m *stagingMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) { + return true, nil +} +func (m *stagingMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error { + return nil +} +func (m *stagingMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) { + return storage.ObjectInfo{}, nil +} + +func setupStagingTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func newStagingService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *FileStagingService { + t.Helper() + return NewFileStagingService( + store.NewFileStore(db), + store.NewBlobStore(db), + st, + "test-bucket", + zap.NewNop(), + ) +} + +func TestFileStaging_DownloadWithDedup(t *testing.T) { + db := setupStagingTestDB(t) + + sha1 := "aaa111" + sha2 := "bbb222" + + db.Create(&model.FileBlob{SHA256: sha1, MinioKey: "blobs/aaa111", FileSize: 5, MimeType: "text/plain", RefCount: 2}) + db.Create(&model.FileBlob{SHA256: sha2, MinioKey: "blobs/bbb222", FileSize: 3, MimeType: "text/plain", RefCount: 1}) + + db.Create(&model.File{Name: "file1.txt", BlobSHA256: sha1}) + db.Create(&model.File{Name: "file2.txt", BlobSHA256: sha1}) + db.Create(&model.File{Name: "file3.txt", BlobSHA256: sha2}) + + var files []model.File + db.Find(&files) + if len(files) < 3 { + t.Fatalf("need 3 files, got %d", len(files)) + } + + var getObjCalls int32 + st := &stagingMockStorage{} + st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) { + atomic.AddInt32(&getObjCalls, 1) + var content string + switch key { + case "blobs/aaa111": + content = "content-a" + case "blobs/bbb222": + content = "content-b" + default: + return nil, storage.ObjectInfo{}, fmt.Errorf("unexpected key %s", key) + } + return io.NopCloser(bytes.NewReader([]byte(content))), storage.ObjectInfo{Key: key}, nil + } + + destDir := t.TempDir() + svc := newStagingService(t, st, db) + + err := svc.DownloadFilesToDir(context.Background(), []int64{files[0].ID, files[1].ID, files[2].ID}, destDir) + if err != nil { + t.Fatalf("DownloadFilesToDir: %v", err) + } + + if calls := atomic.LoadInt32(&getObjCalls); calls != 2 { + t.Errorf("GetObject called %d times, want 2", calls) + } + + expected := map[string]string{ + "file1.txt": "content-a", + "file2.txt": "content-a", + "file3.txt": "content-b", + } + for name, want := range expected { + p := filepath.Join(destDir, name) + data, err := os.ReadFile(p) + if err != nil { + t.Errorf("read %s: %v", name, err) + continue + } + if string(data) != want { + t.Errorf("%s content = %q, want %q", name, data, want) + } + } +} + +func TestFileStaging_PathTraversal(t *testing.T) { + db := setupStagingTestDB(t) + + sha := "traversal123" + db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/traversal", FileSize: 4, MimeType: "text/plain", RefCount: 1}) + db.Create(&model.File{Name: "../../../etc/passwd", BlobSHA256: sha}) + + var file model.File + db.First(&file) + + st := &stagingMockStorage{} + st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) { + return io.NopCloser(bytes.NewReader([]byte("safe"))), storage.ObjectInfo{Key: key}, nil + } + + destDir := t.TempDir() + svc := newStagingService(t, st, db) + + err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir) + if err != nil { + t.Fatalf("DownloadFilesToDir: %v", err) + } + + sanitized := filepath.Join(destDir, "passwd") + data, err := os.ReadFile(sanitized) + if err != nil { + t.Fatalf("read sanitized file: %v", err) + } + if string(data) != "safe" { + t.Errorf("content = %q, want %q", data, "safe") + } + + entries, err := os.ReadDir(destDir) + if err != nil { + t.Fatalf("readdir: %v", err) + } + for _, e := range entries { + if e.Name() != "passwd" { + t.Errorf("unexpected file in destDir: %s", e.Name()) + } + } +} + +func TestFileStaging_EmptyList(t *testing.T) { + db := setupStagingTestDB(t) + st := &stagingMockStorage{} + svc := newStagingService(t, st, db) + + err := svc.DownloadFilesToDir(context.Background(), []int64{}, t.TempDir()) + if err != nil { + t.Errorf("expected nil for empty list, got %v", err) + } +} + +func TestFileStaging_GetObjectFails(t *testing.T) { + db := setupStagingTestDB(t) + + sha := "fail123" + db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/fail", FileSize: 5, MimeType: "text/plain", RefCount: 1}) + db.Create(&model.File{Name: "willfail.txt", BlobSHA256: sha}) + + var file model.File + db.First(&file) + + st := &stagingMockStorage{} + st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) { + return nil, storage.ObjectInfo{}, fmt.Errorf("minio down") + } + + destDir := t.TempDir() + svc := newStagingService(t, st, db) + + err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir) + if err == nil { + t.Fatal("expected error when GetObject fails") + } + if !strings.Contains(err.Error(), "minio down") { + t.Errorf("error = %q, want 'minio down'", err.Error()) + } +} diff --git a/internal/service/script_utils.go b/internal/service/script_utils.go new file mode 100644 index 0000000..71df455 --- /dev/null +++ b/internal/service/script_utils.go @@ -0,0 +1,112 @@ +package service + +import ( + "fmt" + "math/rand" + "regexp" + "sort" + "strconv" + "strings" + + "gcy_hpc_server/internal/model" +) + +var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// ValidateParams checks that all required parameters are present and values match their types. +// Parameters not in the schema are silently ignored. +func ValidateParams(params []model.ParameterSchema, values map[string]string) error { + var errs []string + + for _, p := range params { + if !paramNameRegex.MatchString(p.Name) { + errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name)) + continue + } + + val, ok := values[p.Name] + + if p.Required && !ok { + errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name)) + continue + } + + if !ok { + continue + } + + switch p.Type { + case model.ParamTypeInteger: + if _, err := strconv.Atoi(val); err != nil { + errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val)) + } + case model.ParamTypeBoolean: + if val != "true" && val != "false" && val != "1" && val != "0" { + errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val)) + } + case model.ParamTypeEnum: + if len(p.Options) > 0 { + found := false + for _, opt := range p.Options { + if val == opt { + found = true + break + } + } + if !found { + errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val)) + } + } + case model.ParamTypeFile, model.ParamTypeDirectory: + case model.ParamTypeString: + } + } + + if len(errs) > 0 { + return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; ")) + } + return nil +} + +// RenderScript replaces $PARAM tokens in the template with user-provided values. +// Only tokens defined in the schema are replaced. Replacement is done longest-name-first +// to avoid partial matches (e.g., $JOB_NAME before $JOB). +// All values are shell-escaped using single-quote wrapping. +func RenderScript(template string, params []model.ParameterSchema, values map[string]string) string { + sorted := make([]model.ParameterSchema, len(params)) + copy(sorted, params) + sort.Slice(sorted, func(i, j int) bool { + return len(sorted[i].Name) > len(sorted[j].Name) + }) + + result := template + for _, p := range sorted { + val, ok := values[p.Name] + if !ok { + if p.Default != "" { + val = p.Default + } else { + continue + } + } + escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'" + result = strings.ReplaceAll(result, "$"+p.Name, escaped) + } + return result +} + +// SanitizeDirName sanitizes a directory name. +func SanitizeDirName(name string) string { + replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_") + return replacer.Replace(name) +} + +// RandomSuffix generates a random suffix of length n. +func RandomSuffix(n int) string { + const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, n) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} diff --git a/internal/service/task_service.go b/internal/service/task_service.go new file mode 100644 index 0000000..7af492c --- /dev/null +++ b/internal/service/task_service.go @@ -0,0 +1,554 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" +) + +type TaskService struct { + taskStore *store.TaskStore + appStore *store.ApplicationStore + fileStore *store.FileStore // nil ok + blobStore *store.BlobStore // nil ok + stagingSvc *FileStagingService // nil ok — MinIO unavailable + jobSvc *JobService + workDirBase string + logger *zap.Logger + + // async processing + taskCh chan int64 // buffered channel, cap=16 + cancelFn context.CancelFunc + wg sync.WaitGroup + mu sync.Mutex // protects taskCh from send-on-closed + started bool // prevent double-start + stopped bool +} + +func NewTaskService( + taskStore *store.TaskStore, + appStore *store.ApplicationStore, + fileStore *store.FileStore, + blobStore *store.BlobStore, + stagingSvc *FileStagingService, + jobSvc *JobService, + workDirBase string, + logger *zap.Logger, +) *TaskService { + return &TaskService{ + taskStore: taskStore, + appStore: appStore, + fileStore: fileStore, + blobStore: blobStore, + stagingSvc: stagingSvc, + jobSvc: jobSvc, + workDirBase: workDirBase, + logger: logger, + taskCh: make(chan int64, 16), + } +} + +func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) { + app, err := s.appStore.GetByID(ctx, req.AppID) + if err != nil { + return nil, fmt.Errorf("get application: %w", err) + } + if app == nil { + return nil, fmt.Errorf("application %d not found", req.AppID) + } + + // 2. Validate file limit + if len(req.InputFileIDs) > 100 { + return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs)) + } + + // 3. Deduplicate file IDs + fileIDs := uniqueInt64s(req.InputFileIDs) + + // 4. Validate file IDs exist + if s.fileStore != nil && len(fileIDs) > 0 { + files, err := s.fileStore.GetByIDs(ctx, fileIDs) + if err != nil { + return nil, fmt.Errorf("validate file ids: %w", err) + } + found := make(map[int64]bool, len(files)) + for _, f := range files { + found[f.ID] = true + } + for _, id := range fileIDs { + if !found[id] { + return nil, fmt.Errorf("file %d not found", id) + } + } + } + + // 5. Auto-generate task name if empty + taskName := req.TaskName + if taskName == "" { + taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405") + } + + // 6. Marshal values + valuesJSON := json.RawMessage(`{}`) + if len(req.Values) > 0 { + b, err := json.Marshal(req.Values) + if err != nil { + return nil, fmt.Errorf("marshal values: %w", err) + } + valuesJSON = b + } + + // 7. Marshal input_file_ids + fileIDsJSON := json.RawMessage(`[]`) + if len(fileIDs) > 0 { + b, err := json.Marshal(fileIDs) + if err != nil { + return nil, fmt.Errorf("marshal file ids: %w", err) + } + fileIDsJSON = b + } + + // 8. Create task record + task := &model.Task{ + TaskName: taskName, + AppID: app.ID, + AppName: app.Name, + Status: model.TaskStatusSubmitted, + Values: valuesJSON, + InputFileIDs: fileIDsJSON, + SubmittedAt: time.Now(), + } + + taskID, err := s.taskStore.Create(ctx, task) + if err != nil { + return nil, fmt.Errorf("create task: %w", err) + } + task.ID = taskID + + return task, nil +} + +// ProcessTask runs the full synchronous processing pipeline for a task. +func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error { + // 1. Fetch task + task, err := s.taskStore.GetByID(ctx, taskID) + if err != nil { + return fmt.Errorf("get task: %w", err) + } + if task == nil { + return fmt.Errorf("task %d not found", taskID) + } + + fail := func(step, msg string) error { + _ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg) + _ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount) + return fmt.Errorf("%s", msg) + } + + currentStep := task.CurrentStep + + var workDir string + var app *model.Application + + if currentStep == "" || currentStep == model.TaskStepPreparing { + // 2. Set preparing + if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil { + return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err)) + } + + // 3. Fetch app + app, err = s.appStore.GetByID(ctx, task.AppID) + if err != nil { + return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err)) + } + if app == nil { + return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID)) + } + + // 4-5. Create work directory + workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4)) + if err := os.MkdirAll(workDir, 0777); err != nil { + return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err)) + } + + // 6. CHMOD traversal — critical for multi-user HPC + for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) { + os.Chmod(dir, 0777) + } + os.Chmod(s.workDirBase, 0777) + + // 7. UpdateWorkDir + if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil { + return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err)) + } + } else { + app, err = s.appStore.GetByID(ctx, task.AppID) + if err != nil { + return fail(currentStep, fmt.Sprintf("get application: %v", err)) + } + if app == nil { + return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID)) + } + workDir = task.WorkDir + } + + if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading { + if currentStep == model.TaskStepDownloading && workDir != "" { + matches, _ := filepath.Glob(filepath.Join(workDir, "*")) + for _, f := range matches { + os.Remove(f) + } + } + + // 8. Set downloading + if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil { + return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err)) + } + + // 9. Parse input_file_ids + var fileIDs []int64 + if len(task.InputFileIDs) > 0 { + if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil { + return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err)) + } + } + + // 10-12. Download files + if len(fileIDs) > 0 { + if s.stagingSvc == nil { + return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files") + } + if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil { + return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err)) + } + } + } + + // 13-14. Set ready + submitting + if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil { + return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err)) + } + + // 15. Parse app parameters + var params []model.ParameterSchema + if len(app.Parameters) > 0 { + if err := json.Unmarshal(app.Parameters, ¶ms); err != nil { + return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err)) + } + } + + // 16. Parse task values + values := make(map[string]string) + if len(task.Values) > 0 { + if err := json.Unmarshal(task.Values, &values); err != nil { + return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err)) + } + } + + if err := ValidateParams(params, values); err != nil { + return fail(model.TaskStepSubmitting, err.Error()) + } + + // 17. Render script + rendered := RenderScript(app.ScriptTemplate, params, values) + + // 18. Submit to Slurm + jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{ + Script: rendered, + WorkDir: workDir, + }) + if err != nil { + return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err)) + } + + // 19. Update slurm_job_id and status to queued + if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil { + return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err)) + } + if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil { + return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err)) + } + + return nil +} + +// ListTasks returns a paginated list of tasks. +func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) { + return s.taskStore.List(ctx, query) +} + +// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse +// for old API compatibility. +func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) { + // 1. Create task + task, err := s.CreateTask(ctx, req) + if err != nil { + return nil, err + } + + // 2. Process synchronously + if err := s.ProcessTask(ctx, task.ID); err != nil { + return nil, err + } + + // 3. Re-fetch to get updated slurm_job_id + task, err = s.taskStore.GetByID(ctx, task.ID) + if err != nil { + return nil, fmt.Errorf("re-fetch task: %w", err) + } + if task == nil || task.SlurmJobID == nil { + return nil, fmt.Errorf("task has no slurm job id after processing") + } + + // 4. Return JobResponse + return &model.JobResponse{JobID: *task.SlurmJobID}, nil +} + +// uniqueInt64s deduplicates and sorts a slice of int64. +func uniqueInt64s(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + seen := make(map[int64]bool, len(ids)) + result := make([]int64, 0, len(ids)) + for _, id := range ids { + if !seen[id] { + seen[id] = true + result = append(result, id) + } + } + sort.Slice(result, func(i, j int) bool { return result[i] < result[j] }) + return result +} + +func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string { + if len(slurmState) == 0 { + return model.TaskStatusRunning + } + + state := strings.ToUpper(slurmState[0]) + switch state { + case "PENDING": + return model.TaskStatusQueued + case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT": + return model.TaskStatusRunning + case "COMPLETED": + return model.TaskStatusCompleted + case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED": + return model.TaskStatusFailed + default: + return model.TaskStatusRunning + } +} + +func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error { + task, err := s.taskStore.GetByID(ctx, taskID) + if err != nil { + s.logger.Error("failed to fetch task for refresh", + zap.Int64("task_id", taskID), + zap.Error(err), + ) + return err + } + if task == nil || task.SlurmJobID == nil { + return nil + } + + jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10)) + if err != nil { + s.logger.Warn("failed to query slurm job status during refresh", + zap.Int64("task_id", taskID), + zap.Int32("slurm_job_id", *task.SlurmJobID), + zap.Error(err), + ) + return nil + } + if jobResp == nil { + return nil + } + + newStatus := s.mapSlurmStateToTaskStatus(jobResp.State) + if newStatus != task.Status { + s.logger.Info("updating task status from slurm", + zap.Int64("task_id", taskID), + zap.String("old_status", task.Status), + zap.String("new_status", newStatus), + ) + return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "") + } + return nil +} + +func (s *TaskService) RefreshStaleTasks(ctx context.Context) error { + staleThreshold := 30 * time.Second + nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning} + + for _, status := range nonTerminal { + tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{ + Status: status, + Page: 1, + PageSize: 1000, + }) + if err != nil { + s.logger.Warn("failed to list tasks for stale refresh", + zap.String("status", status), + zap.Error(err), + ) + continue + } + + cutoff := time.Now().Add(-staleThreshold) + for i := range tasks { + if tasks[i].UpdatedAt.Before(cutoff) { + if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil { + s.logger.Warn("failed to refresh stale task", + zap.Int64("task_id", tasks[i].ID), + zap.Error(err), + ) + } + } + } + } + + return nil +} + +func (s *TaskService) StartProcessor(ctx context.Context) { + s.mu.Lock() + if s.started { + s.mu.Unlock() + return + } + s.started = true + s.mu.Unlock() + + ctx, s.cancelFn = context.WithCancel(ctx) + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer func() { + if r := recover(); r != nil { + s.logger.Error("processor panic", zap.Any("panic", r)) + } + }() + for { + select { + case <-ctx.Done(): + return + case taskID, ok := <-s.taskCh: + if !ok { + return + } + taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute) + s.processWithRetry(taskCtx, taskID) + cancel() + } + } + }() + + s.RecoverStuckTasks(ctx) +} + +func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) { + task, err := s.CreateTask(ctx, req) + if err != nil { + return 0, err + } + + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return 0, fmt.Errorf("processor stopped, cannot submit task") + } + select { + case s.taskCh <- task.ID: + default: + s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID)) + } + s.mu.Unlock() + + return task.ID, nil +} + +func (s *TaskService) StopProcessor() { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return + } + s.stopped = true + close(s.taskCh) + s.mu.Unlock() + + if s.cancelFn != nil { + s.cancelFn() + } + s.wg.Wait() + + s.mu.Lock() + drainCh := s.taskCh + s.taskCh = make(chan int64, 16) + s.mu.Unlock() + + for taskID := range drainCh { + _ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "") + } +} + +func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) { + err := s.ProcessTask(ctx, taskID) + if err == nil { + return + } + + task, fetchErr := s.taskStore.GetByID(ctx, taskID) + if fetchErr != nil || task == nil { + return + } + + if task.RetryCount < 3 { + _ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1) + s.mu.Lock() + if !s.stopped { + select { + case s.taskCh <- taskID: + default: + s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID)) + } + } + s.mu.Unlock() + } +} + +func (s *TaskService) RecoverStuckTasks(ctx context.Context) { + tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute) + if err != nil { + s.logger.Error("failed to get stuck tasks", zap.Error(err)) + return + } + for i := range tasks { + _ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "") + s.mu.Lock() + if !s.stopped { + select { + case s.taskCh <- tasks[i].ID: + default: + s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID)) + } + } + s.mu.Unlock() + } +} diff --git a/internal/service/task_service_async_test.go b/internal/service/task_service_async_test.go new file mode 100644 index 0000000..53933a0 --- /dev/null +++ b/internal/service/task_service_async_test.go @@ -0,0 +1,416 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" + "gorm.io/driver/sqlite" + gormlogger "gorm.io/gorm/logger" + + "gorm.io/gorm" +) + +func setupAsyncTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: gormlogger.Default.LogMode(gormlogger.Silent), + }) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + return db +} + +type asyncTestEnv struct { + taskStore *store.TaskStore + appStore *store.ApplicationStore + svc *TaskService + srv *httptest.Server + db *gorm.DB + workDirBase string +} + +func newAsyncTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *asyncTestEnv { + t.Helper() + db := setupAsyncTestDB(t) + + ts := store.NewTaskStore(db) + as := store.NewApplicationStore(db) + + srv := httptest.NewServer(slurmHandler) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + jobSvc := NewJobService(client, zap.NewNop()) + + workDirBase := filepath.Join(t.TempDir(), "workdir") + os.MkdirAll(workDirBase, 0777) + + svc := NewTaskService(ts, as, nil, nil, nil, jobSvc, workDirBase, zap.NewNop()) + + return &asyncTestEnv{ + taskStore: ts, + appStore: as, + svc: svc, + srv: srv, + db: db, + workDirBase: workDirBase, + } +} + +func (e *asyncTestEnv) close() { + e.srv.Close() +} + +func (e *asyncTestEnv) createApp(t *testing.T, name, script string) int64 { + t.Helper() + id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{ + Name: name, + ScriptTemplate: script, + Parameters: json.RawMessage(`[]`), + }) + if err != nil { + t.Fatalf("create app: %v", err) + } + return id +} + +func TestTaskService_Async_SubmitAndProcess(t *testing.T) { + jobID := int32(42) + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "async-app", "#!/bin/bash\necho hello") + + ctx := context.Background() + env.svc.StartProcessor(ctx) + + taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{ + AppID: appID, + TaskName: "async-test", + }) + if err != nil { + t.Fatalf("SubmitAsync: %v", err) + } + if taskID == 0 { + t.Fatal("expected non-zero task ID") + } + + time.Sleep(500 * time.Millisecond) + + task, err := env.taskStore.GetByID(ctx, taskID) + if err != nil { + t.Fatalf("GetByID: %v", err) + } + if task.Status != model.TaskStatusQueued { + t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusQueued) + } + + env.svc.StopProcessor() +} + +func TestTaskService_Retry_MaxExhaustion(t *testing.T) { + callCount := int32(0) + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"slurm down"}`)) + })) + defer env.close() + + appID := env.createApp(t, "retry-app", "#!/bin/bash\necho hello") + + ctx := context.Background() + env.svc.StartProcessor(ctx) + + taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{ + AppID: appID, + TaskName: "retry-test", + }) + if err != nil { + t.Fatalf("SubmitAsync: %v", err) + } + + time.Sleep(2 * time.Second) + + task, _ := env.taskStore.GetByID(ctx, taskID) + if task.Status != model.TaskStatusFailed { + t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusFailed) + } + if task.RetryCount < 3 { + t.Errorf("RetryCount = %d, want >= 3", task.RetryCount) + } + + env.svc.StopProcessor() +} + +func TestTaskService_Recover_StuckTasks(t *testing.T) { + jobID := int32(99) + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "stuck-app", "#!/bin/bash\necho hello") + + ctx := context.Background() + + task := &model.Task{ + TaskName: "stuck-task", + AppID: appID, + AppName: "stuck-app", + Status: model.TaskStatusPreparing, + CurrentStep: model.TaskStepPreparing, + RetryCount: 0, + SubmittedAt: time.Now(), + } + taskID, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("Create stuck task: %v", err) + } + + staleTime := time.Now().Add(-10 * time.Minute) + env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, taskID) + + env.svc.StartProcessor(ctx) + + time.Sleep(1 * time.Second) + + updated, _ := env.taskStore.GetByID(ctx, taskID) + if updated.Status != model.TaskStatusQueued { + t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) + } + + env.svc.StopProcessor() +} + +func TestTaskService_Shutdown_InFlight(t *testing.T) { + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + jobID := int32(77) + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "shutdown-app", "#!/bin/bash\necho hello") + + ctx := context.Background() + env.svc.StartProcessor(ctx) + + taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{ + AppID: appID, + TaskName: "shutdown-test", + }) + + done := make(chan struct{}) + go func() { + env.svc.StopProcessor() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("StopProcessor did not complete within timeout") + } + + task, _ := env.taskStore.GetByID(ctx, taskID) + if task.Status != model.TaskStatusQueued && task.Status != model.TaskStatusSubmitted { + t.Logf("task status after shutdown: %q (acceptable)", task.Status) + } +} + +func TestTaskService_PanicRecovery(t *testing.T) { + jobID := int32(55) + panicDone := int32(0) + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.CompareAndSwapInt32(&panicDone, 0, 1) { + panic("intentional test panic") + } + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "panic-app", "#!/bin/bash\necho hello") + + ctx := context.Background() + env.svc.StartProcessor(ctx) + + taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{ + AppID: appID, + TaskName: "panic-test", + }) + + time.Sleep(1 * time.Second) + + atomic.StoreInt32(&panicDone, 1) + + env.svc.StopProcessor() + _ = taskID +} + +func TestTaskService_SubmitAsync_DuringShutdown(t *testing.T) { + env := newAsyncTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "shutdown-err-app", "#!/bin/bash\necho hello") + + ctx := context.Background() + env.svc.StartProcessor(ctx) + env.svc.StopProcessor() + + _, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{ + AppID: appID, + TaskName: "after-shutdown", + }) + if err == nil { + t.Fatal("expected error when submitting after shutdown") + } +} + +// TestTaskService_SubmitAsync_ChannelFull_NonBlocking verifies SubmitAsync +// returns without blocking when the task channel buffer (cap=16) is full. +// Before fix: SubmitAsync holds s.mu while blocking on full channel → deadlock. +// After fix: non-blocking select returns immediately. +func TestTaskService_SubmitAsync_ChannelFull_NonBlocking(t *testing.T) { + jobID := int32(42) + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "channel-full-app", "#!/bin/bash\necho hello") + ctx := context.Background() + + // Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention + taskIDs := make([]int64, 17) + for i := range taskIDs { + id, err := env.taskStore.Create(ctx, &model.Task{ + TaskName: fmt.Sprintf("fill-%d", i), + AppID: appID, + AppName: "channel-full-app", + Status: model.TaskStatusSubmitted, + CurrentStep: model.TaskStepSubmitting, + SubmittedAt: time.Now(), + }) + if err != nil { + t.Fatalf("create fill task %d: %v", i, err) + } + taskIDs[i] = id + } + + env.svc.StartProcessor(ctx) + defer env.svc.StopProcessor() + + // Consumer grabs first ID immediately; remaining 15 sit in channel. + // Push one more to fill buffer to 16 (full). + for _, id := range taskIDs { + env.svc.taskCh <- id + } + + // Overflow submit: must return within 3s (non-blocking after fix) + done := make(chan error, 1) + go func() { + _, submitErr := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{ + AppID: appID, + TaskName: "overflow-task", + }) + done <- submitErr + }() + + select { + case err := <-done: + if err != nil { + t.Logf("SubmitAsync returned error (acceptable after fix): %v", err) + } else { + t.Log("SubmitAsync returned without blocking — channel send is non-blocking") + } + case <-time.After(3 * time.Second): + t.Fatal("SubmitAsync blocked for >3s — channel send is blocking, potential deadlock") + } +} + +// TestTaskService_Retry_ChannelFull_NonBlocking verifies processWithRetry +// does not deadlock when re-enqueuing a failed task into a full channel. +// Before fix: processWithRetry holds s.mu while blocking on s.taskCh <- taskID → deadlock. +// After fix: non-blocking select drops the retry with a Warn log. +func TestTaskService_Retry_ChannelFull_NonBlocking(t *testing.T) { + env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"slurm down"}`)) + })) + defer env.close() + + appID := env.createApp(t, "retry-full-app", "#!/bin/bash\necho hello") + ctx := context.Background() + + // Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention + taskIDs := make([]int64, 17) + for i := range taskIDs { + id, err := env.taskStore.Create(ctx, &model.Task{ + TaskName: fmt.Sprintf("retry-%d", i), + AppID: appID, + AppName: "retry-full-app", + Status: model.TaskStatusSubmitted, + CurrentStep: model.TaskStepSubmitting, + RetryCount: 0, + SubmittedAt: time.Now(), + }) + if err != nil { + t.Fatalf("create retry task %d: %v", i, err) + } + taskIDs[i] = id + } + + env.svc.StartProcessor(ctx) + + // Push all 17 IDs: consumer grabs one (processing ~1s), 16 fill the buffer + for _, id := range taskIDs { + env.svc.taskCh <- id + } + + // Wait for consumer to finish first task and attempt retry into full channel + time.Sleep(2 * time.Second) + + // If processWithRetry deadlocked holding s.mu, StopProcessor hangs on mutex acquisition + done := make(chan struct{}) + go func() { + env.svc.StopProcessor() + close(done) + }() + + select { + case <-done: + t.Log("StopProcessor completed — retry channel send is non-blocking") + case <-time.After(5 * time.Second): + t.Fatal("StopProcessor did not complete within 5s — deadlock from retry channel send") + } +} diff --git a/internal/service/task_service_status_test.go b/internal/service/task_service_status_test.go new file mode 100644 index 0000000..1a7d23c --- /dev/null +++ b/internal/service/task_service_status_test.go @@ -0,0 +1,294 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func newTaskSvcTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Task{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + return db +} + +type taskSvcTestEnv struct { + taskStore *store.TaskStore + jobSvc *JobService + svc *TaskService + srv *httptest.Server + db *gorm.DB +} + +func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv { + t.Helper() + db := newTaskSvcTestDB(t) + ts := store.NewTaskStore(db) + + srv := httptest.NewServer(handler) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + jobSvc := NewJobService(client, zap.NewNop()) + svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop()) + + return &taskSvcTestEnv{ + taskStore: ts, + jobSvc: jobSvc, + svc: svc, + srv: srv, + db: db, + } +} + +func (e *taskSvcTestEnv) close() { + e.srv.Close() +} + +func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task { + return &model.Task{ + TaskName: name, + AppID: 1, + AppName: "test-app", + Status: status, + CurrentStep: "", + RetryCount: 0, + UserID: "user1", + SubmittedAt: time.Now(), + SlurmJobID: slurmJobID, + } +} + +func TestTaskService_MapSlurmState_AllStates(t *testing.T) { + env := newTaskSvcTestEnv(t, nil) + defer env.close() + + cases := []struct { + input []string + expected string + }{ + {[]string{"PENDING"}, model.TaskStatusQueued}, + {[]string{"RUNNING"}, model.TaskStatusRunning}, + {[]string{"CONFIGURING"}, model.TaskStatusRunning}, + {[]string{"COMPLETING"}, model.TaskStatusRunning}, + {[]string{"COMPLETED"}, model.TaskStatusCompleted}, + {[]string{"FAILED"}, model.TaskStatusFailed}, + {[]string{"CANCELLED"}, model.TaskStatusFailed}, + {[]string{"TIMEOUT"}, model.TaskStatusFailed}, + {[]string{"NODE_FAIL"}, model.TaskStatusFailed}, + {[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed}, + {[]string{"PREEMPTED"}, model.TaskStatusFailed}, + {[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning}, + {[]string{"unknown_state"}, model.TaskStatusRunning}, + {[]string{"pending"}, model.TaskStatusQueued}, + {[]string{"Running"}, model.TaskStatusRunning}, + } + + for _, tc := range cases { + got := env.svc.mapSlurmStateToTaskStatus(tc.input) + if got != tc.expected { + t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +func TestTaskService_MapSlurmState_Empty(t *testing.T) { + env := newTaskSvcTestEnv(t, nil) + defer env.close() + + got := env.svc.mapSlurmStateToTaskStatus([]string{}) + if got != model.TaskStatusRunning { + t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning) + } + + got = env.svc.mapSlurmStateToTaskStatus(nil) + if got != model.TaskStatusRunning { + t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning) + } +} + +func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) { + jobID := int32(42) + env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiJobInfoResp{ + Jobs: slurm.JobInfoMsg{ + { + JobID: &jobID, + JobState: []string{"RUNNING"}, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer env.close() + + ctx := context.Background() + task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID) + id, err := env.taskStore.Create(ctx, task) + if err != nil { + t.Fatalf("Create: %v", err) + } + + err = env.svc.refreshTaskStatus(ctx, id) + if err != nil { + t.Fatalf("refreshTaskStatus: %v", err) + } + + updated, _ := env.taskStore.GetByID(ctx, id) + if updated.Status != model.TaskStatusRunning { + t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning) + } +} + +func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) { + env := newTaskSvcTestEnv(t, nil) + defer env.close() + + ctx := context.Background() + task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil) + id, _ := env.taskStore.Create(ctx, task) + + err := env.svc.refreshTaskStatus(ctx, id) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + got, _ := env.taskStore.GetByID(ctx, id) + if got.Status != model.TaskStatusQueued { + t.Errorf("status should remain unchanged, got %q", got.Status) + } +} + +func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) { + jobID := int32(42) + env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"down"}`)) + })) + defer env.close() + + ctx := context.Background() + task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID) + id, _ := env.taskStore.Create(ctx, task) + + err := env.svc.refreshTaskStatus(ctx, id) + if err != nil { + t.Fatalf("expected no error (soft fail), got %v", err) + } + + got, _ := env.taskStore.GetByID(ctx, id) + if got.Status != model.TaskStatusQueued { + t.Errorf("status should remain unchanged on slurm error, got %q", got.Status) + } +} + +func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) { + jobID := int32(42) + env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiJobInfoResp{ + Jobs: slurm.JobInfoMsg{ + { + JobID: &jobID, + JobState: []string{"RUNNING"}, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer env.close() + + ctx := context.Background() + task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID) + id, _ := env.taskStore.Create(ctx, task) + + err := env.svc.refreshTaskStatus(ctx, id) + if err != nil { + t.Fatalf("refreshTaskStatus: %v", err) + } + + got, _ := env.taskStore.GetByID(ctx, id) + if got.Status != model.TaskStatusRunning { + t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning) + } +} + +func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) { + jobID := int32(42) + slurmQueried := false + env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slurmQueried = true + w.WriteHeader(http.StatusInternalServerError) + })) + defer env.close() + + ctx := context.Background() + task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID) + id, _ := env.taskStore.Create(ctx, task) + + freshTask, _ := env.taskStore.GetByID(ctx, id) + if freshTask == nil { + t.Fatal("task not found") + } + + env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id) + + err := env.svc.RefreshStaleTasks(ctx) + if err != nil { + t.Fatalf("RefreshStaleTasks: %v", err) + } + + if slurmQueried { + t.Error("expected no Slurm query for fresh task") + } +} + +func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) { + jobID := int32(42) + env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiJobInfoResp{ + Jobs: slurm.JobInfoMsg{ + { + JobID: &jobID, + JobState: []string{"COMPLETED"}, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer env.close() + + ctx := context.Background() + task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID) + id, _ := env.taskStore.Create(ctx, task) + + staleTime := time.Now().Add(-60 * time.Second) + env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id) + + err := env.svc.RefreshStaleTasks(ctx) + if err != nil { + t.Fatalf("RefreshStaleTasks: %v", err) + } + + got, _ := env.taskStore.GetByID(ctx, id) + if got.Status != model.TaskStatusCompleted { + t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted) + } +} diff --git a/internal/service/task_service_test.go b/internal/service/task_service_test.go new file mode 100644 index 0000000..d87c70b --- /dev/null +++ b/internal/service/task_service_test.go @@ -0,0 +1,538 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" + "gorm.io/driver/sqlite" + gormlogger "gorm.io/gorm/logger" + + "gorm.io/gorm" +) + +func setupTaskTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: gormlogger.Default.LogMode(gormlogger.Silent), + }) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil { + t.Fatalf("auto migrate: %v", err) + } + return db +} + +type taskTestEnv struct { + taskStore *store.TaskStore + appStore *store.ApplicationStore + fileStore *store.FileStore + blobStore *store.BlobStore + svc *TaskService + srv *httptest.Server + db *gorm.DB + workDirBase string +} + +func newTaskTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *taskTestEnv { + t.Helper() + db := setupTaskTestDB(t) + + ts := store.NewTaskStore(db) + as := store.NewApplicationStore(db) + fs := store.NewFileStore(db) + bs := store.NewBlobStore(db) + + srv := httptest.NewServer(slurmHandler) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + jobSvc := NewJobService(client, zap.NewNop()) + + workDirBase := filepath.Join(t.TempDir(), "workdir") + os.MkdirAll(workDirBase, 0777) + + svc := NewTaskService(ts, as, fs, bs, nil, jobSvc, workDirBase, zap.NewNop()) + + return &taskTestEnv{ + taskStore: ts, + appStore: as, + fileStore: fs, + blobStore: bs, + svc: svc, + srv: srv, + db: db, + workDirBase: workDirBase, + } +} + +func (e *taskTestEnv) close() { + e.srv.Close() +} + +func (e *taskTestEnv) createApp(t *testing.T, name, script string, params json.RawMessage) int64 { + t.Helper() + id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{ + Name: name, + ScriptTemplate: script, + Parameters: params, + }) + if err != nil { + t.Fatalf("create app: %v", err) + } + return id +} + +func TestTaskService_CreateTask_Success(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "my-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + TaskName: "test-task", + Values: map[string]string{"KEY": "val"}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + if task.ID == 0 { + t.Error("expected non-zero task ID") + } + if task.AppID != appID { + t.Errorf("AppID = %d, want %d", task.AppID, appID) + } + if task.AppName != "my-app" { + t.Errorf("AppName = %q, want %q", task.AppName, "my-app") + } + if task.Status != model.TaskStatusSubmitted { + t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusSubmitted) + } + if task.TaskName != "test-task" { + t.Errorf("TaskName = %q, want %q", task.TaskName, "test-task") + } + + var values map[string]string + if err := json.Unmarshal(task.Values, &values); err != nil { + t.Fatalf("unmarshal values: %v", err) + } + if values["KEY"] != "val" { + t.Errorf("values[KEY] = %q, want %q", values["KEY"], "val") + } +} + +func TestTaskService_CreateTask_InvalidAppID(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + _, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: 999, + }) + if err == nil { + t.Fatal("expected error for invalid app_id") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error should mention 'not found', got: %v", err) + } +} + +func TestTaskService_CreateTask_ExceedsFileLimit(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil) + + fileIDs := make([]int64, 101) + for i := range fileIDs { + fileIDs[i] = int64(i + 1) + } + + _, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + InputFileIDs: fileIDs, + }) + if err == nil { + t.Fatal("expected error for exceeding file limit") + } + if !strings.Contains(err.Error(), "exceeds limit") { + t.Errorf("error should mention limit, got: %v", err) + } +} + +func TestTaskService_CreateTask_DuplicateFileIDs(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil) + + ctx := context.Background() + for _, id := range []int64{1, 2} { + f := &model.File{ + Name: "file.txt", + BlobSHA256: "abc123", + } + if err := env.fileStore.Create(ctx, f); err != nil { + t.Fatalf("create file: %v", err) + } + if f.ID != id { + t.Fatalf("expected file ID %d, got %d", id, f.ID) + } + } + + task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{ + AppID: appID, + InputFileIDs: []int64{1, 1, 2, 2}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + var fileIDs []int64 + if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil { + t.Fatalf("unmarshal file ids: %v", err) + } + if len(fileIDs) != 2 { + t.Fatalf("expected 2 deduplicated file IDs, got %d: %v", len(fileIDs), fileIDs) + } + if fileIDs[0] != 1 || fileIDs[1] != 2 { + t.Errorf("expected [1,2], got %v", fileIDs) + } +} + +func TestTaskService_CreateTask_AutoName(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "My Cool App", "#!/bin/bash\necho hi", nil) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + if !strings.HasPrefix(task.TaskName, "My_Cool_App_") { + t.Errorf("auto-generated name should start with 'My_Cool_App_', got %q", task.TaskName) + } +} + +func TestTaskService_CreateTask_NilValues(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: nil, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + if string(task.Values) != `{}` { + t.Errorf("Values = %q, want {}", string(task.Values)) + } +} + +func TestTaskService_ProcessTask_Success(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "test-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{"INPUT": "hello"}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + updated, _ := env.taskStore.GetByID(context.Background(), task.ID) + if updated.Status != model.TaskStatusQueued { + t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) + } + if updated.SlurmJobID == nil || *updated.SlurmJobID != 42 { + t.Errorf("SlurmJobID = %v, want 42", updated.SlurmJobID) + } + if updated.WorkDir == "" { + t.Error("WorkDir should not be empty") + } + if !strings.HasPrefix(updated.WorkDir, env.workDirBase) { + t.Errorf("WorkDir = %q, should start with %q", updated.WorkDir, env.workDirBase) + } + + info, err := os.Stat(updated.WorkDir) + if err != nil { + t.Fatalf("stat workdir: %v", err) + } + if !info.IsDir() { + t.Error("WorkDir should be a directory") + } +} + +func TestTaskService_ProcessTask_TaskNotFound(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + err := env.svc.ProcessTask(context.Background(), 999) + if err == nil { + t.Fatal("expected error for non-existent task") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error should mention 'not found', got: %v", err) + } +} + +func TestTaskService_ProcessTask_SlurmError(t *testing.T) { + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"slurm down"}`)) + })) + defer env.close() + + appID := env.createApp(t, "test-app", "#!/bin/bash\necho hello", nil) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err == nil { + t.Fatal("expected error from Slurm") + } + + updated, _ := env.taskStore.GetByID(context.Background(), task.ID) + if updated.Status != model.TaskStatusFailed { + t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusFailed) + } + if updated.CurrentStep != model.TaskStepSubmitting { + t.Errorf("CurrentStep = %q, want %q", updated.CurrentStep, model.TaskStepSubmitting) + } + if !strings.Contains(updated.ErrorMessage, "submit job") { + t.Errorf("ErrorMessage should mention 'submit job', got: %q", updated.ErrorMessage) + } +} + +func TestTaskService_ProcessTaskSync(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "sync-app", "#!/bin/bash\necho hello", nil) + + resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + }) + if err != nil { + t.Fatalf("ProcessTaskSync: %v", err) + } + if resp.JobID != 42 { + t.Errorf("JobID = %d, want 42", resp.JobID) + } +} + +func TestTaskService_ProcessTaskSync_NoMinIO(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "no-minio-app", "#!/bin/bash\necho hello", nil) + + resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + InputFileIDs: nil, + }) + if err != nil { + t.Fatalf("ProcessTaskSync: %v", err) + } + if resp.JobID != 42 { + t.Errorf("JobID = %d, want 42", resp.JobID) + } +} + +func TestTaskService_ProcessTask_NilValues(t *testing.T) { + jobID := int32(55) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "nil-val-app", "#!/bin/bash\necho hello", nil) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: nil, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err != nil { + t.Fatalf("ProcessTask: %v", err) + } + + updated, _ := env.taskStore.GetByID(context.Background(), task.ID) + if updated.Status != model.TaskStatusQueued { + t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) + } +} + +func TestTaskService_ListTasks(t *testing.T) { + env := newTaskTestEnv(t, nil) + defer env.close() + + appID := env.createApp(t, "list-app", "#!/bin/bash\necho hi", nil) + + for i := 0; i < 3; i++ { + _, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + TaskName: "task-" + string(rune('A'+i)), + }) + if err != nil { + t.Fatalf("CreateTask %d: %v", i, err) + } + } + + tasks, total, err := env.svc.ListTasks(context.Background(), &model.TaskListQuery{ + Page: 1, + PageSize: 10, + }) + if err != nil { + t.Fatalf("ListTasks: %v", err) + } + if total != 3 { + t.Errorf("total = %d, want 3", total) + } + if len(tasks) != 3 { + t.Errorf("len(tasks) = %d, want 3", len(tasks)) + } +} + +func TestTaskService_ProcessTask_ValidateParams_MissingRequired(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + // App requires INPUT param, but we submit without it + appID := env.createApp(t, "validation-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{}, // missing required INPUT + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err == nil { + t.Fatal("expected error for missing required parameter, got nil — ValidateParams is not being called in ProcessTask pipeline") + } + errStr := err.Error() + if !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "missing") && !strings.Contains(errStr, "INPUT") { + t.Errorf("error should mention 'validation', 'missing', or 'INPUT', got: %v", err) + } +} + +func TestTaskService_ProcessTask_ValidateParams_InvalidInteger(t *testing.T) { + jobID := int32(42) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + // App expects integer param NUM, but we submit "abc" + appID := env.createApp(t, "int-validation-app", "#!/bin/bash\necho $NUM", json.RawMessage(`[{"name":"NUM","type":"integer","required":true}]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{"NUM": "abc"}, // invalid integer + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err == nil { + t.Fatal("expected error for invalid integer parameter, got nil — ValidateParams is not being called in ProcessTask pipeline") + } + errStr := err.Error() + if !strings.Contains(errStr, "integer") && !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "NUM") { + t.Errorf("error should mention 'integer', 'validation', or 'NUM', got: %v", err) + } +} + +func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) { + jobID := int32(99) + env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + }) + })) + defer env.close() + + appID := env.createApp(t, "valid-params-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`)) + + task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{ + AppID: appID, + Values: map[string]string{"INPUT": "hello"}, + }) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + + err = env.svc.ProcessTask(context.Background(), task.ID) + if err != nil { + t.Fatalf("ProcessTask with valid params: %v", err) + } + + updated, _ := env.taskStore.GetByID(context.Background(), task.ID) + if updated.Status != model.TaskStatusQueued { + t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued) + } + if updated.SlurmJobID == nil || *updated.SlurmJobID != 99 { + t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID) + } +}