feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission

This commit is contained in:
dailz
2026-04-15 21:31:02 +08:00
parent acf8c1d62b
commit ec64300ff2
9 changed files with 2394 additions and 136 deletions

View File

@@ -0,0 +1,554 @@
package service
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
type TaskService struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
fileStore *store.FileStore // nil ok
blobStore *store.BlobStore // nil ok
stagingSvc *FileStagingService // nil ok — MinIO unavailable
jobSvc *JobService
workDirBase string
logger *zap.Logger
// async processing
taskCh chan int64 // buffered channel, cap=16
cancelFn context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex // protects taskCh from send-on-closed
started bool // prevent double-start
stopped bool
}
func NewTaskService(
taskStore *store.TaskStore,
appStore *store.ApplicationStore,
fileStore *store.FileStore,
blobStore *store.BlobStore,
stagingSvc *FileStagingService,
jobSvc *JobService,
workDirBase string,
logger *zap.Logger,
) *TaskService {
return &TaskService{
taskStore: taskStore,
appStore: appStore,
fileStore: fileStore,
blobStore: blobStore,
stagingSvc: stagingSvc,
jobSvc: jobSvc,
workDirBase: workDirBase,
logger: logger,
taskCh: make(chan int64, 16),
}
}
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
app, err := s.appStore.GetByID(ctx, req.AppID)
if err != nil {
return nil, fmt.Errorf("get application: %w", err)
}
if app == nil {
return nil, fmt.Errorf("application %d not found", req.AppID)
}
// 2. Validate file limit
if len(req.InputFileIDs) > 100 {
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
}
// 3. Deduplicate file IDs
fileIDs := uniqueInt64s(req.InputFileIDs)
// 4. Validate file IDs exist
if s.fileStore != nil && len(fileIDs) > 0 {
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
if err != nil {
return nil, fmt.Errorf("validate file ids: %w", err)
}
found := make(map[int64]bool, len(files))
for _, f := range files {
found[f.ID] = true
}
for _, id := range fileIDs {
if !found[id] {
return nil, fmt.Errorf("file %d not found", id)
}
}
}
// 5. Auto-generate task name if empty
taskName := req.TaskName
if taskName == "" {
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
}
// 6. Marshal values
valuesJSON := json.RawMessage(`{}`)
if len(req.Values) > 0 {
b, err := json.Marshal(req.Values)
if err != nil {
return nil, fmt.Errorf("marshal values: %w", err)
}
valuesJSON = b
}
// 7. Marshal input_file_ids
fileIDsJSON := json.RawMessage(`[]`)
if len(fileIDs) > 0 {
b, err := json.Marshal(fileIDs)
if err != nil {
return nil, fmt.Errorf("marshal file ids: %w", err)
}
fileIDsJSON = b
}
// 8. Create task record
task := &model.Task{
TaskName: taskName,
AppID: app.ID,
AppName: app.Name,
Status: model.TaskStatusSubmitted,
Values: valuesJSON,
InputFileIDs: fileIDsJSON,
SubmittedAt: time.Now(),
}
taskID, err := s.taskStore.Create(ctx, task)
if err != nil {
return nil, fmt.Errorf("create task: %w", err)
}
task.ID = taskID
return task, nil
}
// ProcessTask runs the full synchronous processing pipeline for a task.
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
// 1. Fetch task
task, err := s.taskStore.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("get task: %w", err)
}
if task == nil {
return fmt.Errorf("task %d not found", taskID)
}
fail := func(step, msg string) error {
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
return fmt.Errorf("%s", msg)
}
currentStep := task.CurrentStep
var workDir string
var app *model.Application
if currentStep == "" || currentStep == model.TaskStepPreparing {
// 2. Set preparing
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
}
// 3. Fetch app
app, err = s.appStore.GetByID(ctx, task.AppID)
if err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
}
if app == nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
}
// 4-5. Create work directory
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
if err := os.MkdirAll(workDir, 0777); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
}
// 6. CHMOD traversal — critical for multi-user HPC
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
os.Chmod(dir, 0777)
}
os.Chmod(s.workDirBase, 0777)
// 7. UpdateWorkDir
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
}
} else {
app, err = s.appStore.GetByID(ctx, task.AppID)
if err != nil {
return fail(currentStep, fmt.Sprintf("get application: %v", err))
}
if app == nil {
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
}
workDir = task.WorkDir
}
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
if currentStep == model.TaskStepDownloading && workDir != "" {
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
for _, f := range matches {
os.Remove(f)
}
}
// 8. Set downloading
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
}
// 9. Parse input_file_ids
var fileIDs []int64
if len(task.InputFileIDs) > 0 {
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
}
}
// 10-12. Download files
if len(fileIDs) > 0 {
if s.stagingSvc == nil {
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
}
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
}
}
}
// 13-14. Set ready + submitting
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
}
// 15. Parse app parameters
var params []model.ParameterSchema
if len(app.Parameters) > 0 {
if err := json.Unmarshal(app.Parameters, &params); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
}
}
// 16. Parse task values
values := make(map[string]string)
if len(task.Values) > 0 {
if err := json.Unmarshal(task.Values, &values); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
}
}
if err := ValidateParams(params, values); err != nil {
return fail(model.TaskStepSubmitting, err.Error())
}
// 17. Render script
rendered := RenderScript(app.ScriptTemplate, params, values)
// 18. Submit to Slurm
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
Script: rendered,
WorkDir: workDir,
})
if err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
}
// 19. Update slurm_job_id and status to queued
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
}
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
}
return nil
}
// ListTasks returns a paginated list of tasks.
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
return s.taskStore.List(ctx, query)
}
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
// for old API compatibility.
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
// 1. Create task
task, err := s.CreateTask(ctx, req)
if err != nil {
return nil, err
}
// 2. Process synchronously
if err := s.ProcessTask(ctx, task.ID); err != nil {
return nil, err
}
// 3. Re-fetch to get updated slurm_job_id
task, err = s.taskStore.GetByID(ctx, task.ID)
if err != nil {
return nil, fmt.Errorf("re-fetch task: %w", err)
}
if task == nil || task.SlurmJobID == nil {
return nil, fmt.Errorf("task has no slurm job id after processing")
}
// 4. Return JobResponse
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
}
// uniqueInt64s deduplicates and sorts a slice of int64.
func uniqueInt64s(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
seen := make(map[int64]bool, len(ids))
result := make([]int64, 0, len(ids))
for _, id := range ids {
if !seen[id] {
seen[id] = true
result = append(result, id)
}
}
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
return result
}
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
if len(slurmState) == 0 {
return model.TaskStatusRunning
}
state := strings.ToUpper(slurmState[0])
switch state {
case "PENDING":
return model.TaskStatusQueued
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
return model.TaskStatusRunning
case "COMPLETED":
return model.TaskStatusCompleted
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
return model.TaskStatusFailed
default:
return model.TaskStatusRunning
}
}
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
task, err := s.taskStore.GetByID(ctx, taskID)
if err != nil {
s.logger.Error("failed to fetch task for refresh",
zap.Int64("task_id", taskID),
zap.Error(err),
)
return err
}
if task == nil || task.SlurmJobID == nil {
return nil
}
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
if err != nil {
s.logger.Warn("failed to query slurm job status during refresh",
zap.Int64("task_id", taskID),
zap.Int32("slurm_job_id", *task.SlurmJobID),
zap.Error(err),
)
return nil
}
if jobResp == nil {
return nil
}
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
if newStatus != task.Status {
s.logger.Info("updating task status from slurm",
zap.Int64("task_id", taskID),
zap.String("old_status", task.Status),
zap.String("new_status", newStatus),
)
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
}
return nil
}
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
staleThreshold := 30 * time.Second
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
for _, status := range nonTerminal {
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
Status: status,
Page: 1,
PageSize: 1000,
})
if err != nil {
s.logger.Warn("failed to list tasks for stale refresh",
zap.String("status", status),
zap.Error(err),
)
continue
}
cutoff := time.Now().Add(-staleThreshold)
for i := range tasks {
if tasks[i].UpdatedAt.Before(cutoff) {
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
s.logger.Warn("failed to refresh stale task",
zap.Int64("task_id", tasks[i].ID),
zap.Error(err),
)
}
}
}
}
return nil
}
func (s *TaskService) StartProcessor(ctx context.Context) {
s.mu.Lock()
if s.started {
s.mu.Unlock()
return
}
s.started = true
s.mu.Unlock()
ctx, s.cancelFn = context.WithCancel(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
if r := recover(); r != nil {
s.logger.Error("processor panic", zap.Any("panic", r))
}
}()
for {
select {
case <-ctx.Done():
return
case taskID, ok := <-s.taskCh:
if !ok {
return
}
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
s.processWithRetry(taskCtx, taskID)
cancel()
}
}
}()
s.RecoverStuckTasks(ctx)
}
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
task, err := s.CreateTask(ctx, req)
if err != nil {
return 0, err
}
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return 0, fmt.Errorf("processor stopped, cannot submit task")
}
select {
case s.taskCh <- task.ID:
default:
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
}
s.mu.Unlock()
return task.ID, nil
}
func (s *TaskService) StopProcessor() {
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return
}
s.stopped = true
close(s.taskCh)
s.mu.Unlock()
if s.cancelFn != nil {
s.cancelFn()
}
s.wg.Wait()
s.mu.Lock()
drainCh := s.taskCh
s.taskCh = make(chan int64, 16)
s.mu.Unlock()
for taskID := range drainCh {
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
}
}
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
err := s.ProcessTask(ctx, taskID)
if err == nil {
return
}
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
if fetchErr != nil || task == nil {
return
}
if task.RetryCount < 3 {
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
s.mu.Lock()
if !s.stopped {
select {
case s.taskCh <- taskID:
default:
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
}
}
s.mu.Unlock()
}
}
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
if err != nil {
s.logger.Error("failed to get stuck tasks", zap.Error(err))
return
}
for i := range tasks {
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
s.mu.Lock()
if !s.stopped {
select {
case s.taskCh <- tasks[i].ID:
default:
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
}
}
s.mu.Unlock()
}
}