package store import ( "context" "errors" "fmt" "time" "gcy_hpc_server/internal/model" "gorm.io/gorm" ) type TaskStore struct { db *gorm.DB } func NewTaskStore(db *gorm.DB) *TaskStore { return &TaskStore{db: db} } func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) { if err := s.db.WithContext(ctx).Create(task).Error; err != nil { return 0, err } return task.ID, nil } func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) { var task model.Task err := s.db.WithContext(ctx).First(&task, id).Error if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } if err != nil { return nil, err } return &task, nil } func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) { page := query.Page pageSize := query.PageSize if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 10 } q := s.db.WithContext(ctx).Model(&model.Task{}) if query.Status != "" { q = q.Where("status = ?", query.Status) } var total int64 if err := q.Count(&total).Error; err != nil { return nil, 0, err } var tasks []model.Task offset := (page - 1) * pageSize if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil { return nil, 0, err } return tasks, total, nil } func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error { updates := map[string]interface{}{ "status": status, "error_message": errorMsg, } now := time.Now() switch status { case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady, model.TaskStatusQueued, model.TaskStatusRunning: updates["started_at"] = &now case model.TaskStatusCompleted, model.TaskStatusFailed: updates["finished_at"] = &now } result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return fmt.Errorf("task %d not found", id) } return nil } func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error { result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id). Update("slurm_job_id", slurmJobID) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return fmt.Errorf("task %d not found", id) } return nil } func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error { result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id). Update("work_dir", workDir) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return fmt.Errorf("task %d not found", id) } return nil } func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error { result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{ "status": status, "current_step": currentStep, "retry_count": retryCount, }) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return fmt.Errorf("task %d not found", id) } return nil } func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) { cutoff := time.Now().Add(-maxAge) var tasks []model.Task err := s.db.WithContext(ctx). Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}). Where("updated_at < ?", cutoff). Find(&tasks).Error return tasks, err }