package service import ( "context" "encoding/json" "fmt" "os" "path/filepath" "sort" "strconv" "strings" "sync" "time" "gcy_hpc_server/internal/model" "gcy_hpc_server/internal/store" "go.uber.org/zap" ) type TaskService struct { taskStore *store.TaskStore appStore *store.ApplicationStore fileStore *store.FileStore // nil ok blobStore *store.BlobStore // nil ok stagingSvc *FileStagingService // nil ok — MinIO unavailable jobSvc *JobService workDirBase string logger *zap.Logger // async processing taskCh chan int64 // buffered channel, cap=16 cancelFn context.CancelFunc wg sync.WaitGroup mu sync.Mutex // protects taskCh from send-on-closed started bool // prevent double-start stopped bool } func NewTaskService( taskStore *store.TaskStore, appStore *store.ApplicationStore, fileStore *store.FileStore, blobStore *store.BlobStore, stagingSvc *FileStagingService, jobSvc *JobService, workDirBase string, logger *zap.Logger, ) *TaskService { return &TaskService{ taskStore: taskStore, appStore: appStore, fileStore: fileStore, blobStore: blobStore, stagingSvc: stagingSvc, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger, taskCh: make(chan int64, 16), } } func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) { app, err := s.appStore.GetByID(ctx, req.AppID) if err != nil { return nil, fmt.Errorf("get application: %w", err) } if app == nil { return nil, fmt.Errorf("application %d not found", req.AppID) } // 2. Validate file limit if len(req.InputFileIDs) > 100 { return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs)) } // 3. Deduplicate file IDs fileIDs := uniqueInt64s(req.InputFileIDs) // 4. Validate file IDs exist if s.fileStore != nil && len(fileIDs) > 0 { files, err := s.fileStore.GetByIDs(ctx, fileIDs) if err != nil { return nil, fmt.Errorf("validate file ids: %w", err) } found := make(map[int64]bool, len(files)) for _, f := range files { found[f.ID] = true } for _, id := range fileIDs { if !found[id] { return nil, fmt.Errorf("file %d not found", id) } } } // 5. Auto-generate task name if empty taskName := req.TaskName if taskName == "" { taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405") } // 6. Marshal values valuesJSON := json.RawMessage(`{}`) if len(req.Values) > 0 { b, err := json.Marshal(req.Values) if err != nil { return nil, fmt.Errorf("marshal values: %w", err) } valuesJSON = b } // 7. Marshal input_file_ids fileIDsJSON := json.RawMessage(`[]`) if len(fileIDs) > 0 { b, err := json.Marshal(fileIDs) if err != nil { return nil, fmt.Errorf("marshal file ids: %w", err) } fileIDsJSON = b } // 8. Create task record task := &model.Task{ TaskName: taskName, AppID: app.ID, AppName: app.Name, Status: model.TaskStatusSubmitted, Values: valuesJSON, InputFileIDs: fileIDsJSON, SubmittedAt: time.Now(), } taskID, err := s.taskStore.Create(ctx, task) if err != nil { return nil, fmt.Errorf("create task: %w", err) } task.ID = taskID return task, nil } // ProcessTask runs the full synchronous processing pipeline for a task. func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error { // 1. Fetch task task, err := s.taskStore.GetByID(ctx, taskID) if err != nil { return fmt.Errorf("get task: %w", err) } if task == nil { return fmt.Errorf("task %d not found", taskID) } fail := func(step, msg string) error { _ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg) _ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount) return fmt.Errorf("%s", msg) } currentStep := task.CurrentStep var workDir string var app *model.Application if currentStep == "" || currentStep == model.TaskStepPreparing { // 2. Set preparing if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil { return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err)) } // 3. Fetch app app, err = s.appStore.GetByID(ctx, task.AppID) if err != nil { return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err)) } if app == nil { return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID)) } // 4-5. Create work directory workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4)) if err := os.MkdirAll(workDir, 0777); err != nil { return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err)) } // 6. CHMOD traversal — critical for multi-user HPC for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) { os.Chmod(dir, 0777) } os.Chmod(s.workDirBase, 0777) // 7. UpdateWorkDir if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil { return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err)) } } else { app, err = s.appStore.GetByID(ctx, task.AppID) if err != nil { return fail(currentStep, fmt.Sprintf("get application: %v", err)) } if app == nil { return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID)) } workDir = task.WorkDir } if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading { if currentStep == model.TaskStepDownloading && workDir != "" { matches, _ := filepath.Glob(filepath.Join(workDir, "*")) for _, f := range matches { os.Remove(f) } } // 8. Set downloading if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil { return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err)) } // 9. Parse input_file_ids var fileIDs []int64 if len(task.InputFileIDs) > 0 { if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil { return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err)) } } // 10-12. Download files if len(fileIDs) > 0 { if s.stagingSvc == nil { return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files") } if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil { return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err)) } } } // 13-14. Set ready + submitting if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil { return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err)) } // 15. Parse app parameters var params []model.ParameterSchema if len(app.Parameters) > 0 { if err := json.Unmarshal(app.Parameters, ¶ms); err != nil { return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err)) } } // 16. Parse task values values := make(map[string]string) if len(task.Values) > 0 { if err := json.Unmarshal(task.Values, &values); err != nil { return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err)) } } if err := ValidateParams(params, values); err != nil { return fail(model.TaskStepSubmitting, err.Error()) } // 17. Render script rendered := RenderScript(app.ScriptTemplate, params, values) // 18. Submit to Slurm jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{ Script: rendered, WorkDir: workDir, }) if err != nil { return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err)) } // 19. Update slurm_job_id and status to queued if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil { return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err)) } if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil { return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err)) } return nil } // ListTasks returns a paginated list of tasks. func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) { return s.taskStore.List(ctx, query) } // ProcessTaskSync creates and processes a task synchronously, returning a JobResponse // for old API compatibility. func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) { // 1. Create task task, err := s.CreateTask(ctx, req) if err != nil { return nil, err } // 2. Process synchronously if err := s.ProcessTask(ctx, task.ID); err != nil { return nil, err } // 3. Re-fetch to get updated slurm_job_id task, err = s.taskStore.GetByID(ctx, task.ID) if err != nil { return nil, fmt.Errorf("re-fetch task: %w", err) } if task == nil || task.SlurmJobID == nil { return nil, fmt.Errorf("task has no slurm job id after processing") } // 4. Return JobResponse return &model.JobResponse{JobID: *task.SlurmJobID}, nil } // uniqueInt64s deduplicates and sorts a slice of int64. func uniqueInt64s(ids []int64) []int64 { if len(ids) == 0 { return nil } seen := make(map[int64]bool, len(ids)) result := make([]int64, 0, len(ids)) for _, id := range ids { if !seen[id] { seen[id] = true result = append(result, id) } } sort.Slice(result, func(i, j int) bool { return result[i] < result[j] }) return result } func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string { if len(slurmState) == 0 { return model.TaskStatusRunning } state := strings.ToUpper(slurmState[0]) switch state { case "PENDING": return model.TaskStatusQueued case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT": return model.TaskStatusRunning case "COMPLETED": return model.TaskStatusCompleted case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED": return model.TaskStatusFailed default: return model.TaskStatusRunning } } func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error { task, err := s.taskStore.GetByID(ctx, taskID) if err != nil { s.logger.Error("failed to fetch task for refresh", zap.Int64("task_id", taskID), zap.Error(err), ) return err } if task == nil || task.SlurmJobID == nil { return nil } jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10)) if err != nil { s.logger.Warn("failed to query slurm job status during refresh", zap.Int64("task_id", taskID), zap.Int32("slurm_job_id", *task.SlurmJobID), zap.Error(err), ) return nil } if jobResp == nil { return nil } newStatus := s.mapSlurmStateToTaskStatus(jobResp.State) if newStatus != task.Status { s.logger.Info("updating task status from slurm", zap.Int64("task_id", taskID), zap.String("old_status", task.Status), zap.String("new_status", newStatus), ) return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "") } return nil } func (s *TaskService) RefreshStaleTasks(ctx context.Context) error { staleThreshold := 30 * time.Second nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning} for _, status := range nonTerminal { tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{ Status: status, Page: 1, PageSize: 1000, }) if err != nil { s.logger.Warn("failed to list tasks for stale refresh", zap.String("status", status), zap.Error(err), ) continue } cutoff := time.Now().Add(-staleThreshold) for i := range tasks { if tasks[i].UpdatedAt.Before(cutoff) { if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil { s.logger.Warn("failed to refresh stale task", zap.Int64("task_id", tasks[i].ID), zap.Error(err), ) } } } } return nil } func (s *TaskService) StartProcessor(ctx context.Context) { s.mu.Lock() if s.started { s.mu.Unlock() return } s.started = true s.mu.Unlock() ctx, s.cancelFn = context.WithCancel(ctx) s.wg.Add(1) go func() { defer s.wg.Done() defer func() { if r := recover(); r != nil { s.logger.Error("processor panic", zap.Any("panic", r)) } }() for { select { case <-ctx.Done(): return case taskID, ok := <-s.taskCh: if !ok { return } taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute) s.processWithRetry(taskCtx, taskID) cancel() } } }() s.RecoverStuckTasks(ctx) } func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) { task, err := s.CreateTask(ctx, req) if err != nil { return 0, err } s.mu.Lock() if s.stopped { s.mu.Unlock() return 0, fmt.Errorf("processor stopped, cannot submit task") } select { case s.taskCh <- task.ID: default: s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID)) } s.mu.Unlock() return task.ID, nil } func (s *TaskService) StopProcessor() { s.mu.Lock() if s.stopped { s.mu.Unlock() return } s.stopped = true close(s.taskCh) s.mu.Unlock() if s.cancelFn != nil { s.cancelFn() } s.wg.Wait() s.mu.Lock() drainCh := s.taskCh s.taskCh = make(chan int64, 16) s.mu.Unlock() for taskID := range drainCh { _ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "") } } func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) { err := s.ProcessTask(ctx, taskID) if err == nil { return } task, fetchErr := s.taskStore.GetByID(ctx, taskID) if fetchErr != nil || task == nil { return } if task.RetryCount < 3 { _ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1) s.mu.Lock() if !s.stopped { select { case s.taskCh <- taskID: default: s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID)) } } s.mu.Unlock() } } func (s *TaskService) RecoverStuckTasks(ctx context.Context) { tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute) if err != nil { s.logger.Error("failed to get stuck tasks", zap.Error(err)) return } for i := range tasks { _ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "") s.mu.Lock() if !s.stopped { select { case s.taskCh <- tasks[i].ID: default: s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID)) } } s.mu.Unlock() } }