Add scheduling_map field to ParameterSchema so Application creators can declare that a parameter (e.g. NP) maps to a scheduling field (e.g. cpus). The backend auto-injects the scheduling value into script template variables before rendering, eliminating duplicate user input. The frontend hides mapped parameters from the form and injects their values on submit.
812 lines
24 KiB
Go
812 lines
24 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gcy_hpc_server/internal/model"
|
|
"gcy_hpc_server/internal/store"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type TaskService struct {
|
|
taskStore *store.TaskStore
|
|
appStore *store.ApplicationStore
|
|
fileStore *store.FileStore // nil ok
|
|
blobStore *store.BlobStore // nil ok
|
|
stagingSvc *FileStagingService // nil ok — MinIO unavailable
|
|
jobSvc *JobService
|
|
workDirBase string
|
|
logger *zap.Logger
|
|
|
|
// async processing
|
|
taskCh chan int64 // buffered channel for task IDs awaiting processing
|
|
cancelFn context.CancelFunc
|
|
wg sync.WaitGroup
|
|
mu sync.Mutex // protects taskCh from send-on-closed
|
|
started bool // prevent double-start
|
|
stopped bool
|
|
// inflight tracks task IDs currently being processed by the worker goroutine.
|
|
//
|
|
// Why it exists: taskCh is an in-memory Go channel — all pending taskIDs are
|
|
// lost when the server restarts. RecoverStuckTasks is responsible for
|
|
// recovering those lost tasks from the DB. However, GetStuckTasks uses a
|
|
// broad query (status NOT IN completed/failed AND updated_at < 5min ago) that
|
|
// also matches tasks being actively processed by the worker (e.g. a slow
|
|
// download). Without inflight, RecoverStuckTasks would reset those tasks to
|
|
// "submitted" and re-enqueue them, causing double-processing.
|
|
//
|
|
// How it works:
|
|
// - ProcessTask stores the taskID on entry, deletes on exit (via defer).
|
|
// - RecoverStuckTasks checks inflight before re-enqueueing; in-flight tasks
|
|
// are skipped.
|
|
// - On server restart inflight is empty (in-memory), so all genuinely stuck
|
|
// tasks are correctly recovered without false negatives.
|
|
inflight sync.Map // map[int64]struct{}
|
|
}
|
|
|
|
func NewTaskService(
|
|
taskStore *store.TaskStore,
|
|
appStore *store.ApplicationStore,
|
|
fileStore *store.FileStore,
|
|
blobStore *store.BlobStore,
|
|
stagingSvc *FileStagingService,
|
|
jobSvc *JobService,
|
|
workDirBase string,
|
|
logger *zap.Logger,
|
|
) *TaskService {
|
|
return &TaskService{
|
|
taskStore: taskStore,
|
|
appStore: appStore,
|
|
fileStore: fileStore,
|
|
blobStore: blobStore,
|
|
stagingSvc: stagingSvc,
|
|
jobSvc: jobSvc,
|
|
workDirBase: workDirBase,
|
|
logger: logger,
|
|
taskCh: make(chan int64, 10000),
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
|
|
app, err := s.appStore.GetByID(ctx, req.AppID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get application: %w", err)
|
|
}
|
|
if app == nil {
|
|
return nil, fmt.Errorf("application %d not found", req.AppID)
|
|
}
|
|
|
|
// 2. Validate file limit
|
|
if len(req.InputFileIDs) > 100 {
|
|
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
|
|
}
|
|
|
|
// 3. Deduplicate file IDs
|
|
fileIDs := uniqueInt64s(req.InputFileIDs)
|
|
|
|
// 4. Validate file IDs exist
|
|
if s.fileStore != nil && len(fileIDs) > 0 {
|
|
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("validate file ids: %w", err)
|
|
}
|
|
found := make(map[int64]bool, len(files))
|
|
for _, f := range files {
|
|
found[f.ID] = true
|
|
}
|
|
for _, id := range fileIDs {
|
|
if !found[id] {
|
|
return nil, fmt.Errorf("file %d not found", id)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 5. Auto-generate task name if empty
|
|
taskName := req.TaskName
|
|
if taskName == "" {
|
|
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
|
|
}
|
|
|
|
// 6. Marshal values
|
|
valuesJSON := json.RawMessage(`{}`)
|
|
if len(req.Values) > 0 {
|
|
b, err := json.Marshal(req.Values)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal values: %w", err)
|
|
}
|
|
valuesJSON = b
|
|
}
|
|
|
|
// 7. Marshal input_file_ids
|
|
fileIDsJSON := json.RawMessage(`[]`)
|
|
if len(fileIDs) > 0 {
|
|
b, err := json.Marshal(fileIDs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal file ids: %w", err)
|
|
}
|
|
fileIDsJSON = b
|
|
}
|
|
|
|
// 8. Create task record
|
|
task := &model.Task{
|
|
TaskName: taskName,
|
|
AppID: app.ID,
|
|
AppName: app.Name,
|
|
Status: model.TaskStatusSubmitted,
|
|
Values: valuesJSON,
|
|
InputFileIDs: fileIDsJSON,
|
|
SubmittedAt: time.Now(),
|
|
Partition: derefStr(req.Partition),
|
|
Cpus: req.Cpus,
|
|
MemoryPerNode: req.MemoryPerNode,
|
|
MemoryPerCpu: req.MemoryPerCpu,
|
|
TimeLimit: req.TimeLimit,
|
|
QOS: req.QOS,
|
|
JobName: req.JobName,
|
|
Nodes: req.Nodes,
|
|
Tasks: req.Tasks,
|
|
CpusPerTask: req.CpusPerTask,
|
|
Constraints: req.Constraints,
|
|
Reservation: req.Reservation,
|
|
Account: req.Account,
|
|
Nice: req.Nice,
|
|
MailType: req.MailType,
|
|
MailUser: req.MailUser,
|
|
StandardOutput: req.StandardOutput,
|
|
StandardError: req.StandardError,
|
|
StandardInput: req.StandardInput,
|
|
RequiredNodes: req.RequiredNodes,
|
|
ExcludedNodes: req.ExcludedNodes,
|
|
BeginTime: req.BeginTime,
|
|
Deadline: req.Deadline,
|
|
Array: req.Array,
|
|
Dependency: req.Dependency,
|
|
Requeue: req.Requeue,
|
|
KillOnNodeFail: req.KillOnNodeFail,
|
|
}
|
|
|
|
taskID, err := s.taskStore.Create(ctx, task)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create task: %w", err)
|
|
}
|
|
task.ID = taskID
|
|
|
|
return task, nil
|
|
}
|
|
|
|
// ProcessTask runs the full synchronous processing pipeline for a task.
|
|
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
|
|
s.inflight.Store(taskID, struct{}{})
|
|
defer s.inflight.Delete(taskID)
|
|
|
|
// 1. Fetch task
|
|
task, err := s.taskStore.GetByID(ctx, taskID)
|
|
if err != nil {
|
|
return fmt.Errorf("get task: %w", err)
|
|
}
|
|
if task == nil {
|
|
return fmt.Errorf("task %d not found", taskID)
|
|
}
|
|
|
|
// Defense-in-depth against duplicate processing. When the same taskID enters
|
|
// taskCh multiple times (e.g. submitted normally + RecoverStuckTasks also
|
|
// enqueues it before the worker picks up the first copy), the worker processes
|
|
// them sequentially. The first invocation changes status from "submitted" to
|
|
// "preparing"; the second invocation reads the latest DB status, sees
|
|
// non-submitted, and safely skips.
|
|
//
|
|
// This does NOT block retries: processWithRetry sets status back to "submitted"
|
|
// before re-enqueueing, so the retried invocation passes this check and
|
|
// continues from the saved currentStep.
|
|
if task.Status != model.TaskStatusSubmitted {
|
|
s.logger.Debug("skipping task with non-submitted status",
|
|
zap.Int64("task_id", taskID),
|
|
zap.String("status", string(task.Status)),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
fail := func(step, msg string) error {
|
|
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
|
|
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
|
|
return fmt.Errorf("%s", msg)
|
|
}
|
|
|
|
currentStep := task.CurrentStep
|
|
|
|
var workDir string
|
|
var app *model.Application
|
|
|
|
if currentStep == "" || currentStep == model.TaskStepPreparing {
|
|
// 2. Set preparing
|
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
|
|
}
|
|
|
|
// 3. Fetch app
|
|
app, err = s.appStore.GetByID(ctx, task.AppID)
|
|
if err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
|
|
}
|
|
if app == nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
|
|
}
|
|
|
|
// 4-5. Create work directory
|
|
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
|
|
if err := os.MkdirAll(workDir, 0777); err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
|
|
}
|
|
|
|
// 6. CHMOD traversal — critical for multi-user HPC
|
|
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
|
|
os.Chmod(dir, 0777)
|
|
}
|
|
os.Chmod(s.workDirBase, 0777)
|
|
|
|
// 7. UpdateWorkDir
|
|
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
|
|
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
|
|
}
|
|
} else {
|
|
app, err = s.appStore.GetByID(ctx, task.AppID)
|
|
if err != nil {
|
|
return fail(currentStep, fmt.Sprintf("get application: %v", err))
|
|
}
|
|
if app == nil {
|
|
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
|
|
}
|
|
workDir = task.WorkDir
|
|
}
|
|
|
|
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
|
|
if currentStep == model.TaskStepDownloading && workDir != "" {
|
|
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
|
|
for _, f := range matches {
|
|
os.Remove(f)
|
|
}
|
|
}
|
|
|
|
// 8. Set downloading
|
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
|
|
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
|
|
}
|
|
|
|
// 9. Parse input_file_ids
|
|
var fileIDs []int64
|
|
if len(task.InputFileIDs) > 0 {
|
|
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
|
|
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
|
|
}
|
|
}
|
|
|
|
// 10-12. Download files
|
|
if len(fileIDs) > 0 {
|
|
if s.stagingSvc == nil {
|
|
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
|
|
}
|
|
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
|
|
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
// 13-14. Set ready + submitting (guard: skip if already submitted to Slurm)
|
|
if task.SlurmJobID != nil {
|
|
s.logger.Info("task already has slurm job, skipping submission",
|
|
zap.Int64("task_id", taskID),
|
|
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
|
|
}
|
|
|
|
// 15. Parse app parameters
|
|
var params []model.ParameterSchema
|
|
if len(app.Parameters) > 0 {
|
|
if err := json.Unmarshal(app.Parameters, ¶ms); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
|
|
}
|
|
}
|
|
|
|
// 15a. Parse app environment
|
|
var appEnv map[string]string
|
|
if len(app.Environment) > 0 {
|
|
if err := json.Unmarshal(app.Environment, &appEnv); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse application environment: %v", err))
|
|
}
|
|
}
|
|
|
|
// 16. Parse task values
|
|
values := make(map[string]string)
|
|
if len(task.Values) > 0 {
|
|
if err := json.Unmarshal(task.Values, &values); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
|
|
}
|
|
}
|
|
|
|
// 16a. Auto-inject WORK_DIR if the app defines it as a parameter.
|
|
// The work directory is created by the server, not provided by the user.
|
|
for _, p := range params {
|
|
if p.Name == "WORK_DIR" {
|
|
values["WORK_DIR"] = workDir
|
|
break
|
|
}
|
|
}
|
|
|
|
// 16b. Map input_file_ids to file-type parameters by order.
|
|
// User selects files via FilePicker; we assign their IDs to file/directory
|
|
// params sequentially so the backend can resolve them to filenames later.
|
|
var inputFileIDs []int64
|
|
if len(task.InputFileIDs) > 0 {
|
|
if err := json.Unmarshal(task.InputFileIDs, &inputFileIDs); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse input file ids: %v", err))
|
|
}
|
|
}
|
|
if len(inputFileIDs) > 0 {
|
|
fileParamIdx := 0
|
|
for _, p := range params {
|
|
if p.Type != model.ParamTypeFile && p.Type != model.ParamTypeDirectory {
|
|
continue
|
|
}
|
|
if fileParamIdx < len(inputFileIDs) {
|
|
values[p.Name] = strconv.FormatInt(inputFileIDs[fileParamIdx], 10)
|
|
fileParamIdx++
|
|
}
|
|
}
|
|
}
|
|
|
|
// 16b-3. Auto-inject scheduling params based on scheduling_map.
|
|
// If an Application parameter declares scheduling_map, the corresponding
|
|
// scheduling field value overrides any user-provided value.
|
|
for _, p := range params {
|
|
if p.SchedulingMap == "" {
|
|
continue
|
|
}
|
|
if val := ResolveSchedulingMap(p.SchedulingMap, task); val != "" {
|
|
values[p.Name] = val
|
|
}
|
|
}
|
|
|
|
// 16c. Validate all params (WORK_DIR and file params now have values).
|
|
if err := ValidateParams(params, values); err != nil {
|
|
return fail(model.TaskStepSubmitting, err.Error())
|
|
}
|
|
|
|
// 16d. Resolve file-type parameter values: file_id → filename.
|
|
var fileLookupIDs []int64
|
|
for _, p := range params {
|
|
if p.Type != model.ParamTypeFile && p.Type != model.ParamTypeDirectory {
|
|
continue
|
|
}
|
|
val, ok := values[p.Name]
|
|
if !ok || val == "" {
|
|
continue
|
|
}
|
|
fileID, err := strconv.ParseInt(val, 10, 64)
|
|
if err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parameter %q: invalid file_id %q, expected numeric file ID", p.Name, val))
|
|
}
|
|
fileLookupIDs = append(fileLookupIDs, fileID)
|
|
}
|
|
|
|
if len(fileLookupIDs) > 0 && s.fileStore != nil {
|
|
fetchedFiles, err := s.fileStore.GetByIDs(ctx, fileLookupIDs)
|
|
if err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("fetch file names for parameter resolution: %v", err))
|
|
}
|
|
fileMap := make(map[int64]string, len(fetchedFiles))
|
|
for _, f := range fetchedFiles {
|
|
fileMap[f.ID] = f.Name
|
|
}
|
|
for _, p := range params {
|
|
if p.Type != model.ParamTypeFile && p.Type != model.ParamTypeDirectory {
|
|
continue
|
|
}
|
|
val, ok := values[p.Name]
|
|
if !ok || val == "" {
|
|
continue
|
|
}
|
|
fileID, _ := strconv.ParseInt(val, 10, 64)
|
|
filename, found := fileMap[fileID]
|
|
if !found {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parameter %q: file_id %d not found", p.Name, fileID))
|
|
}
|
|
values[p.Name] = filename
|
|
}
|
|
}
|
|
|
|
// 注入默认调度参数(仅在内存中,不持久化到数据库)
|
|
if task.TimeLimit == nil {
|
|
task.TimeLimit = int32Ptr(10080) // 168 小时
|
|
}
|
|
if task.StandardOutput == nil {
|
|
task.StandardOutput = strToPtrOrNil(filepath.Join(workDir, "slurm-%j.out"))
|
|
}
|
|
if task.StandardError == nil {
|
|
task.StandardError = strToPtrOrNil(filepath.Join(workDir, "slurm-%j.err"))
|
|
}
|
|
|
|
// 17. Render script
|
|
rendered := RenderScript(app.ScriptTemplate, params, values)
|
|
s.logger.Info("rendered script",
|
|
zap.Int64("task_id", taskID),
|
|
zap.String("work_dir", workDir),
|
|
zap.String("script", rendered),
|
|
)
|
|
|
|
// 18. Submit to Slurm
|
|
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
|
|
Script: rendered,
|
|
WorkDir: workDir,
|
|
Partition: task.Partition,
|
|
CPUs: derefInt32(task.Cpus),
|
|
TimeLimit: derefInt32ToStr(task.TimeLimit),
|
|
QOS: derefStr(task.QOS),
|
|
JobName: derefStr(task.JobName),
|
|
MemoryPerNode: task.MemoryPerNode,
|
|
MemoryPerCpu: task.MemoryPerCpu,
|
|
Nodes: task.Nodes,
|
|
Tasks: task.Tasks,
|
|
CpusPerTask: task.CpusPerTask,
|
|
Constraints: task.Constraints,
|
|
Reservation: task.Reservation,
|
|
Account: task.Account,
|
|
Nice: task.Nice,
|
|
MailType: task.MailType,
|
|
MailUser: task.MailUser,
|
|
StandardOutput: task.StandardOutput,
|
|
StandardError: task.StandardError,
|
|
StandardInput: task.StandardInput,
|
|
RequiredNodes: task.RequiredNodes,
|
|
ExcludedNodes: task.ExcludedNodes,
|
|
BeginTime: task.BeginTime,
|
|
Deadline: task.Deadline,
|
|
Array: task.Array,
|
|
Dependency: task.Dependency,
|
|
Requeue: task.Requeue,
|
|
KillOnNodeFail: task.KillOnNodeFail,
|
|
Environment: appEnv,
|
|
})
|
|
if err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
|
|
}
|
|
|
|
// 19. Update slurm_job_id and status to queued
|
|
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
|
|
}
|
|
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
|
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListTasks returns a paginated list of tasks.
|
|
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
|
return s.taskStore.List(ctx, query)
|
|
}
|
|
|
|
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
|
|
// for old API compatibility.
|
|
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
|
|
// 1. Create task
|
|
task, err := s.CreateTask(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. Process synchronously
|
|
if err := s.ProcessTask(ctx, task.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 3. Re-fetch to get updated slurm_job_id
|
|
task, err = s.taskStore.GetByID(ctx, task.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("re-fetch task: %w", err)
|
|
}
|
|
if task == nil || task.SlurmJobID == nil {
|
|
return nil, fmt.Errorf("task has no slurm job id after processing")
|
|
}
|
|
|
|
// 4. Return JobResponse
|
|
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
|
|
}
|
|
|
|
// uniqueInt64s deduplicates and sorts a slice of int64.
|
|
func uniqueInt64s(ids []int64) []int64 {
|
|
if len(ids) == 0 {
|
|
return nil
|
|
}
|
|
seen := make(map[int64]bool, len(ids))
|
|
result := make([]int64, 0, len(ids))
|
|
for _, id := range ids {
|
|
if !seen[id] {
|
|
seen[id] = true
|
|
result = append(result, id)
|
|
}
|
|
}
|
|
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
|
|
return result
|
|
}
|
|
|
|
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
|
|
if len(slurmState) == 0 {
|
|
return ""
|
|
}
|
|
|
|
state := strings.ToUpper(slurmState[0])
|
|
switch state {
|
|
case "PENDING":
|
|
return model.TaskStatusQueued
|
|
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
|
|
return model.TaskStatusRunning
|
|
case "COMPLETED":
|
|
return model.TaskStatusCompleted
|
|
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
|
|
return model.TaskStatusFailed
|
|
default:
|
|
s.logger.Warn("unrecognized slurm state, skipping update", zap.String("state", state))
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
|
|
task, err := s.taskStore.GetByID(ctx, taskID)
|
|
if err != nil {
|
|
s.logger.Error("failed to fetch task for refresh",
|
|
zap.Int64("task_id", taskID),
|
|
zap.Error(err),
|
|
)
|
|
return err
|
|
}
|
|
if task == nil || task.SlurmJobID == nil {
|
|
return nil
|
|
}
|
|
|
|
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
|
|
if err != nil {
|
|
s.logger.Warn("failed to query slurm job status during refresh",
|
|
zap.Int64("task_id", taskID),
|
|
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
|
zap.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
if jobResp == nil {
|
|
return nil
|
|
}
|
|
|
|
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
|
|
if newStatus == "" || newStatus == task.Status {
|
|
return nil
|
|
}
|
|
|
|
s.logger.Info("updating task status from slurm",
|
|
zap.Int64("task_id", taskID),
|
|
zap.String("old_status", task.Status),
|
|
zap.String("new_status", newStatus),
|
|
)
|
|
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
|
|
}
|
|
|
|
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
|
|
staleThreshold := 30 * time.Second
|
|
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
|
|
|
|
for _, status := range nonTerminal {
|
|
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
|
|
Status: status,
|
|
Page: 1,
|
|
PageSize: 1000,
|
|
})
|
|
if err != nil {
|
|
s.logger.Warn("failed to list tasks for stale refresh",
|
|
zap.String("status", status),
|
|
zap.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
|
|
cutoff := time.Now().Add(-staleThreshold)
|
|
for i := range tasks {
|
|
if tasks[i].UpdatedAt.Before(cutoff) {
|
|
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
|
|
s.logger.Warn("failed to refresh stale task",
|
|
zap.Int64("task_id", tasks[i].ID),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskService) StartProcessor(ctx context.Context) {
|
|
s.mu.Lock()
|
|
if s.started {
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
s.started = true
|
|
s.mu.Unlock()
|
|
|
|
ctx, s.cancelFn = context.WithCancel(ctx)
|
|
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
s.logger.Error("processor panic", zap.Any("panic", r))
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case taskID, ok := <-s.taskCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
|
|
s.processWithRetry(taskCtx, taskID)
|
|
cancel()
|
|
}
|
|
}
|
|
}()
|
|
|
|
s.RecoverStuckTasks(ctx)
|
|
}
|
|
|
|
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
|
|
task, err := s.CreateTask(ctx, req)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
if s.stopped {
|
|
s.mu.Unlock()
|
|
return 0, fmt.Errorf("processor stopped, cannot submit task")
|
|
}
|
|
select {
|
|
case s.taskCh <- task.ID:
|
|
default:
|
|
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
return task.ID, nil
|
|
}
|
|
|
|
func (s *TaskService) StopProcessor() {
|
|
s.mu.Lock()
|
|
if s.stopped {
|
|
s.mu.Unlock()
|
|
return
|
|
}
|
|
s.stopped = true
|
|
close(s.taskCh)
|
|
s.mu.Unlock()
|
|
|
|
if s.cancelFn != nil {
|
|
s.cancelFn()
|
|
}
|
|
s.wg.Wait()
|
|
|
|
s.mu.Lock()
|
|
drainCh := s.taskCh
|
|
s.taskCh = make(chan int64, 10000)
|
|
s.mu.Unlock()
|
|
|
|
for taskID := range drainCh {
|
|
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
|
|
err := s.ProcessTask(ctx, taskID)
|
|
if err == nil {
|
|
return
|
|
}
|
|
|
|
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
|
|
if fetchErr != nil || task == nil {
|
|
return
|
|
}
|
|
|
|
if task.RetryCount < 3 {
|
|
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
|
|
s.mu.Lock()
|
|
if !s.stopped {
|
|
select {
|
|
case s.taskCh <- taskID:
|
|
default:
|
|
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
|
|
// RecoverStuckTasks recovers tasks that are "stuck" — they exist in the DB
|
|
// with a non-terminal status but are not being processed.
|
|
//
|
|
// Scenarios that create stuck tasks:
|
|
//
|
|
// 1. Server restart: taskCh is an in-memory Go channel, all pending IDs are
|
|
// lost on process exit. Tasks that were queued but never picked up by the
|
|
// worker remain in "submitted" status in DB with no one to process them.
|
|
//
|
|
// 2. Server crash mid-processing: the worker had advanced a task to
|
|
// "preparing"/"downloading" and then died. The task sits in that
|
|
// intermediate state with no SlurmJobID and no worker to continue.
|
|
//
|
|
// 3. Channel full: SubmitAsync dropped a task because taskCh was at
|
|
// capacity. The task stays "submitted" but was never enqueued.
|
|
//
|
|
// The bug this fix addresses:
|
|
//
|
|
// GetStuckTasks queries: status NOT IN (completed, failed) AND updated_at <
|
|
// 5min ago. This also matches tasks currently being processed by the worker
|
|
// whose step is slow (>5 min, e.g. downloading large files) and hasn't
|
|
// refreshed updated_at. Without the inflight check below, this function
|
|
// would reset such a task to "submitted" and re-enqueue it, causing the
|
|
// same task to be processed by two concurrent invocations of ProcessTask.
|
|
//
|
|
// Fix: the inflight sync.Map tracks taskIDs currently inside ProcessTask.
|
|
// Tasks found in inflight are skipped here. On server restart inflight is
|
|
// empty (it's in-memory), so all genuinely stuck tasks from scenarios 1-3
|
|
// above are correctly recovered.
|
|
|
|
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
|
|
if err != nil {
|
|
s.logger.Error("failed to get stuck tasks", zap.Error(err))
|
|
return
|
|
}
|
|
for i := range tasks {
|
|
if tasks[i].SlurmJobID != nil {
|
|
s.logger.Info("skipping stuck task recovery, already in slurm",
|
|
zap.Int64("taskID", tasks[i].ID),
|
|
zap.Int32("slurm_job_id", *tasks[i].SlurmJobID),
|
|
)
|
|
continue
|
|
}
|
|
if _, ok := s.inflight.Load(tasks[i].ID); ok {
|
|
s.logger.Debug("skipping in-flight task",
|
|
zap.Int64("taskID", tasks[i].ID),
|
|
)
|
|
continue
|
|
}
|
|
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
|
|
s.mu.Lock()
|
|
if !s.stopped {
|
|
select {
|
|
case s.taskCh <- tasks[i].ID:
|
|
default:
|
|
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}
|