Files
hpc/internal/service/job_service.go

348 lines
8.7 KiB
Go

package service
import (
"context"
"fmt"
"strconv"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
)
// JobService wraps Slurm SDK job operations with model mapping and pagination.
type JobService struct {
client *slurm.Client
logger *zap.Logger
}
// NewJobService creates a new JobService with the given Slurm SDK client.
func NewJobService(client *slurm.Client, logger *zap.Logger) *JobService {
return &JobService{client: client, logger: logger}
}
// SubmitJob submits a new job to Slurm and returns the job ID.
func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest) (*model.JobResponse, error) {
script := req.Script
jobDesc := &slurm.JobDescMsg{
Script: &script,
Partition: strToPtrOrNil(req.Partition),
Qos: strToPtrOrNil(req.QOS),
Name: strToPtrOrNil(req.JobName),
}
if req.CPUs > 0 {
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
}
if req.TimeLimit != "" {
if mins, err := strconv.ParseInt(req.TimeLimit, 10, 64); err == nil {
jobDesc.TimeLimit = &slurm.Uint32NoVal{Number: &mins}
}
}
submitReq := &slurm.JobSubmitReq{
Script: &script,
Job: jobDesc,
}
s.logger.Debug("slurm API request",
zap.String("operation", "SubmitJob"),
zap.Any("body", submitReq),
)
start := time.Now()
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
return nil, fmt.Errorf("submit job: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Any("body", result),
)
resp := &model.JobResponse{}
if result.Result != nil && result.Result.JobID != nil {
resp.JobID = *result.Result.JobID
} else if result.JobID != nil {
resp.JobID = *result.JobID
}
s.logger.Info("job submitted", zap.String("job_name", req.JobName), zap.Int32("job_id", resp.JobID))
return resp, nil
}
// GetJobs lists all current jobs from Slurm.
func (s *JobService) GetJobs(ctx context.Context) ([]model.JobResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobs"),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
return nil, fmt.Errorf("get jobs: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
jobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
jobs = append(jobs, mapJobInfo(&result.Jobs[i]))
}
return jobs, nil
}
// GetJob retrieves a single job by ID.
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
return nil, fmt.Errorf("get job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
return nil, nil
}
resp := mapJobInfo(&result.Jobs[0])
return &resp, nil
}
// CancelJob cancels a job by ID.
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
s.logger.Debug("slurm API request",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
return fmt.Errorf("cancel job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
s.logger.Info("job cancelled", zap.String("job_id", jobID))
return nil
}
// GetJobHistory queries SlurmDBD for historical jobs with pagination.
func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQuery) (*model.JobListResponse, error) {
opts := &slurm.GetSlurmdbJobsOptions{}
if query.Users != "" {
opts.Users = strToPtr(query.Users)
}
if query.Account != "" {
opts.Account = strToPtr(query.Account)
}
if query.Partition != "" {
opts.Partition = strToPtr(query.Partition)
}
if query.State != "" {
opts.State = strToPtr(query.State)
}
if query.JobName != "" {
opts.JobName = strToPtr(query.JobName)
}
if query.StartTime != "" {
opts.StartTime = strToPtr(query.StartTime)
}
if query.EndTime != "" {
opts.EndTime = strToPtr(query.EndTime)
}
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobHistory"),
zap.Any("body", opts),
)
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
return nil, fmt.Errorf("get job history: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
}
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// ---------------------------------------------------------------------------
// Helper functions
// ---------------------------------------------------------------------------
func strToPtr(s string) *string { return &s }
// strPtrOrNil returns a pointer to s if non-empty, otherwise nil.
func strToPtrOrNil(s string) *string {
if s == "" {
return nil
}
return &s
}
// mapJobInfo maps SDK JobInfo to API JobResponse.
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
resp := model.JobResponse{}
if ji.JobID != nil {
resp.JobID = *ji.JobID
}
if ji.Name != nil {
resp.Name = *ji.Name
}
resp.State = ji.JobState
if ji.Partition != nil {
resp.Partition = *ji.Partition
}
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
resp.SubmitTime = ji.SubmitTime.Number
}
if ji.StartTime != nil && ji.StartTime.Number != nil {
resp.StartTime = ji.StartTime.Number
}
if ji.EndTime != nil && ji.EndTime.Number != nil {
resp.EndTime = ji.EndTime.Number
}
if ji.ExitCode != nil && ji.ExitCode.ReturnCode != nil && ji.ExitCode.ReturnCode.Number != nil {
code := int32(*ji.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if ji.Nodes != nil {
resp.Nodes = *ji.Nodes
}
return resp
}
// mapSlurmdbJob maps SDK SlurmDBD Job to API JobResponse.
func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
resp := model.JobResponse{}
if j.JobID != nil {
resp.JobID = *j.JobID
}
if j.Name != nil {
resp.Name = *j.Name
}
if j.State != nil {
resp.State = j.State.Current
}
if j.Partition != nil {
resp.Partition = *j.Partition
}
if j.Time != nil {
if j.Time.Submission != nil {
resp.SubmitTime = j.Time.Submission
}
if j.Time.Start != nil {
resp.StartTime = j.Time.Start
}
if j.Time.End != nil {
resp.EndTime = j.Time.End
}
}
if j.Nodes != nil {
resp.Nodes = *j.Nodes
}
return resp
}