Files
hpc/internal/service/job_service.go
dailz 32f5792b68 feat(service): pass work directory to Slurm job submission
Add WorkDir to SubmitJobRequest and pass it as CurrentWorkingDirectory to Slurm REST API. Fixes Slurm 500 error when working directory is not specified.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:12:28 +08:00

492 lines
13 KiB
Go

package service
import (
"context"
"fmt"
"strconv"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
)
// JobService wraps Slurm SDK job operations with model mapping and pagination.
type JobService struct {
client *slurm.Client
logger *zap.Logger
}
// NewJobService creates a new JobService with the given Slurm SDK client.
func NewJobService(client *slurm.Client, logger *zap.Logger) *JobService {
return &JobService{client: client, logger: logger}
}
// SubmitJob submits a new job to Slurm and returns the job ID.
func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest) (*model.JobResponse, error) {
script := req.Script
jobDesc := &slurm.JobDescMsg{
Script: &script,
Partition: strToPtrOrNil(req.Partition),
Qos: strToPtrOrNil(req.QOS),
Name: strToPtrOrNil(req.JobName),
}
if req.WorkDir != "" {
jobDesc.CurrentWorkingDirectory = &req.WorkDir
}
if req.CPUs > 0 {
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
}
if req.TimeLimit != "" {
if mins, err := strconv.ParseInt(req.TimeLimit, 10, 64); err == nil {
jobDesc.TimeLimit = &slurm.Uint32NoVal{Number: &mins}
}
}
submitReq := &slurm.JobSubmitReq{
Script: &script,
Job: jobDesc,
}
s.logger.Debug("slurm API request",
zap.String("operation", "SubmitJob"),
zap.Any("body", submitReq),
)
start := time.Now()
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
return nil, fmt.Errorf("submit job: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Any("body", result),
)
resp := &model.JobResponse{}
if result.Result != nil && result.Result.JobID != nil {
resp.JobID = *result.Result.JobID
} else if result.JobID != nil {
resp.JobID = *result.JobID
}
s.logger.Info("job submitted", zap.String("job_name", req.JobName), zap.Int32("job_id", resp.JobID))
return resp, nil
}
// GetJobs lists all current jobs from Slurm with in-memory pagination.
func (s *JobService) GetJobs(ctx context.Context, query *model.JobListQuery) (*model.JobListResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobs"),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
return nil, fmt.Errorf("get jobs: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapJobInfo(&result.Jobs[i]))
}
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// GetJob retrieves a single job by ID. If the job is not found in the active
// queue (404 or empty result), it falls back to querying SlurmDBD history.
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
if slurm.IsNotFound(err) {
s.logger.Debug("job not in active queue, querying history",
zap.String("job_id", jobID),
)
return s.getJobFromHistory(ctx, jobID)
}
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
return nil, fmt.Errorf("get job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
s.logger.Debug("empty jobs response, querying history",
zap.String("job_id", jobID),
)
return s.getJobFromHistory(ctx, jobID)
}
resp := mapJobInfo(&result.Jobs[0])
return &resp, nil
}
func (s *JobService) getJobFromHistory(ctx context.Context, jobID string) (*model.JobResponse, error) {
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJob(ctx, jobID)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurmdb API error response",
zap.String("operation", "getJobFromHistory"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
if slurm.IsNotFound(err) {
return nil, nil
}
return nil, fmt.Errorf("get job history %s: %w", jobID, err)
}
s.logger.Debug("slurmdb API response",
zap.String("operation", "getJobFromHistory"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
return nil, nil
}
resp := mapSlurmdbJob(&result.Jobs[0])
return &resp, nil
}
// CancelJob cancels a job by ID.
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
s.logger.Debug("slurm API request",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
return fmt.Errorf("cancel job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
s.logger.Info("job cancelled", zap.String("job_id", jobID))
return nil
}
// GetJobHistory queries SlurmDBD for historical jobs with pagination.
func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQuery) (*model.JobListResponse, error) {
opts := &slurm.GetSlurmdbJobsOptions{}
if query.Users != "" {
opts.Users = strToPtr(query.Users)
}
if query.Account != "" {
opts.Account = strToPtr(query.Account)
}
if query.Partition != "" {
opts.Partition = strToPtr(query.Partition)
}
if query.State != "" {
opts.State = strToPtr(query.State)
}
if query.JobName != "" {
opts.JobName = strToPtr(query.JobName)
}
if query.StartTime != "" {
opts.StartTime = strToPtr(query.StartTime)
}
if query.EndTime != "" {
opts.EndTime = strToPtr(query.EndTime)
}
if query.SubmitTime != "" {
opts.SubmitTime = strToPtr(query.SubmitTime)
}
if query.Cluster != "" {
opts.Cluster = strToPtr(query.Cluster)
}
if query.Qos != "" {
opts.Qos = strToPtr(query.Qos)
}
if query.Constraints != "" {
opts.Constraints = strToPtr(query.Constraints)
}
if query.ExitCode != "" {
opts.ExitCode = strToPtr(query.ExitCode)
}
if query.Node != "" {
opts.Node = strToPtr(query.Node)
}
if query.Reservation != "" {
opts.Reservation = strToPtr(query.Reservation)
}
if query.Groups != "" {
opts.Groups = strToPtr(query.Groups)
}
if query.Wckey != "" {
opts.Wckey = strToPtr(query.Wckey)
}
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobHistory"),
zap.Any("body", opts),
)
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
return nil, fmt.Errorf("get job history: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
}
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// ---------------------------------------------------------------------------
// Helper functions
// ---------------------------------------------------------------------------
func strToPtr(s string) *string { return &s }
// strPtrOrNil returns a pointer to s if non-empty, otherwise nil.
func strToPtrOrNil(s string) *string {
if s == "" {
return nil
}
return &s
}
func mapUint32NoValToInt32(v *slurm.Uint32NoVal) *int32 {
if v != nil && v.Number != nil {
n := int32(*v.Number)
return &n
}
return nil
}
// mapJobInfo maps SDK JobInfo to API JobResponse.
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
resp := model.JobResponse{}
if ji.JobID != nil {
resp.JobID = *ji.JobID
}
if ji.Name != nil {
resp.Name = *ji.Name
}
resp.State = ji.JobState
if ji.Partition != nil {
resp.Partition = *ji.Partition
}
resp.Account = derefStr(ji.Account)
resp.User = derefStr(ji.UserName)
resp.Cluster = derefStr(ji.Cluster)
resp.QOS = derefStr(ji.Qos)
resp.Priority = mapUint32NoValToInt32(ji.Priority)
resp.TimeLimit = uint32NoValString(ji.TimeLimit)
resp.StateReason = derefStr(ji.StateReason)
resp.Cpus = mapUint32NoValToInt32(ji.Cpus)
resp.Tasks = mapUint32NoValToInt32(ji.Tasks)
resp.NodeCount = mapUint32NoValToInt32(ji.NodeCount)
resp.BatchHost = derefStr(ji.BatchHost)
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
resp.SubmitTime = ji.SubmitTime.Number
}
if ji.StartTime != nil && ji.StartTime.Number != nil {
resp.StartTime = ji.StartTime.Number
}
if ji.EndTime != nil && ji.EndTime.Number != nil {
resp.EndTime = ji.EndTime.Number
}
if ji.ExitCode != nil && ji.ExitCode.ReturnCode != nil && ji.ExitCode.ReturnCode.Number != nil {
code := int32(*ji.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if ji.Nodes != nil {
resp.Nodes = *ji.Nodes
}
resp.StdOut = derefStr(ji.StandardOutput)
resp.StdErr = derefStr(ji.StandardError)
resp.StdIn = derefStr(ji.StandardInput)
resp.WorkDir = derefStr(ji.CurrentWorkingDirectory)
resp.Command = derefStr(ji.Command)
resp.ArrayJobID = mapUint32NoValToInt32(ji.ArrayJobID)
resp.ArrayTaskID = mapUint32NoValToInt32(ji.ArrayTaskID)
return resp
}
// mapSlurmdbJob maps SDK SlurmDBD Job to API JobResponse.
func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
resp := model.JobResponse{}
if j.JobID != nil {
resp.JobID = *j.JobID
}
if j.Name != nil {
resp.Name = *j.Name
}
if j.State != nil {
resp.State = j.State.Current
resp.StateReason = derefStr(j.State.Reason)
}
if j.Partition != nil {
resp.Partition = *j.Partition
}
resp.Account = derefStr(j.Account)
if j.User != nil {
resp.User = *j.User
}
resp.Cluster = derefStr(j.Cluster)
resp.QOS = derefStr(j.Qos)
resp.Priority = mapUint32NoValToInt32(j.Priority)
if j.Time != nil {
resp.TimeLimit = uint32NoValString(j.Time.Limit)
if j.Time.Submission != nil {
resp.SubmitTime = j.Time.Submission
}
if j.Time.Start != nil {
resp.StartTime = j.Time.Start
}
if j.Time.End != nil {
resp.EndTime = j.Time.End
}
}
if j.ExitCode != nil && j.ExitCode.ReturnCode != nil && j.ExitCode.ReturnCode.Number != nil {
code := int32(*j.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if j.Nodes != nil {
resp.Nodes = *j.Nodes
}
if j.Required != nil {
resp.Cpus = j.Required.CPUs
}
if j.AllocationNodes != nil {
resp.NodeCount = j.AllocationNodes
}
resp.WorkDir = derefStr(j.WorkingDirectory)
return resp
}