142 lines
3.5 KiB
Go
142 lines
3.5 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"gcy_hpc_server/internal/model"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type TaskStore struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewTaskStore(db *gorm.DB) *TaskStore {
|
|
return &TaskStore{db: db}
|
|
}
|
|
|
|
func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) {
|
|
if err := s.db.WithContext(ctx).Create(task).Error; err != nil {
|
|
return 0, err
|
|
}
|
|
return task.ID, nil
|
|
}
|
|
|
|
func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) {
|
|
var task model.Task
|
|
err := s.db.WithContext(ctx).First(&task, id).Error
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &task, nil
|
|
}
|
|
|
|
func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
|
page := query.Page
|
|
pageSize := query.PageSize
|
|
if page <= 0 {
|
|
page = 1
|
|
}
|
|
if pageSize <= 0 {
|
|
pageSize = 10
|
|
}
|
|
|
|
q := s.db.WithContext(ctx).Model(&model.Task{})
|
|
if query.Status != "" {
|
|
q = q.Where("status = ?", query.Status)
|
|
}
|
|
|
|
var total int64
|
|
if err := q.Count(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
var tasks []model.Task
|
|
offset := (page - 1) * pageSize
|
|
if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return tasks, total, nil
|
|
}
|
|
|
|
func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error {
|
|
updates := map[string]interface{}{
|
|
"status": status,
|
|
"error_message": errorMsg,
|
|
}
|
|
|
|
now := time.Now()
|
|
switch status {
|
|
case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady,
|
|
model.TaskStatusQueued, model.TaskStatusRunning:
|
|
updates["started_at"] = &now
|
|
case model.TaskStatusCompleted, model.TaskStatusFailed:
|
|
updates["finished_at"] = &now
|
|
}
|
|
|
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("task %d not found", id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error {
|
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
|
|
Update("slurm_job_id", slurmJobID)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("task %d not found", id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error {
|
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
|
|
Update("work_dir", workDir)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("task %d not found", id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error {
|
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
|
"status": status,
|
|
"current_step": currentStep,
|
|
"retry_count": retryCount,
|
|
})
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("task %d not found", id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) {
|
|
cutoff := time.Now().Add(-maxAge)
|
|
var tasks []model.Task
|
|
err := s.db.WithContext(ctx).
|
|
Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}).
|
|
Where("updated_at < ?", cutoff).
|
|
Find(&tasks).Error
|
|
return tasks, err
|
|
}
|