RecoverStuckTasks now skips tasks that already have a slurm_job_id, and ProcessTask adds a guard before the submitting step to prevent re-submission even if a task is incorrectly re-enqueued. Also deprecates POST /api/v1/jobs/submit endpoint (replaced by POST /tasks) and comments out related handlers and tests.
724 lines
20 KiB
Go
724 lines
20 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gcy_hpc_server/internal/model"
|
|
"gcy_hpc_server/internal/store"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type TaskService struct {
|
|
taskStore *store.TaskStore
|
|
appStore *store.ApplicationStore
|
|
fileStore *store.FileStore // nil ok
|
|
blobStore *store.BlobStore // nil ok
|
|
stagingSvc *FileStagingService // nil ok — MinIO unavailable
|
|
jobSvc *JobService
|
|
workDirBase string
|
|
logger *zap.Logger
|
|
|
|
// async processing
|
|
taskCh chan int64 // buffered channel, cap=16
|
|
cancelFn context.CancelFunc
|
|
wg sync.WaitGroup
|
|
mu sync.Mutex // protects taskCh from send-on-closed
|
|
started bool // prevent double-start
|
|
stopped bool
|
|
}
|
|
|
|
func NewTaskService(
|
|
taskStore *store.TaskStore,
|
|
appStore *store.ApplicationStore,
|
|
fileStore *store.FileStore,
|
|
blobStore *store.BlobStore,
|
|
stagingSvc *FileStagingService,
|
|
jobSvc *JobService,
|
|
workDirBase string,
|
|
logger *zap.Logger,
|
|
) *TaskService {
|
|
return &TaskService{
|
|
taskStore: taskStore,
|
|
appStore: appStore,
|
|
fileStore: fileStore,
|
|
blobStore: blobStore,
|
|
stagingSvc: stagingSvc,
|
|
jobSvc: jobSvc,
|
|
workDirBase: workDirBase,
|
|
logger: logger,
|
|
taskCh: make(chan int64, 16),
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
|
|
app, err := s.appStore.GetByID(ctx, req.AppID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get application: %w", err)
|
|
}
|
|
if app == nil {
|
|
return nil, fmt.Errorf("application %d not found", req.AppID)
|
|
}
|
|
|
|
// 2. Validate file limit
|
|
if len(req.InputFileIDs) > 100 {
|
|
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
|
|
}
|
|
|
|
// 3. Deduplicate file IDs
|
|
fileIDs := uniqueInt64s(req.InputFileIDs)
|
|
|
|
// 4. Validate file IDs exist
|
|
if s.fileStore != nil && len(fileIDs) > 0 {
|
|
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("validate file ids: %w", err)
|
|
}
|
|
found := make(map[int64]bool, len(files))
|
|
for _, f := range files {
|
|
found[f.ID] = true
|
|
}
|
|
for _, id := range fileIDs {
|
|
if !found[id] {
|
|
return nil, fmt.Errorf("file %d not found", id)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 5. Auto-generate task name if empty
|
|
taskName := req.TaskName
|
|
if taskName == "" {
|
|
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
|
|
}
|
|
|
|
// 6. Marshal values
|
|
valuesJSON := json.RawMessage(`{}`)
|
|
if len(req.Values) > 0 {
|
|
b, err := json.Marshal(req.Values)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal values: %w", err)
|
|
}
|
|
valuesJSON = b
|
|
}
|
|
|
|
// 7. Marshal input_file_ids
|
|
fileIDsJSON := json.RawMessage(`[]`)
|
|
if len(fileIDs) > 0 {
|
|
b, err := json.Marshal(fileIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal file ids: %w", err)
|
|
}
|
|
fileIDsJSON = b
|
|
}
|
|
|
|
// 8. Create task record
|
|
task := &model.Task{
|
|
TaskName: taskName,
|
|
AppID: app.ID,
|
|
AppName: app.Name,
|
|
Status: model.TaskStatusSubmitted,
|
|
Values: valuesJSON,
|
|
InputFileIDs: fileIDsJSON,
|
|
SubmittedAt: time.Now(),
|
|
Partition: derefStr(req.Partition),
|
|
Cpus: req.Cpus,
|
|
MemoryPerNode: req.MemoryPerNode,
|
|
MemoryPerCpu: req.MemoryPerCpu,
|
|
TimeLimit: req.TimeLimit,
|
|
QOS: req.QOS,
|
|
JobName: req.JobName,
|
|
Nodes: req.Nodes,
|
|
Tasks: req.Tasks,
|
|
CpusPerTask: req.CpusPerTask,
|
|
Constraints: req.Constraints,
|
|
Reservation: req.Reservation,
|
|
Account: req.Account,
|
|
Nice: req.Nice,
|
|
MailType: req.MailType,
|
|
MailUser: req.MailUser,
|
|
StandardOutput: req.StandardOutput,
|
|
StandardError: req.StandardError,
|
|
StandardInput: req.StandardInput,
|
|
RequiredNodes: req.RequiredNodes,
|
|
ExcludedNodes: req.ExcludedNodes,
|
|
BeginTime: req.BeginTime,
|
|
Deadline: req.Deadline,
|
|
Array: req.Array,
|
|
Dependency: req.Dependency,
|
|
Requeue: req.Requeue,
|
|
KillOnNodeFail: req.KillOnNodeFail,
|
|
}
|
|
|
|
taskID, err := s.taskStore.Create(ctx, task)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create task: %w", err)
|
|
}
|
|
task.ID = taskID
|
|
|
|
return task, nil
|
|
}
|
|
|
|
// ProcessTask runs the full synchronous processing pipeline for a task.
|
|
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
|
|
// 1. Fetch task
|
|
task, err := s.taskStore.GetByID(ctx, taskID)
|
|
if err != nil {
|
|
return fmt.Errorf("get task: %w", err)
|
|
}
|
|
if task == nil {
|
|
return fmt.Errorf("task %d not found", taskID)
|
|
}
|
|
|
|
fail := func(step, msg string) error {
|
|
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
|
|
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
|
|
return fmt.Errorf("%s", msg)
|
|
}
|
|
|
|
currentStep := task.CurrentStep
|
|
|
|
var workDir string
|
|
var app *model.Application
|
|
|
|
if currentStep == "" || currentStep == model.TaskStepPreparing {
|
|
// 2. Set preparing
|
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
|
|
}
|
|
|
|
// 3. Fetch app
|
|
app, err = s.appStore.GetByID(ctx, task.AppID)
|
|
if err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
|
|
}
|
|
if app == nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
|
|
}
|
|
|
|
// 4-5. Create work directory
|
|
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
|
|
if err := os.MkdirAll(workDir, 0777); err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
|
|
}
|
|
|
|
// 6. CHMOD traversal — critical for multi-user HPC
|
|
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
|
|
os.Chmod(dir, 0777)
|
|
}
|
|
os.Chmod(s.workDirBase, 0777)
|
|
|
|
// 7. UpdateWorkDir
|
|
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
|
|
}
|
|
} else {
|
|
app, err = s.appStore.GetByID(ctx, task.AppID)
|
|
if err != nil {
|
|
return fail(currentStep, fmt.Sprintf("get application: %v", err))
|
|
}
|
|
if app == nil {
|
|
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
|
|
}
|
|
workDir = task.WorkDir
|
|
}
|
|
|
|
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
|
|
if currentStep == model.TaskStepDownloading && workDir != "" {
|
|
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
|
|
for _, f := range matches {
|
|
os.Remove(f)
|
|
}
|
|
}
|
|
|
|
// 8. Set downloading
|
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
|
|
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
|
|
}
|
|
|
|
// 9. Parse input_file_ids
|
|
var fileIDs []int64
|
|
if len(task.InputFileIDs) > 0 {
|
|
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
|
|
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
|
|
}
|
|
}
|
|
|
|
// 10-12. Download files
|
|
if len(fileIDs) > 0 {
|
|
if s.stagingSvc == nil {
|
|
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
|
|
}
|
|
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
|
|
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
// 13-14. Set ready + submitting (guard: skip if already submitted to Slurm)
|
|
if task.SlurmJobID != nil {
|
|
s.logger.Info("task already has slurm job, skipping submission",
|
|
zap.Int64("task_id", taskID),
|
|
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
|
|
}
|
|
|
|
// 15. Parse app parameters
|
|
var params []model.ParameterSchema
|
|
if len(app.Parameters) > 0 {
|
|
if err := json.Unmarshal(app.Parameters, ¶ms); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
|
|
}
|
|
}
|
|
|
|
// 15a. Parse app environment
|
|
var appEnv map[string]string
|
|
if len(app.Environment) > 0 {
|
|
if err := json.Unmarshal(app.Environment, &appEnv); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse application environment: %v", err))
|
|
}
|
|
}
|
|
|
|
// 16. Parse task values
|
|
values := make(map[string]string)
|
|
if len(task.Values) > 0 {
|
|
if err := json.Unmarshal(task.Values, &values); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
|
|
}
|
|
}
|
|
|
|
// 16a. Auto-inject WORK_DIR if the app defines it as a parameter.
|
|
// The work directory is created by the server, not provided by the user.
|
|
for _, p := range params {
|
|
if p.Name == "WORK_DIR" {
|
|
values["WORK_DIR"] = workDir
|
|
break
|
|
}
|
|
}
|
|
|
|
// 16b. Map input_file_ids to file-type parameters by order.
|
|
// User selects files via FilePicker; we assign their IDs to file/directory
|
|
// params sequentially so the backend can resolve them to filenames later.
|
|
var inputFileIDs []int64
|
|
if len(task.InputFileIDs) > 0 {
|
|
if err := json.Unmarshal(task.InputFileIDs, &inputFileIDs); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse input file ids: %v", err))
|
|
}
|
|
}
|
|
if len(inputFileIDs) > 0 {
|
|
fileParamIdx := 0
|
|
for _, p := range params {
|
|
if p.Type != model.ParamTypeFile && p.Type != model.ParamTypeDirectory {
|
|
continue
|
|
}
|
|
if fileParamIdx < len(inputFileIDs) {
|
|
values[p.Name] = strconv.FormatInt(inputFileIDs[fileParamIdx], 10)
|
|
fileParamIdx++
|
|
}
|
|
}
|
|
}
|
|
|
|
// 16c. Validate all params (WORK_DIR and file params now have values).
|
|
if err := ValidateParams(params, values); err != nil {
|
|
return fail(model.TaskStepSubmitting, err.Error())
|
|
}
|
|
|
|
// 16d. Resolve file-type parameter values: file_id → filename.
|
|
var fileLookupIDs []int64
|
|
for _, p := range params {
|
|
if p.Type != model.ParamTypeFile && p.Type != model.ParamTypeDirectory {
|
|
continue
|
|
}
|
|
val, ok := values[p.Name]
|
|
if !ok || val == "" {
|
|
continue
|
|
}
|
|
fileID, err := strconv.ParseInt(val, 10, 64)
|
|
if err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parameter %q: invalid file_id %q, expected numeric file ID", p.Name, val))
|
|
}
|
|
fileLookupIDs = append(fileLookupIDs, fileID)
|
|
}
|
|
|
|
if len(fileLookupIDs) > 0 && s.fileStore != nil {
|
|
fetchedFiles, err := s.fileStore.GetByIDs(ctx, fileLookupIDs)
|
|
if err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("fetch file names for parameter resolution: %v", err))
|
|
}
|
|
fileMap := make(map[int64]string, len(fetchedFiles))
|
|
for _, f := range fetchedFiles {
|
|
fileMap[f.ID] = f.Name
|
|
}
|
|
for _, p := range params {
|
|
if p.Type != model.ParamTypeFile && p.Type != model.ParamTypeDirectory {
|
|
continue
|
|
}
|
|
val, ok := values[p.Name]
|
|
if !ok || val == "" {
|
|
continue
|
|
}
|
|
fileID, _ := strconv.ParseInt(val, 10, 64)
|
|
filename, found := fileMap[fileID]
|
|
if !found {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parameter %q: file_id %d not found", p.Name, fileID))
|
|
}
|
|
values[p.Name] = filename
|
|
}
|
|
}
|
|
|
|
// 注入默认调度参数(仅在内存中,不持久化到数据库)
|
|
if task.TimeLimit == nil {
|
|
task.TimeLimit = int32Ptr(10080) // 168 小时
|
|
}
|
|
if task.StandardOutput == nil {
|
|
task.StandardOutput = strToPtrOrNil(filepath.Join(workDir, "slurm-%j.out"))
|
|
}
|
|
if task.StandardError == nil {
|
|
task.StandardError = strToPtrOrNil(filepath.Join(workDir, "slurm-%j.err"))
|
|
}
|
|
|
|
// 17. Render script
|
|
rendered := RenderScript(app.ScriptTemplate, params, values)
|
|
s.logger.Info("rendered script",
|
|
zap.Int64("task_id", taskID),
|
|
zap.String("work_dir", workDir),
|
|
zap.String("script", rendered),
|
|
)
|
|
|
|
// 18. Submit to Slurm
|
|
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
|
|
Script: rendered,
|
|
WorkDir: workDir,
|
|
Partition: task.Partition,
|
|
CPUs: derefInt32(task.Cpus),
|
|
TimeLimit: derefInt32ToStr(task.TimeLimit),
|
|
QOS: derefStr(task.QOS),
|
|
JobName: derefStr(task.JobName),
|
|
MemoryPerNode: task.MemoryPerNode,
|
|
MemoryPerCpu: task.MemoryPerCpu,
|
|
Nodes: task.Nodes,
|
|
Tasks: task.Tasks,
|
|
CpusPerTask: task.CpusPerTask,
|
|
Constraints: task.Constraints,
|
|
Reservation: task.Reservation,
|
|
Account: task.Account,
|
|
Nice: task.Nice,
|
|
MailType: task.MailType,
|
|
MailUser: task.MailUser,
|
|
StandardOutput: task.StandardOutput,
|
|
StandardError: task.StandardError,
|
|
StandardInput: task.StandardInput,
|
|
RequiredNodes: task.RequiredNodes,
|
|
ExcludedNodes: task.ExcludedNodes,
|
|
BeginTime: task.BeginTime,
|
|
Deadline: task.Deadline,
|
|
Array: task.Array,
|
|
Dependency: task.Dependency,
|
|
Requeue: task.Requeue,
|
|
KillOnNodeFail: task.KillOnNodeFail,
|
|
Environment: appEnv,
|
|
})
|
|
if err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
|
|
}
|
|
|
|
// 19. Update slurm_job_id and status to queued
|
|
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
|
|
}
|
|
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListTasks returns a paginated list of tasks.
|
|
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
|
return s.taskStore.List(ctx, query)
|
|
}
|
|
|
|
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
|
|
// for old API compatibility.
|
|
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
|
|
// 1. Create task
|
|
task, err := s.CreateTask(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. Process synchronously
|
|
if err := s.ProcessTask(ctx, task.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Re-fetch to get updated slurm_job_id
|
|
task, err = s.taskStore.GetByID(ctx, task.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("re-fetch task: %w", err)
|
|
}
|
|
if task == nil || task.SlurmJobID == nil {
|
|
return nil, fmt.Errorf("task has no slurm job id after processing")
|
|
}
|
|
|
|
// 4. Return JobResponse
|
|
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
|
|
}
|
|
|
|
// uniqueInt64s deduplicates and sorts a slice of int64.
|
|
func uniqueInt64s(ids []int64) []int64 {
|
|
if len(ids) == 0 {
|
|
return nil
|
|
}
|
|
seen := make(map[int64]bool, len(ids))
|
|
result := make([]int64, 0, len(ids))
|
|
for _, id := range ids {
|
|
if !seen[id] {
|
|
seen[id] = true
|
|
result = append(result, id)
|
|
}
|
|
}
|
|
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
|
|
return result
|
|
}
|
|
|
|
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
|
|
if len(slurmState) == 0 {
|
|
return model.TaskStatusRunning
|
|
}
|
|
|
|
state := strings.ToUpper(slurmState[0])
|
|
switch state {
|
|
case "PENDING":
|
|
return model.TaskStatusQueued
|
|
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
|
|
return model.TaskStatusRunning
|
|
case "COMPLETED":
|
|
return model.TaskStatusCompleted
|
|
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
|
|
return model.TaskStatusFailed
|
|
default:
|
|
return model.TaskStatusRunning
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
|
|
task, err := s.taskStore.GetByID(ctx, taskID)
|
|
if err != nil {
|
|
s.logger.Error("failed to fetch task for refresh",
|
|
zap.Int64("task_id", taskID),
|
|
zap.Error(err),
|
|
)
|
|
return err
|
|
}
|
|
if task == nil || task.SlurmJobID == nil {
|
|
return nil
|
|
}
|
|
|
|
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
|
|
if err != nil {
|
|
s.logger.Warn("failed to query slurm job status during refresh",
|
|
zap.Int64("task_id", taskID),
|
|
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
|
zap.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
if jobResp == nil {
|
|
return nil
|
|
}
|
|
|
|
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
|
|
if newStatus != task.Status {
|
|
s.logger.Info("updating task status from slurm",
|
|
zap.Int64("task_id", taskID),
|
|
zap.String("old_status", task.Status),
|
|
zap.String("new_status", newStatus),
|
|
)
|
|
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
|
|
staleThreshold := 30 * time.Second
|
|
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
|
|
|
|
for _, status := range nonTerminal {
|
|
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
|
|
Status: status,
|
|
Page: 1,
|
|
PageSize: 1000,
|
|
})
|
|
if err != nil {
|
|
s.logger.Warn("failed to list tasks for stale refresh",
|
|
zap.String("status", status),
|
|
zap.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
|
|
cutoff := time.Now().Add(-staleThreshold)
|
|
for i := range tasks {
|
|
if tasks[i].UpdatedAt.Before(cutoff) {
|
|
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
|
|
s.logger.Warn("failed to refresh stale task",
|
|
zap.Int64("task_id", tasks[i].ID),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskService) StartProcessor(ctx context.Context) {
|
|
s.mu.Lock()
|
|
if s.started {
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
s.started = true
|
|
s.mu.Unlock()
|
|
|
|
ctx, s.cancelFn = context.WithCancel(ctx)
|
|
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
s.logger.Error("processor panic", zap.Any("panic", r))
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case taskID, ok := <-s.taskCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
|
|
s.processWithRetry(taskCtx, taskID)
|
|
cancel()
|
|
}
|
|
}
|
|
}()
|
|
|
|
s.RecoverStuckTasks(ctx)
|
|
}
|
|
|
|
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
|
|
task, err := s.CreateTask(ctx, req)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
if s.stopped {
|
|
s.mu.Unlock()
|
|
return 0, fmt.Errorf("processor stopped, cannot submit task")
|
|
}
|
|
select {
|
|
case s.taskCh <- task.ID:
|
|
default:
|
|
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
return task.ID, nil
|
|
}
|
|
|
|
func (s *TaskService) StopProcessor() {
|
|
s.mu.Lock()
|
|
if s.stopped {
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
s.stopped = true
|
|
close(s.taskCh)
|
|
s.mu.Unlock()
|
|
|
|
if s.cancelFn != nil {
|
|
s.cancelFn()
|
|
}
|
|
s.wg.Wait()
|
|
|
|
s.mu.Lock()
|
|
drainCh := s.taskCh
|
|
s.taskCh = make(chan int64, 16)
|
|
s.mu.Unlock()
|
|
|
|
for taskID := range drainCh {
|
|
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
|
|
err := s.ProcessTask(ctx, taskID)
|
|
if err == nil {
|
|
return
|
|
}
|
|
|
|
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
|
|
if fetchErr != nil || task == nil {
|
|
return
|
|
}
|
|
|
|
if task.RetryCount < 3 {
|
|
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
|
|
s.mu.Lock()
|
|
if !s.stopped {
|
|
select {
|
|
case s.taskCh <- taskID:
|
|
default:
|
|
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
|
|
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
|
|
if err != nil {
|
|
s.logger.Error("failed to get stuck tasks", zap.Error(err))
|
|
return
|
|
}
|
|
for i := range tasks {
|
|
if tasks[i].SlurmJobID != nil {
|
|
s.logger.Info("skipping stuck task recovery, already in slurm",
|
|
zap.Int64("taskID", tasks[i].ID),
|
|
zap.Int32("slurm_job_id", *tasks[i].SlurmJobID),
|
|
)
|
|
continue
|
|
}
|
|
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
|
|
s.mu.Lock()
|
|
if !s.stopped {
|
|
select {
|
|
case s.taskCh <- tasks[i].ID:
|
|
default:
|
|
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}
|