feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission
This commit is contained in:
@@ -4,13 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
@@ -19,8 +14,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
||||
|
||||
// ApplicationService handles parameter validation, script rendering, and job
|
||||
// submission for parameterized HPC applications.
|
||||
type ApplicationService struct {
|
||||
@@ -28,92 +21,15 @@ type ApplicationService struct {
|
||||
jobSvc *JobService
|
||||
workDirBase string
|
||||
logger *zap.Logger
|
||||
taskSvc *TaskService
|
||||
}
|
||||
|
||||
func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger) *ApplicationService {
|
||||
return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger}
|
||||
}
|
||||
|
||||
// ValidateParams checks that all required parameters are present and values match their types.
|
||||
// Parameters not in the schema are silently ignored.
|
||||
func (s *ApplicationService) ValidateParams(params []model.ParameterSchema, values map[string]string) error {
|
||||
var errs []string
|
||||
|
||||
for _, p := range params {
|
||||
if !paramNameRegex.MatchString(p.Name) {
|
||||
errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name))
|
||||
continue
|
||||
func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger, taskSvc ...*TaskService) *ApplicationService {
|
||||
var ts *TaskService
|
||||
if len(taskSvc) > 0 {
|
||||
ts = taskSvc[0]
|
||||
}
|
||||
|
||||
val, ok := values[p.Name]
|
||||
|
||||
if p.Required && !ok {
|
||||
errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch p.Type {
|
||||
case model.ParamTypeInteger:
|
||||
if _, err := strconv.Atoi(val); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val))
|
||||
}
|
||||
case model.ParamTypeBoolean:
|
||||
if val != "true" && val != "false" && val != "1" && val != "0" {
|
||||
errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val))
|
||||
}
|
||||
case model.ParamTypeEnum:
|
||||
if len(p.Options) > 0 {
|
||||
found := false
|
||||
for _, opt := range p.Options {
|
||||
if val == opt {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val))
|
||||
}
|
||||
}
|
||||
case model.ParamTypeFile, model.ParamTypeDirectory:
|
||||
case model.ParamTypeString:
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenderScript replaces $PARAM tokens in the template with user-provided values.
|
||||
// Only tokens defined in the schema are replaced. Replacement is done longest-name-first
|
||||
// to avoid partial matches (e.g., $JOB_NAME before $JOB).
|
||||
// All values are shell-escaped using single-quote wrapping.
|
||||
func (s *ApplicationService) RenderScript(template string, params []model.ParameterSchema, values map[string]string) string {
|
||||
sorted := make([]model.ParameterSchema, len(params))
|
||||
copy(sorted, params)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return len(sorted[i].Name) > len(sorted[j].Name)
|
||||
})
|
||||
|
||||
result := template
|
||||
for _, p := range sorted {
|
||||
val, ok := values[p.Name]
|
||||
if !ok {
|
||||
if p.Default != "" {
|
||||
val = p.Default
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'"
|
||||
result = strings.ReplaceAll(result, "$"+p.Name, escaped)
|
||||
}
|
||||
return result
|
||||
return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger, taskSvc: ts}
|
||||
}
|
||||
|
||||
// ListApplications delegates to the store.
|
||||
@@ -141,13 +57,22 @@ func (s *ApplicationService) DeleteApplication(ctx context.Context, id int64) er
|
||||
return s.store.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// SubmitFromApplication orchestrates the full submission flow:
|
||||
// 1. Fetch application by ID
|
||||
// 2. Parse parameters schema
|
||||
// 3. Validate parameter values
|
||||
// 4. Render script template
|
||||
// 5. Submit job via JobService
|
||||
// SubmitFromApplication orchestrates the full submission flow.
|
||||
// When TaskService is available, it delegates to ProcessTaskSync which creates
|
||||
// an hpc_tasks record and runs the full pipeline. Otherwise falls back to the
|
||||
// original direct implementation.
|
||||
func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicationID int64, values map[string]string) (*model.JobResponse, error) {
|
||||
if s.taskSvc != nil {
|
||||
req := &model.CreateTaskRequest{
|
||||
AppID: applicationID,
|
||||
Values: values,
|
||||
InputFileIDs: nil, // old API has no file_ids concept
|
||||
TaskName: "",
|
||||
}
|
||||
return s.taskSvc.ProcessTaskSync(ctx, req)
|
||||
}
|
||||
|
||||
// Fallback: original direct logic when TaskService not available
|
||||
app, err := s.store.GetByID(ctx, applicationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get application: %w", err)
|
||||
@@ -163,16 +88,16 @@ func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicat
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.ValidateParams(params, values); err != nil {
|
||||
if err := ValidateParams(params, values); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rendered := s.RenderScript(app.ScriptTemplate, params, values)
|
||||
rendered := RenderScript(app.ScriptTemplate, params, values)
|
||||
|
||||
workDir := ""
|
||||
if s.workDirBase != "" {
|
||||
safeName := sanitizeDirName(app.Name)
|
||||
subDir := time.Now().Format("20060102_150405") + "_" + randomSuffix(4)
|
||||
safeName := SanitizeDirName(app.Name)
|
||||
subDir := time.Now().Format("20060102_150405") + "_" + RandomSuffix(4)
|
||||
workDir = filepath.Join(s.workDirBase, safeName, subDir)
|
||||
if err := os.MkdirAll(workDir, 0777); err != nil {
|
||||
return nil, fmt.Errorf("create work directory %s: %w", workDir, err)
|
||||
@@ -187,17 +112,3 @@ func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicat
|
||||
req := &model.SubmitJobRequest{Script: rendered, WorkDir: workDir}
|
||||
return s.jobSvc.SubmitJob(ctx, req)
|
||||
}
|
||||
|
||||
func sanitizeDirName(name string) string {
|
||||
replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
|
||||
return replacer.Replace(name)
|
||||
}
|
||||
|
||||
func randomSuffix(n int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -42,24 +44,22 @@ func setupApplicationService(t *testing.T, slurmHandler http.HandlerFunc) (*Appl
|
||||
}
|
||||
|
||||
func TestValidateParams_AllRequired(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{
|
||||
{Name: "NAME", Type: model.ParamTypeString, Required: true},
|
||||
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
|
||||
}
|
||||
values := map[string]string{"NAME": "hello", "COUNT": "5"}
|
||||
if err := svc.ValidateParams(params, values); err != nil {
|
||||
if err := ValidateParams(params, values); err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateParams_MissingRequired(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{
|
||||
{Name: "NAME", Type: model.ParamTypeString, Required: true},
|
||||
}
|
||||
values := map[string]string{}
|
||||
err := svc.ValidateParams(params, values)
|
||||
err := ValidateParams(params, values)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing required param")
|
||||
}
|
||||
@@ -69,12 +69,11 @@ func TestValidateParams_MissingRequired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateParams_InvalidInteger(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{
|
||||
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
|
||||
}
|
||||
values := map[string]string{"COUNT": "abc"}
|
||||
err := svc.ValidateParams(params, values)
|
||||
err := ValidateParams(params, values)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid integer")
|
||||
}
|
||||
@@ -84,12 +83,11 @@ func TestValidateParams_InvalidInteger(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateParams_InvalidEnum(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{
|
||||
{Name: "MODE", Type: model.ParamTypeEnum, Required: true, Options: []string{"fast", "slow"}},
|
||||
}
|
||||
values := map[string]string{"MODE": "medium"}
|
||||
err := svc.ValidateParams(params, values)
|
||||
err := ValidateParams(params, values)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid enum value")
|
||||
}
|
||||
@@ -99,12 +97,11 @@ func TestValidateParams_InvalidEnum(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateParams_BooleanValues(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{
|
||||
{Name: "FLAG", Type: model.ParamTypeBoolean, Required: true},
|
||||
}
|
||||
for _, val := range []string{"true", "false", "1", "0"} {
|
||||
err := svc.ValidateParams(params, map[string]string{"FLAG": val})
|
||||
err := ValidateParams(params, map[string]string{"FLAG": val})
|
||||
if err != nil {
|
||||
t.Errorf("boolean value %q should be valid, got error: %v", val, err)
|
||||
}
|
||||
@@ -112,10 +109,9 @@ func TestValidateParams_BooleanValues(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRenderScript_SimpleReplacement(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
||||
values := map[string]string{"INPUT": "data.txt"}
|
||||
result := svc.RenderScript("echo $INPUT", params, values)
|
||||
result := RenderScript("echo $INPUT", params, values)
|
||||
expected := "echo 'data.txt'"
|
||||
if result != expected {
|
||||
t.Errorf("got %q, want %q", result, expected)
|
||||
@@ -123,10 +119,9 @@ func TestRenderScript_SimpleReplacement(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRenderScript_DefaultValues(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{{Name: "OUTPUT", Type: model.ParamTypeString, Default: "out.log"}}
|
||||
values := map[string]string{}
|
||||
result := svc.RenderScript("cat $OUTPUT", params, values)
|
||||
result := RenderScript("cat $OUTPUT", params, values)
|
||||
expected := "cat 'out.log'"
|
||||
if result != expected {
|
||||
t.Errorf("got %q, want %q", result, expected)
|
||||
@@ -134,10 +129,9 @@ func TestRenderScript_DefaultValues(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRenderScript_PreservesUnknownVars(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
||||
values := map[string]string{"INPUT": "data.txt"}
|
||||
result := svc.RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values)
|
||||
result := RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values)
|
||||
if !strings.Contains(result, "$HOME") {
|
||||
t.Error("$HOME should be preserved")
|
||||
}
|
||||
@@ -150,7 +144,6 @@ func TestRenderScript_PreservesUnknownVars(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRenderScript_ShellEscaping(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
||||
|
||||
tests := []struct {
|
||||
@@ -165,7 +158,7 @@ func TestRenderScript_ShellEscaping(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value})
|
||||
result := RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value})
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
@@ -174,14 +167,13 @@ func TestRenderScript_ShellEscaping(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRenderScript_OverlappingParams(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
template := "$JOB_NAME and $JOB"
|
||||
params := []model.ParameterSchema{
|
||||
{Name: "JOB", Type: model.ParamTypeString},
|
||||
{Name: "JOB_NAME", Type: model.ParamTypeString},
|
||||
}
|
||||
values := map[string]string{"JOB": "myjob", "JOB_NAME": "my-test-job"}
|
||||
result := svc.RenderScript(template, params, values)
|
||||
result := RenderScript(template, params, values)
|
||||
if strings.Contains(result, "$JOB_NAME") {
|
||||
t.Error("$JOB_NAME was not replaced")
|
||||
}
|
||||
@@ -197,10 +189,9 @@ func TestRenderScript_OverlappingParams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRenderScript_NewlineInValue(t *testing.T) {
|
||||
svc := NewApplicationService(nil, nil, "", zap.NewNop())
|
||||
params := []model.ParameterSchema{{Name: "CMD", Type: model.ParamTypeString}}
|
||||
values := map[string]string{"CMD": "line1\nline2"}
|
||||
result := svc.RenderScript("echo $CMD", params, values)
|
||||
result := RenderScript("echo $CMD", params, values)
|
||||
expected := "echo 'line1\nline2'"
|
||||
if result != expected {
|
||||
t.Errorf("got %q, want %q", result, expected)
|
||||
@@ -298,3 +289,68 @@ func TestSubmitFromApplication_NoParameters(t *testing.T) {
|
||||
t.Errorf("JobID = %d, want 99", resp.JobID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmitFromApplication_DelegatesToTaskService(t *testing.T) {
|
||||
jobID := int32(77)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||
jobSvc := NewJobService(client, zap.NewNop())
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Application{}, &model.Task{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||
t.Fatalf("auto migrate: %v", err)
|
||||
}
|
||||
|
||||
appStore := store.NewApplicationStore(db)
|
||||
taskStore := store.NewTaskStore(db)
|
||||
fileStore := store.NewFileStore(db)
|
||||
blobStore := store.NewBlobStore(db)
|
||||
|
||||
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||
os.MkdirAll(workDirBase, 0777)
|
||||
|
||||
taskSvc := NewTaskService(taskStore, appStore, fileStore, blobStore, nil, jobSvc, workDirBase, zap.NewNop())
|
||||
appSvc := NewApplicationService(appStore, jobSvc, workDirBase, zap.NewNop(), taskSvc)
|
||||
|
||||
id, err := appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||
Name: "delegated-app",
|
||||
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
|
||||
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create app: %v", err)
|
||||
}
|
||||
|
||||
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
|
||||
"JOB_NAME": "delegated-job",
|
||||
"INPUT": "test-data",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitFromApplication() error = %v", err)
|
||||
}
|
||||
if resp.JobID != 77 {
|
||||
t.Errorf("JobID = %d, want 77", resp.JobID)
|
||||
}
|
||||
|
||||
var task model.Task
|
||||
if err := db.Where("app_id = ?", id).First(&task).Error; err != nil {
|
||||
t.Fatalf("no hpc_tasks record found for app_id %d: %v", id, err)
|
||||
}
|
||||
if task.SlurmJobID == nil || *task.SlurmJobID != 77 {
|
||||
t.Errorf("task SlurmJobID = %v, want 77", task.SlurmJobID)
|
||||
}
|
||||
if task.Status != model.TaskStatusQueued {
|
||||
t.Errorf("task Status = %q, want %q", task.Status, model.TaskStatusQueued)
|
||||
}
|
||||
}
|
||||
|
||||
145
internal/service/file_staging_service.go
Normal file
145
internal/service/file_staging_service.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/storage"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// FileStagingService batch downloads files from MinIO to a local (NFS) directory,
|
||||
// deduplicating by blob SHA256 so each unique blob is fetched only once.
|
||||
type FileStagingService struct {
|
||||
fileStore *store.FileStore
|
||||
blobStore *store.BlobStore
|
||||
storage storage.ObjectStorage
|
||||
bucket string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewFileStagingService(fileStore *store.FileStore, blobStore *store.BlobStore, st storage.ObjectStorage, bucket string, logger *zap.Logger) *FileStagingService {
|
||||
return &FileStagingService{
|
||||
fileStore: fileStore,
|
||||
blobStore: blobStore,
|
||||
storage: st,
|
||||
bucket: bucket,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadFilesToDir downloads the given files into destDir.
|
||||
// Files sharing the same blob SHA256 are deduplicated: the blob is fetched once
|
||||
// and then copied to each filename. Filenames are sanitized with filepath.Base
|
||||
// to prevent path traversal.
|
||||
func (s *FileStagingService) DownloadFilesToDir(ctx context.Context, fileIDs []int64, destDir string) error {
|
||||
if len(fileIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch files: %w", err)
|
||||
}
|
||||
|
||||
type group struct {
|
||||
primary *model.File // first file — written via io.Copy from MinIO
|
||||
others []*model.File // remaining files — local copy of primary
|
||||
}
|
||||
groups := make(map[string]*group)
|
||||
for i := range files {
|
||||
f := &files[i]
|
||||
g, ok := groups[f.BlobSHA256]
|
||||
if !ok {
|
||||
groups[f.BlobSHA256] = &group{primary: f}
|
||||
} else {
|
||||
g.others = append(g.others, f)
|
||||
}
|
||||
}
|
||||
|
||||
sha256s := make([]string, 0, len(groups))
|
||||
for sh := range groups {
|
||||
sha256s = append(sha256s, sh)
|
||||
}
|
||||
blobs, err := s.blobStore.GetBySHA256s(ctx, sha256s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch blobs: %w", err)
|
||||
}
|
||||
|
||||
blobMap := make(map[string]*model.FileBlob, len(blobs))
|
||||
for i := range blobs {
|
||||
blobMap[blobs[i].SHA256] = &blobs[i]
|
||||
}
|
||||
|
||||
for sha256, g := range groups {
|
||||
blob, ok := blobMap[sha256]
|
||||
if !ok {
|
||||
return fmt.Errorf("blob %s not found", sha256)
|
||||
}
|
||||
|
||||
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get object %s: %w", blob.MinioKey, err)
|
||||
}
|
||||
|
||||
// TODO: handle filename collisions when multiple files have the same Name (low risk without user auth, revisit when auth is added)
|
||||
primaryName := filepath.Base(g.primary.Name)
|
||||
primaryPath := filepath.Join(destDir, primaryName)
|
||||
|
||||
if err := writeFile(primaryPath, reader); err != nil {
|
||||
reader.Close()
|
||||
os.Remove(primaryPath)
|
||||
return fmt.Errorf("write file %s: %w", primaryName, err)
|
||||
}
|
||||
reader.Close()
|
||||
|
||||
for _, other := range g.others {
|
||||
otherName := filepath.Base(other.Name)
|
||||
otherPath := filepath.Join(destDir, otherName)
|
||||
|
||||
if err := copyFile(primaryPath, otherPath); err != nil {
|
||||
return fmt.Errorf("copy %s to %s: %w", primaryName, otherName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeFile(path string, reader io.Reader) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(f, reader); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err := io.Copy(out, in); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
232
internal/service/file_staging_service_test.go
Normal file
232
internal/service/file_staging_service_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/storage"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type stagingMockStorage struct {
|
||||
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
|
||||
}
|
||||
|
||||
func (m *stagingMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||
if m.getObjectFn != nil {
|
||||
return m.getObjectFn(ctx, bucket, key, opts)
|
||||
}
|
||||
return nil, storage.ObjectInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *stagingMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||
return storage.UploadInfo{}, nil
|
||||
}
|
||||
func (m *stagingMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
|
||||
return storage.UploadInfo{}, nil
|
||||
}
|
||||
func (m *stagingMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *stagingMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *stagingMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
|
||||
return nil
|
||||
}
|
||||
func (m *stagingMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *stagingMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
|
||||
return nil
|
||||
}
|
||||
func (m *stagingMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (m *stagingMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
|
||||
return nil
|
||||
}
|
||||
func (m *stagingMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||
return storage.ObjectInfo{}, nil
|
||||
}
|
||||
|
||||
func setupStagingTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func newStagingService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *FileStagingService {
|
||||
t.Helper()
|
||||
return NewFileStagingService(
|
||||
store.NewFileStore(db),
|
||||
store.NewBlobStore(db),
|
||||
st,
|
||||
"test-bucket",
|
||||
zap.NewNop(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestFileStaging_DownloadWithDedup(t *testing.T) {
|
||||
db := setupStagingTestDB(t)
|
||||
|
||||
sha1 := "aaa111"
|
||||
sha2 := "bbb222"
|
||||
|
||||
db.Create(&model.FileBlob{SHA256: sha1, MinioKey: "blobs/aaa111", FileSize: 5, MimeType: "text/plain", RefCount: 2})
|
||||
db.Create(&model.FileBlob{SHA256: sha2, MinioKey: "blobs/bbb222", FileSize: 3, MimeType: "text/plain", RefCount: 1})
|
||||
|
||||
db.Create(&model.File{Name: "file1.txt", BlobSHA256: sha1})
|
||||
db.Create(&model.File{Name: "file2.txt", BlobSHA256: sha1})
|
||||
db.Create(&model.File{Name: "file3.txt", BlobSHA256: sha2})
|
||||
|
||||
var files []model.File
|
||||
db.Find(&files)
|
||||
if len(files) < 3 {
|
||||
t.Fatalf("need 3 files, got %d", len(files))
|
||||
}
|
||||
|
||||
var getObjCalls int32
|
||||
st := &stagingMockStorage{}
|
||||
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||
atomic.AddInt32(&getObjCalls, 1)
|
||||
var content string
|
||||
switch key {
|
||||
case "blobs/aaa111":
|
||||
content = "content-a"
|
||||
case "blobs/bbb222":
|
||||
content = "content-b"
|
||||
default:
|
||||
return nil, storage.ObjectInfo{}, fmt.Errorf("unexpected key %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader([]byte(content))), storage.ObjectInfo{Key: key}, nil
|
||||
}
|
||||
|
||||
destDir := t.TempDir()
|
||||
svc := newStagingService(t, st, db)
|
||||
|
||||
err := svc.DownloadFilesToDir(context.Background(), []int64{files[0].ID, files[1].ID, files[2].ID}, destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadFilesToDir: %v", err)
|
||||
}
|
||||
|
||||
if calls := atomic.LoadInt32(&getObjCalls); calls != 2 {
|
||||
t.Errorf("GetObject called %d times, want 2", calls)
|
||||
}
|
||||
|
||||
expected := map[string]string{
|
||||
"file1.txt": "content-a",
|
||||
"file2.txt": "content-a",
|
||||
"file3.txt": "content-b",
|
||||
}
|
||||
for name, want := range expected {
|
||||
p := filepath.Join(destDir, name)
|
||||
data, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
t.Errorf("read %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
if string(data) != want {
|
||||
t.Errorf("%s content = %q, want %q", name, data, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStaging_PathTraversal(t *testing.T) {
|
||||
db := setupStagingTestDB(t)
|
||||
|
||||
sha := "traversal123"
|
||||
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/traversal", FileSize: 4, MimeType: "text/plain", RefCount: 1})
|
||||
db.Create(&model.File{Name: "../../../etc/passwd", BlobSHA256: sha})
|
||||
|
||||
var file model.File
|
||||
db.First(&file)
|
||||
|
||||
st := &stagingMockStorage{}
|
||||
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||
return io.NopCloser(bytes.NewReader([]byte("safe"))), storage.ObjectInfo{Key: key}, nil
|
||||
}
|
||||
|
||||
destDir := t.TempDir()
|
||||
svc := newStagingService(t, st, db)
|
||||
|
||||
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadFilesToDir: %v", err)
|
||||
}
|
||||
|
||||
sanitized := filepath.Join(destDir, "passwd")
|
||||
data, err := os.ReadFile(sanitized)
|
||||
if err != nil {
|
||||
t.Fatalf("read sanitized file: %v", err)
|
||||
}
|
||||
if string(data) != "safe" {
|
||||
t.Errorf("content = %q, want %q", data, "safe")
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(destDir)
|
||||
if err != nil {
|
||||
t.Fatalf("readdir: %v", err)
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.Name() != "passwd" {
|
||||
t.Errorf("unexpected file in destDir: %s", e.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStaging_EmptyList(t *testing.T) {
|
||||
db := setupStagingTestDB(t)
|
||||
st := &stagingMockStorage{}
|
||||
svc := newStagingService(t, st, db)
|
||||
|
||||
err := svc.DownloadFilesToDir(context.Background(), []int64{}, t.TempDir())
|
||||
if err != nil {
|
||||
t.Errorf("expected nil for empty list, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStaging_GetObjectFails(t *testing.T) {
|
||||
db := setupStagingTestDB(t)
|
||||
|
||||
sha := "fail123"
|
||||
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/fail", FileSize: 5, MimeType: "text/plain", RefCount: 1})
|
||||
db.Create(&model.File{Name: "willfail.txt", BlobSHA256: sha})
|
||||
|
||||
var file model.File
|
||||
db.First(&file)
|
||||
|
||||
st := &stagingMockStorage{}
|
||||
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||
return nil, storage.ObjectInfo{}, fmt.Errorf("minio down")
|
||||
}
|
||||
|
||||
destDir := t.TempDir()
|
||||
svc := newStagingService(t, st, db)
|
||||
|
||||
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when GetObject fails")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "minio down") {
|
||||
t.Errorf("error = %q, want 'minio down'", err.Error())
|
||||
}
|
||||
}
|
||||
112
internal/service/script_utils.go
Normal file
112
internal/service/script_utils.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
)
|
||||
|
||||
var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
||||
|
||||
// ValidateParams checks that all required parameters are present and values match their types.
|
||||
// Parameters not in the schema are silently ignored.
|
||||
func ValidateParams(params []model.ParameterSchema, values map[string]string) error {
|
||||
var errs []string
|
||||
|
||||
for _, p := range params {
|
||||
if !paramNameRegex.MatchString(p.Name) {
|
||||
errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
val, ok := values[p.Name]
|
||||
|
||||
if p.Required && !ok {
|
||||
errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch p.Type {
|
||||
case model.ParamTypeInteger:
|
||||
if _, err := strconv.Atoi(val); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val))
|
||||
}
|
||||
case model.ParamTypeBoolean:
|
||||
if val != "true" && val != "false" && val != "1" && val != "0" {
|
||||
errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val))
|
||||
}
|
||||
case model.ParamTypeEnum:
|
||||
if len(p.Options) > 0 {
|
||||
found := false
|
||||
for _, opt := range p.Options {
|
||||
if val == opt {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val))
|
||||
}
|
||||
}
|
||||
case model.ParamTypeFile, model.ParamTypeDirectory:
|
||||
case model.ParamTypeString:
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenderScript replaces $PARAM tokens in the template with user-provided values.
|
||||
// Only tokens defined in the schema are replaced. Replacement is done longest-name-first
|
||||
// to avoid partial matches (e.g., $JOB_NAME before $JOB).
|
||||
// All values are shell-escaped using single-quote wrapping.
|
||||
func RenderScript(template string, params []model.ParameterSchema, values map[string]string) string {
|
||||
sorted := make([]model.ParameterSchema, len(params))
|
||||
copy(sorted, params)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return len(sorted[i].Name) > len(sorted[j].Name)
|
||||
})
|
||||
|
||||
result := template
|
||||
for _, p := range sorted {
|
||||
val, ok := values[p.Name]
|
||||
if !ok {
|
||||
if p.Default != "" {
|
||||
val = p.Default
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'"
|
||||
result = strings.ReplaceAll(result, "$"+p.Name, escaped)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SanitizeDirName sanitizes a directory name.
|
||||
func SanitizeDirName(name string) string {
|
||||
replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
|
||||
return replacer.Replace(name)
|
||||
}
|
||||
|
||||
// RandomSuffix generates a random suffix of length n.
|
||||
func RandomSuffix(n int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
554
internal/service/task_service.go
Normal file
554
internal/service/task_service.go
Normal file
@@ -0,0 +1,554 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type TaskService struct {
|
||||
taskStore *store.TaskStore
|
||||
appStore *store.ApplicationStore
|
||||
fileStore *store.FileStore // nil ok
|
||||
blobStore *store.BlobStore // nil ok
|
||||
stagingSvc *FileStagingService // nil ok — MinIO unavailable
|
||||
jobSvc *JobService
|
||||
workDirBase string
|
||||
logger *zap.Logger
|
||||
|
||||
// async processing
|
||||
taskCh chan int64 // buffered channel, cap=16
|
||||
cancelFn context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex // protects taskCh from send-on-closed
|
||||
started bool // prevent double-start
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func NewTaskService(
|
||||
taskStore *store.TaskStore,
|
||||
appStore *store.ApplicationStore,
|
||||
fileStore *store.FileStore,
|
||||
blobStore *store.BlobStore,
|
||||
stagingSvc *FileStagingService,
|
||||
jobSvc *JobService,
|
||||
workDirBase string,
|
||||
logger *zap.Logger,
|
||||
) *TaskService {
|
||||
return &TaskService{
|
||||
taskStore: taskStore,
|
||||
appStore: appStore,
|
||||
fileStore: fileStore,
|
||||
blobStore: blobStore,
|
||||
stagingSvc: stagingSvc,
|
||||
jobSvc: jobSvc,
|
||||
workDirBase: workDirBase,
|
||||
logger: logger,
|
||||
taskCh: make(chan int64, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
|
||||
app, err := s.appStore.GetByID(ctx, req.AppID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get application: %w", err)
|
||||
}
|
||||
if app == nil {
|
||||
return nil, fmt.Errorf("application %d not found", req.AppID)
|
||||
}
|
||||
|
||||
// 2. Validate file limit
|
||||
if len(req.InputFileIDs) > 100 {
|
||||
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
|
||||
}
|
||||
|
||||
// 3. Deduplicate file IDs
|
||||
fileIDs := uniqueInt64s(req.InputFileIDs)
|
||||
|
||||
// 4. Validate file IDs exist
|
||||
if s.fileStore != nil && len(fileIDs) > 0 {
|
||||
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate file ids: %w", err)
|
||||
}
|
||||
found := make(map[int64]bool, len(files))
|
||||
for _, f := range files {
|
||||
found[f.ID] = true
|
||||
}
|
||||
for _, id := range fileIDs {
|
||||
if !found[id] {
|
||||
return nil, fmt.Errorf("file %d not found", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Auto-generate task name if empty
|
||||
taskName := req.TaskName
|
||||
if taskName == "" {
|
||||
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
|
||||
}
|
||||
|
||||
// 6. Marshal values
|
||||
valuesJSON := json.RawMessage(`{}`)
|
||||
if len(req.Values) > 0 {
|
||||
b, err := json.Marshal(req.Values)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal values: %w", err)
|
||||
}
|
||||
valuesJSON = b
|
||||
}
|
||||
|
||||
// 7. Marshal input_file_ids
|
||||
fileIDsJSON := json.RawMessage(`[]`)
|
||||
if len(fileIDs) > 0 {
|
||||
b, err := json.Marshal(fileIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal file ids: %w", err)
|
||||
}
|
||||
fileIDsJSON = b
|
||||
}
|
||||
|
||||
// 8. Create task record
|
||||
task := &model.Task{
|
||||
TaskName: taskName,
|
||||
AppID: app.ID,
|
||||
AppName: app.Name,
|
||||
Status: model.TaskStatusSubmitted,
|
||||
Values: valuesJSON,
|
||||
InputFileIDs: fileIDsJSON,
|
||||
SubmittedAt: time.Now(),
|
||||
}
|
||||
|
||||
taskID, err := s.taskStore.Create(ctx, task)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create task: %w", err)
|
||||
}
|
||||
task.ID = taskID
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// ProcessTask runs the full synchronous processing pipeline for a task.
|
||||
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
|
||||
// 1. Fetch task
|
||||
task, err := s.taskStore.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get task: %w", err)
|
||||
}
|
||||
if task == nil {
|
||||
return fmt.Errorf("task %d not found", taskID)
|
||||
}
|
||||
|
||||
fail := func(step, msg string) error {
|
||||
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
|
||||
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
|
||||
return fmt.Errorf("%s", msg)
|
||||
}
|
||||
|
||||
currentStep := task.CurrentStep
|
||||
|
||||
var workDir string
|
||||
var app *model.Application
|
||||
|
||||
if currentStep == "" || currentStep == model.TaskStepPreparing {
|
||||
// 2. Set preparing
|
||||
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
|
||||
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
|
||||
}
|
||||
|
||||
// 3. Fetch app
|
||||
app, err = s.appStore.GetByID(ctx, task.AppID)
|
||||
if err != nil {
|
||||
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
|
||||
}
|
||||
if app == nil {
|
||||
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
|
||||
}
|
||||
|
||||
// 4-5. Create work directory
|
||||
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
|
||||
if err := os.MkdirAll(workDir, 0777); err != nil {
|
||||
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
|
||||
}
|
||||
|
||||
// 6. CHMOD traversal — critical for multi-user HPC
|
||||
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
|
||||
os.Chmod(dir, 0777)
|
||||
}
|
||||
os.Chmod(s.workDirBase, 0777)
|
||||
|
||||
// 7. UpdateWorkDir
|
||||
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
|
||||
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
|
||||
}
|
||||
} else {
|
||||
app, err = s.appStore.GetByID(ctx, task.AppID)
|
||||
if err != nil {
|
||||
return fail(currentStep, fmt.Sprintf("get application: %v", err))
|
||||
}
|
||||
if app == nil {
|
||||
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
|
||||
}
|
||||
workDir = task.WorkDir
|
||||
}
|
||||
|
||||
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
|
||||
if currentStep == model.TaskStepDownloading && workDir != "" {
|
||||
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
|
||||
for _, f := range matches {
|
||||
os.Remove(f)
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Set downloading
|
||||
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
|
||||
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
|
||||
}
|
||||
|
||||
// 9. Parse input_file_ids
|
||||
var fileIDs []int64
|
||||
if len(task.InputFileIDs) > 0 {
|
||||
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
|
||||
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// 10-12. Download files
|
||||
if len(fileIDs) > 0 {
|
||||
if s.stagingSvc == nil {
|
||||
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
|
||||
}
|
||||
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
|
||||
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 13-14. Set ready + submitting
|
||||
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
|
||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
|
||||
}
|
||||
|
||||
// 15. Parse app parameters
|
||||
var params []model.ParameterSchema
|
||||
if len(app.Parameters) > 0 {
|
||||
if err := json.Unmarshal(app.Parameters, ¶ms); err != nil {
|
||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// 16. Parse task values
|
||||
values := make(map[string]string)
|
||||
if len(task.Values) > 0 {
|
||||
if err := json.Unmarshal(task.Values, &values); err != nil {
|
||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := ValidateParams(params, values); err != nil {
|
||||
return fail(model.TaskStepSubmitting, err.Error())
|
||||
}
|
||||
|
||||
// 17. Render script
|
||||
rendered := RenderScript(app.ScriptTemplate, params, values)
|
||||
|
||||
// 18. Submit to Slurm
|
||||
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
|
||||
Script: rendered,
|
||||
WorkDir: workDir,
|
||||
})
|
||||
if err != nil {
|
||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
|
||||
}
|
||||
|
||||
// 19. Update slurm_job_id and status to queued
|
||||
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
|
||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
|
||||
}
|
||||
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
|
||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListTasks returns a paginated list of tasks.
|
||||
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
||||
return s.taskStore.List(ctx, query)
|
||||
}
|
||||
|
||||
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
|
||||
// for old API compatibility.
|
||||
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
|
||||
// 1. Create task
|
||||
task, err := s.CreateTask(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Process synchronously
|
||||
if err := s.ProcessTask(ctx, task.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Re-fetch to get updated slurm_job_id
|
||||
task, err = s.taskStore.GetByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("re-fetch task: %w", err)
|
||||
}
|
||||
if task == nil || task.SlurmJobID == nil {
|
||||
return nil, fmt.Errorf("task has no slurm job id after processing")
|
||||
}
|
||||
|
||||
// 4. Return JobResponse
|
||||
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
|
||||
}
|
||||
|
||||
// uniqueInt64s deduplicates and sorts a slice of int64.
|
||||
func uniqueInt64s(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[int64]bool, len(ids))
|
||||
result := make([]int64, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if !seen[id] {
|
||||
seen[id] = true
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
|
||||
if len(slurmState) == 0 {
|
||||
return model.TaskStatusRunning
|
||||
}
|
||||
|
||||
state := strings.ToUpper(slurmState[0])
|
||||
switch state {
|
||||
case "PENDING":
|
||||
return model.TaskStatusQueued
|
||||
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
|
||||
return model.TaskStatusRunning
|
||||
case "COMPLETED":
|
||||
return model.TaskStatusCompleted
|
||||
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
|
||||
return model.TaskStatusFailed
|
||||
default:
|
||||
return model.TaskStatusRunning
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
|
||||
task, err := s.taskStore.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to fetch task for refresh",
|
||||
zap.Int64("task_id", taskID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
if task == nil || task.SlurmJobID == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to query slurm job status during refresh",
|
||||
zap.Int64("task_id", taskID),
|
||||
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
if jobResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
|
||||
if newStatus != task.Status {
|
||||
s.logger.Info("updating task status from slurm",
|
||||
zap.Int64("task_id", taskID),
|
||||
zap.String("old_status", task.Status),
|
||||
zap.String("new_status", newStatus),
|
||||
)
|
||||
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
|
||||
staleThreshold := 30 * time.Second
|
||||
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
|
||||
|
||||
for _, status := range nonTerminal {
|
||||
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
|
||||
Status: status,
|
||||
Page: 1,
|
||||
PageSize: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to list tasks for stale refresh",
|
||||
zap.String("status", status),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
cutoff := time.Now().Add(-staleThreshold)
|
||||
for i := range tasks {
|
||||
if tasks[i].UpdatedAt.Before(cutoff) {
|
||||
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
|
||||
s.logger.Warn("failed to refresh stale task",
|
||||
zap.Int64("task_id", tasks[i].ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TaskService) StartProcessor(ctx context.Context) {
|
||||
s.mu.Lock()
|
||||
if s.started {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.started = true
|
||||
s.mu.Unlock()
|
||||
|
||||
ctx, s.cancelFn = context.WithCancel(ctx)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Error("processor panic", zap.Any("panic", r))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case taskID, ok := <-s.taskCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
|
||||
s.processWithRetry(taskCtx, taskID)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.RecoverStuckTasks(ctx)
|
||||
}
|
||||
|
||||
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
|
||||
task, err := s.CreateTask(ctx, req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if s.stopped {
|
||||
s.mu.Unlock()
|
||||
return 0, fmt.Errorf("processor stopped, cannot submit task")
|
||||
}
|
||||
select {
|
||||
case s.taskCh <- task.ID:
|
||||
default:
|
||||
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return task.ID, nil
|
||||
}
|
||||
|
||||
func (s *TaskService) StopProcessor() {
|
||||
s.mu.Lock()
|
||||
if s.stopped {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.stopped = true
|
||||
close(s.taskCh)
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.cancelFn != nil {
|
||||
s.cancelFn()
|
||||
}
|
||||
s.wg.Wait()
|
||||
|
||||
s.mu.Lock()
|
||||
drainCh := s.taskCh
|
||||
s.taskCh = make(chan int64, 16)
|
||||
s.mu.Unlock()
|
||||
|
||||
for taskID := range drainCh {
|
||||
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
|
||||
err := s.ProcessTask(ctx, taskID)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
|
||||
if fetchErr != nil || task == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if task.RetryCount < 3 {
|
||||
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
|
||||
s.mu.Lock()
|
||||
if !s.stopped {
|
||||
select {
|
||||
case s.taskCh <- taskID:
|
||||
default:
|
||||
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
|
||||
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get stuck tasks", zap.Error(err))
|
||||
return
|
||||
}
|
||||
for i := range tasks {
|
||||
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
|
||||
s.mu.Lock()
|
||||
if !s.stopped {
|
||||
select {
|
||||
case s.taskCh <- tasks[i].ID:
|
||||
default:
|
||||
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
416
internal/service/task_service_async_test.go
Normal file
416
internal/service/task_service_async_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/slurm"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
gormlogger "gorm.io/gorm/logger"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupAsyncTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||
t.Fatalf("auto migrate: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
type asyncTestEnv struct {
|
||||
taskStore *store.TaskStore
|
||||
appStore *store.ApplicationStore
|
||||
svc *TaskService
|
||||
srv *httptest.Server
|
||||
db *gorm.DB
|
||||
workDirBase string
|
||||
}
|
||||
|
||||
func newAsyncTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *asyncTestEnv {
|
||||
t.Helper()
|
||||
db := setupAsyncTestDB(t)
|
||||
|
||||
ts := store.NewTaskStore(db)
|
||||
as := store.NewApplicationStore(db)
|
||||
|
||||
srv := httptest.NewServer(slurmHandler)
|
||||
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||
jobSvc := NewJobService(client, zap.NewNop())
|
||||
|
||||
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||
os.MkdirAll(workDirBase, 0777)
|
||||
|
||||
svc := NewTaskService(ts, as, nil, nil, nil, jobSvc, workDirBase, zap.NewNop())
|
||||
|
||||
return &asyncTestEnv{
|
||||
taskStore: ts,
|
||||
appStore: as,
|
||||
svc: svc,
|
||||
srv: srv,
|
||||
db: db,
|
||||
workDirBase: workDirBase,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *asyncTestEnv) close() {
|
||||
e.srv.Close()
|
||||
}
|
||||
|
||||
func (e *asyncTestEnv) createApp(t *testing.T, name, script string) int64 {
|
||||
t.Helper()
|
||||
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||
Name: name,
|
||||
ScriptTemplate: script,
|
||||
Parameters: json.RawMessage(`[]`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create app: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestTaskService_Async_SubmitAndProcess(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "async-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "async-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitAsync: %v", err)
|
||||
}
|
||||
if taskID == 0 {
|
||||
t.Fatal("expected non-zero task ID")
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
task, err := env.taskStore.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID: %v", err)
|
||||
}
|
||||
if task.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusQueued)
|
||||
}
|
||||
|
||||
env.svc.StopProcessor()
|
||||
}
|
||||
|
||||
func TestTaskService_Retry_MaxExhaustion(t *testing.T) {
|
||||
callCount := int32(0)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"slurm down"}`))
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "retry-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "retry-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitAsync: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
task, _ := env.taskStore.GetByID(ctx, taskID)
|
||||
if task.Status != model.TaskStatusFailed {
|
||||
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusFailed)
|
||||
}
|
||||
if task.RetryCount < 3 {
|
||||
t.Errorf("RetryCount = %d, want >= 3", task.RetryCount)
|
||||
}
|
||||
|
||||
env.svc.StopProcessor()
|
||||
}
|
||||
|
||||
func TestTaskService_Recover_StuckTasks(t *testing.T) {
|
||||
jobID := int32(99)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "stuck-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
task := &model.Task{
|
||||
TaskName: "stuck-task",
|
||||
AppID: appID,
|
||||
AppName: "stuck-app",
|
||||
Status: model.TaskStatusPreparing,
|
||||
CurrentStep: model.TaskStepPreparing,
|
||||
RetryCount: 0,
|
||||
SubmittedAt: time.Now(),
|
||||
}
|
||||
taskID, err := env.taskStore.Create(ctx, task)
|
||||
if err != nil {
|
||||
t.Fatalf("Create stuck task: %v", err)
|
||||
}
|
||||
|
||||
staleTime := time.Now().Add(-10 * time.Minute)
|
||||
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, taskID)
|
||||
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
updated, _ := env.taskStore.GetByID(ctx, taskID)
|
||||
if updated.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||
}
|
||||
|
||||
env.svc.StopProcessor()
|
||||
}
|
||||
|
||||
func TestTaskService_Shutdown_InFlight(t *testing.T) {
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
jobID := int32(77)
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "shutdown-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "shutdown-test",
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
env.svc.StopProcessor()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("StopProcessor did not complete within timeout")
|
||||
}
|
||||
|
||||
task, _ := env.taskStore.GetByID(ctx, taskID)
|
||||
if task.Status != model.TaskStatusQueued && task.Status != model.TaskStatusSubmitted {
|
||||
t.Logf("task status after shutdown: %q (acceptable)", task.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_PanicRecovery(t *testing.T) {
|
||||
jobID := int32(55)
|
||||
panicDone := int32(0)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if atomic.CompareAndSwapInt32(&panicDone, 0, 1) {
|
||||
panic("intentional test panic")
|
||||
}
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "panic-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "panic-test",
|
||||
})
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
atomic.StoreInt32(&panicDone, 1)
|
||||
|
||||
env.svc.StopProcessor()
|
||||
_ = taskID
|
||||
}
|
||||
|
||||
func TestTaskService_SubmitAsync_DuringShutdown(t *testing.T) {
|
||||
env := newAsyncTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "shutdown-err-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
env.svc.StopProcessor()
|
||||
|
||||
_, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "after-shutdown",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when submitting after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTaskService_SubmitAsync_ChannelFull_NonBlocking verifies SubmitAsync
|
||||
// returns without blocking when the task channel buffer (cap=16) is full.
|
||||
// Before fix: SubmitAsync holds s.mu while blocking on full channel → deadlock.
|
||||
// After fix: non-blocking select returns immediately.
|
||||
func TestTaskService_SubmitAsync_ChannelFull_NonBlocking(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(5 * time.Second)
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "channel-full-app", "#!/bin/bash\necho hello")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
|
||||
taskIDs := make([]int64, 17)
|
||||
for i := range taskIDs {
|
||||
id, err := env.taskStore.Create(ctx, &model.Task{
|
||||
TaskName: fmt.Sprintf("fill-%d", i),
|
||||
AppID: appID,
|
||||
AppName: "channel-full-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
CurrentStep: model.TaskStepSubmitting,
|
||||
SubmittedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create fill task %d: %v", i, err)
|
||||
}
|
||||
taskIDs[i] = id
|
||||
}
|
||||
|
||||
env.svc.StartProcessor(ctx)
|
||||
defer env.svc.StopProcessor()
|
||||
|
||||
// Consumer grabs first ID immediately; remaining 15 sit in channel.
|
||||
// Push one more to fill buffer to 16 (full).
|
||||
for _, id := range taskIDs {
|
||||
env.svc.taskCh <- id
|
||||
}
|
||||
|
||||
// Overflow submit: must return within 3s (non-blocking after fix)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, submitErr := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "overflow-task",
|
||||
})
|
||||
done <- submitErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Logf("SubmitAsync returned error (acceptable after fix): %v", err)
|
||||
} else {
|
||||
t.Log("SubmitAsync returned without blocking — channel send is non-blocking")
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("SubmitAsync blocked for >3s — channel send is blocking, potential deadlock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTaskService_Retry_ChannelFull_NonBlocking verifies processWithRetry
|
||||
// does not deadlock when re-enqueuing a failed task into a full channel.
|
||||
// Before fix: processWithRetry holds s.mu while blocking on s.taskCh <- taskID → deadlock.
|
||||
// After fix: non-blocking select drops the retry with a Warn log.
|
||||
func TestTaskService_Retry_ChannelFull_NonBlocking(t *testing.T) {
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"slurm down"}`))
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "retry-full-app", "#!/bin/bash\necho hello")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
|
||||
taskIDs := make([]int64, 17)
|
||||
for i := range taskIDs {
|
||||
id, err := env.taskStore.Create(ctx, &model.Task{
|
||||
TaskName: fmt.Sprintf("retry-%d", i),
|
||||
AppID: appID,
|
||||
AppName: "retry-full-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
CurrentStep: model.TaskStepSubmitting,
|
||||
RetryCount: 0,
|
||||
SubmittedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create retry task %d: %v", i, err)
|
||||
}
|
||||
taskIDs[i] = id
|
||||
}
|
||||
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
// Push all 17 IDs: consumer grabs one (processing ~1s), 16 fill the buffer
|
||||
for _, id := range taskIDs {
|
||||
env.svc.taskCh <- id
|
||||
}
|
||||
|
||||
// Wait for consumer to finish first task and attempt retry into full channel
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// If processWithRetry deadlocked holding s.mu, StopProcessor hangs on mutex acquisition
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
env.svc.StopProcessor()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("StopProcessor completed — retry channel send is non-blocking")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("StopProcessor did not complete within 5s — deadlock from retry channel send")
|
||||
}
|
||||
}
|
||||
294
internal/service/task_service_status_test.go
Normal file
294
internal/service/task_service_status_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/slurm"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func newTaskSvcTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Task{}); err != nil {
|
||||
t.Fatalf("auto migrate: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
type taskSvcTestEnv struct {
|
||||
taskStore *store.TaskStore
|
||||
jobSvc *JobService
|
||||
svc *TaskService
|
||||
srv *httptest.Server
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv {
|
||||
t.Helper()
|
||||
db := newTaskSvcTestDB(t)
|
||||
ts := store.NewTaskStore(db)
|
||||
|
||||
srv := httptest.NewServer(handler)
|
||||
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||
jobSvc := NewJobService(client, zap.NewNop())
|
||||
svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop())
|
||||
|
||||
return &taskSvcTestEnv{
|
||||
taskStore: ts,
|
||||
jobSvc: jobSvc,
|
||||
svc: svc,
|
||||
srv: srv,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *taskSvcTestEnv) close() {
|
||||
e.srv.Close()
|
||||
}
|
||||
|
||||
func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task {
|
||||
return &model.Task{
|
||||
TaskName: name,
|
||||
AppID: 1,
|
||||
AppName: "test-app",
|
||||
Status: status,
|
||||
CurrentStep: "",
|
||||
RetryCount: 0,
|
||||
UserID: "user1",
|
||||
SubmittedAt: time.Now(),
|
||||
SlurmJobID: slurmJobID,
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_MapSlurmState_AllStates(t *testing.T) {
|
||||
env := newTaskSvcTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
cases := []struct {
|
||||
input []string
|
||||
expected string
|
||||
}{
|
||||
{[]string{"PENDING"}, model.TaskStatusQueued},
|
||||
{[]string{"RUNNING"}, model.TaskStatusRunning},
|
||||
{[]string{"CONFIGURING"}, model.TaskStatusRunning},
|
||||
{[]string{"COMPLETING"}, model.TaskStatusRunning},
|
||||
{[]string{"COMPLETED"}, model.TaskStatusCompleted},
|
||||
{[]string{"FAILED"}, model.TaskStatusFailed},
|
||||
{[]string{"CANCELLED"}, model.TaskStatusFailed},
|
||||
{[]string{"TIMEOUT"}, model.TaskStatusFailed},
|
||||
{[]string{"NODE_FAIL"}, model.TaskStatusFailed},
|
||||
{[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed},
|
||||
{[]string{"PREEMPTED"}, model.TaskStatusFailed},
|
||||
{[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning},
|
||||
{[]string{"unknown_state"}, model.TaskStatusRunning},
|
||||
{[]string{"pending"}, model.TaskStatusQueued},
|
||||
{[]string{"Running"}, model.TaskStatusRunning},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := env.svc.mapSlurmStateToTaskStatus(tc.input)
|
||||
if got != tc.expected {
|
||||
t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_MapSlurmState_Empty(t *testing.T) {
|
||||
env := newTaskSvcTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
got := env.svc.mapSlurmStateToTaskStatus([]string{})
|
||||
if got != model.TaskStatusRunning {
|
||||
t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning)
|
||||
}
|
||||
|
||||
got = env.svc.mapSlurmStateToTaskStatus(nil)
|
||||
if got != model.TaskStatusRunning {
|
||||
t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := slurm.OpenapiJobInfoResp{
|
||||
Jobs: slurm.JobInfoMsg{
|
||||
{
|
||||
JobID: &jobID,
|
||||
JobState: []string{"RUNNING"},
|
||||
},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
ctx := context.Background()
|
||||
task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID)
|
||||
id, err := env.taskStore.Create(ctx, task)
|
||||
if err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.refreshTaskStatus(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("refreshTaskStatus: %v", err)
|
||||
}
|
||||
|
||||
updated, _ := env.taskStore.GetByID(ctx, id)
|
||||
if updated.Status != model.TaskStatusRunning {
|
||||
t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) {
|
||||
env := newTaskSvcTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
ctx := context.Background()
|
||||
task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil)
|
||||
id, _ := env.taskStore.Create(ctx, task)
|
||||
|
||||
err := env.svc.refreshTaskStatus(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
got, _ := env.taskStore.GetByID(ctx, id)
|
||||
if got.Status != model.TaskStatusQueued {
|
||||
t.Errorf("status should remain unchanged, got %q", got.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"down"}`))
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
ctx := context.Background()
|
||||
task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID)
|
||||
id, _ := env.taskStore.Create(ctx, task)
|
||||
|
||||
err := env.svc.refreshTaskStatus(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error (soft fail), got %v", err)
|
||||
}
|
||||
|
||||
got, _ := env.taskStore.GetByID(ctx, id)
|
||||
if got.Status != model.TaskStatusQueued {
|
||||
t.Errorf("status should remain unchanged on slurm error, got %q", got.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := slurm.OpenapiJobInfoResp{
|
||||
Jobs: slurm.JobInfoMsg{
|
||||
{
|
||||
JobID: &jobID,
|
||||
JobState: []string{"RUNNING"},
|
||||
},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
ctx := context.Background()
|
||||
task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID)
|
||||
id, _ := env.taskStore.Create(ctx, task)
|
||||
|
||||
err := env.svc.refreshTaskStatus(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("refreshTaskStatus: %v", err)
|
||||
}
|
||||
|
||||
got, _ := env.taskStore.GetByID(ctx, id)
|
||||
if got.Status != model.TaskStatusRunning {
|
||||
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
slurmQueried := false
|
||||
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slurmQueried = true
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
ctx := context.Background()
|
||||
task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID)
|
||||
id, _ := env.taskStore.Create(ctx, task)
|
||||
|
||||
freshTask, _ := env.taskStore.GetByID(ctx, id)
|
||||
if freshTask == nil {
|
||||
t.Fatal("task not found")
|
||||
}
|
||||
|
||||
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id)
|
||||
|
||||
err := env.svc.RefreshStaleTasks(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshStaleTasks: %v", err)
|
||||
}
|
||||
|
||||
if slurmQueried {
|
||||
t.Error("expected no Slurm query for fresh task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := slurm.OpenapiJobInfoResp{
|
||||
Jobs: slurm.JobInfoMsg{
|
||||
{
|
||||
JobID: &jobID,
|
||||
JobState: []string{"COMPLETED"},
|
||||
},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
ctx := context.Background()
|
||||
task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID)
|
||||
id, _ := env.taskStore.Create(ctx, task)
|
||||
|
||||
staleTime := time.Now().Add(-60 * time.Second)
|
||||
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id)
|
||||
|
||||
err := env.svc.RefreshStaleTasks(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshStaleTasks: %v", err)
|
||||
}
|
||||
|
||||
got, _ := env.taskStore.GetByID(ctx, id)
|
||||
if got.Status != model.TaskStatusCompleted {
|
||||
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted)
|
||||
}
|
||||
}
|
||||
538
internal/service/task_service_test.go
Normal file
538
internal/service/task_service_test.go
Normal file
@@ -0,0 +1,538 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/slurm"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
gormlogger "gorm.io/gorm/logger"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupTaskTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||
t.Fatalf("auto migrate: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
type taskTestEnv struct {
|
||||
taskStore *store.TaskStore
|
||||
appStore *store.ApplicationStore
|
||||
fileStore *store.FileStore
|
||||
blobStore *store.BlobStore
|
||||
svc *TaskService
|
||||
srv *httptest.Server
|
||||
db *gorm.DB
|
||||
workDirBase string
|
||||
}
|
||||
|
||||
func newTaskTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *taskTestEnv {
|
||||
t.Helper()
|
||||
db := setupTaskTestDB(t)
|
||||
|
||||
ts := store.NewTaskStore(db)
|
||||
as := store.NewApplicationStore(db)
|
||||
fs := store.NewFileStore(db)
|
||||
bs := store.NewBlobStore(db)
|
||||
|
||||
srv := httptest.NewServer(slurmHandler)
|
||||
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||
jobSvc := NewJobService(client, zap.NewNop())
|
||||
|
||||
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||
os.MkdirAll(workDirBase, 0777)
|
||||
|
||||
svc := NewTaskService(ts, as, fs, bs, nil, jobSvc, workDirBase, zap.NewNop())
|
||||
|
||||
return &taskTestEnv{
|
||||
taskStore: ts,
|
||||
appStore: as,
|
||||
fileStore: fs,
|
||||
blobStore: bs,
|
||||
svc: svc,
|
||||
srv: srv,
|
||||
db: db,
|
||||
workDirBase: workDirBase,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *taskTestEnv) close() {
|
||||
e.srv.Close()
|
||||
}
|
||||
|
||||
func (e *taskTestEnv) createApp(t *testing.T, name, script string, params json.RawMessage) int64 {
|
||||
t.Helper()
|
||||
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||
Name: name,
|
||||
ScriptTemplate: script,
|
||||
Parameters: params,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create app: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestTaskService_CreateTask_Success(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "my-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`))
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "test-task",
|
||||
Values: map[string]string{"KEY": "val"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
if task.ID == 0 {
|
||||
t.Error("expected non-zero task ID")
|
||||
}
|
||||
if task.AppID != appID {
|
||||
t.Errorf("AppID = %d, want %d", task.AppID, appID)
|
||||
}
|
||||
if task.AppName != "my-app" {
|
||||
t.Errorf("AppName = %q, want %q", task.AppName, "my-app")
|
||||
}
|
||||
if task.Status != model.TaskStatusSubmitted {
|
||||
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusSubmitted)
|
||||
}
|
||||
if task.TaskName != "test-task" {
|
||||
t.Errorf("TaskName = %q, want %q", task.TaskName, "test-task")
|
||||
}
|
||||
|
||||
var values map[string]string
|
||||
if err := json.Unmarshal(task.Values, &values); err != nil {
|
||||
t.Fatalf("unmarshal values: %v", err)
|
||||
}
|
||||
if values["KEY"] != "val" {
|
||||
t.Errorf("values[KEY] = %q, want %q", values["KEY"], "val")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_CreateTask_InvalidAppID(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: 999,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid app_id")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("error should mention 'not found', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_CreateTask_ExceedsFileLimit(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
|
||||
|
||||
fileIDs := make([]int64, 101)
|
||||
for i := range fileIDs {
|
||||
fileIDs[i] = int64(i + 1)
|
||||
}
|
||||
|
||||
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
InputFileIDs: fileIDs,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for exceeding file limit")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "exceeds limit") {
|
||||
t.Errorf("error should mention limit, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_CreateTask_DuplicateFileIDs(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
for _, id := range []int64{1, 2} {
|
||||
f := &model.File{
|
||||
Name: "file.txt",
|
||||
BlobSHA256: "abc123",
|
||||
}
|
||||
if err := env.fileStore.Create(ctx, f); err != nil {
|
||||
t.Fatalf("create file: %v", err)
|
||||
}
|
||||
if f.ID != id {
|
||||
t.Fatalf("expected file ID %d, got %d", id, f.ID)
|
||||
}
|
||||
}
|
||||
|
||||
task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
InputFileIDs: []int64{1, 1, 2, 2},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
var fileIDs []int64
|
||||
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
|
||||
t.Fatalf("unmarshal file ids: %v", err)
|
||||
}
|
||||
if len(fileIDs) != 2 {
|
||||
t.Fatalf("expected 2 deduplicated file IDs, got %d: %v", len(fileIDs), fileIDs)
|
||||
}
|
||||
if fileIDs[0] != 1 || fileIDs[1] != 2 {
|
||||
t.Errorf("expected [1,2], got %v", fileIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_CreateTask_AutoName(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "My Cool App", "#!/bin/bash\necho hi", nil)
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(task.TaskName, "My_Cool_App_") {
|
||||
t.Errorf("auto-generated name should start with 'My_Cool_App_', got %q", task.TaskName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_CreateTask_NilValues(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
Values: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
if string(task.Values) != `{}` {
|
||||
t.Errorf("Values = %q, want {}", string(task.Values))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_Success(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "test-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
Values: map[string]string{"INPUT": "hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessTask: %v", err)
|
||||
}
|
||||
|
||||
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||
if updated.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||
}
|
||||
if updated.SlurmJobID == nil || *updated.SlurmJobID != 42 {
|
||||
t.Errorf("SlurmJobID = %v, want 42", updated.SlurmJobID)
|
||||
}
|
||||
if updated.WorkDir == "" {
|
||||
t.Error("WorkDir should not be empty")
|
||||
}
|
||||
if !strings.HasPrefix(updated.WorkDir, env.workDirBase) {
|
||||
t.Errorf("WorkDir = %q, should start with %q", updated.WorkDir, env.workDirBase)
|
||||
}
|
||||
|
||||
info, err := os.Stat(updated.WorkDir)
|
||||
if err != nil {
|
||||
t.Fatalf("stat workdir: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Error("WorkDir should be a directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_TaskNotFound(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
err := env.svc.ProcessTask(context.Background(), 999)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-existent task")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("error should mention 'not found', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_SlurmError(t *testing.T) {
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"slurm down"}`))
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "test-app", "#!/bin/bash\necho hello", nil)
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from Slurm")
|
||||
}
|
||||
|
||||
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||
if updated.Status != model.TaskStatusFailed {
|
||||
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusFailed)
|
||||
}
|
||||
if updated.CurrentStep != model.TaskStepSubmitting {
|
||||
t.Errorf("CurrentStep = %q, want %q", updated.CurrentStep, model.TaskStepSubmitting)
|
||||
}
|
||||
if !strings.Contains(updated.ErrorMessage, "submit job") {
|
||||
t.Errorf("ErrorMessage should mention 'submit job', got: %q", updated.ErrorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTaskSync(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "sync-app", "#!/bin/bash\necho hello", nil)
|
||||
|
||||
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessTaskSync: %v", err)
|
||||
}
|
||||
if resp.JobID != 42 {
|
||||
t.Errorf("JobID = %d, want 42", resp.JobID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTaskSync_NoMinIO(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "no-minio-app", "#!/bin/bash\necho hello", nil)
|
||||
|
||||
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
InputFileIDs: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessTaskSync: %v", err)
|
||||
}
|
||||
if resp.JobID != 42 {
|
||||
t.Errorf("JobID = %d, want 42", resp.JobID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_NilValues(t *testing.T) {
|
||||
jobID := int32(55)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "nil-val-app", "#!/bin/bash\necho hello", nil)
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
Values: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessTask: %v", err)
|
||||
}
|
||||
|
||||
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||
if updated.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ListTasks(t *testing.T) {
|
||||
env := newTaskTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "list-app", "#!/bin/bash\necho hi", nil)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "task-" + string(rune('A'+i)),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
tasks, total, err := env.svc.ListTasks(context.Background(), &model.TaskListQuery{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListTasks: %v", err)
|
||||
}
|
||||
if total != 3 {
|
||||
t.Errorf("total = %d, want 3", total)
|
||||
}
|
||||
if len(tasks) != 3 {
|
||||
t.Errorf("len(tasks) = %d, want 3", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_ValidateParams_MissingRequired(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
// App requires INPUT param, but we submit without it
|
||||
appID := env.createApp(t, "validation-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
Values: map[string]string{}, // missing required INPUT
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing required parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "missing") && !strings.Contains(errStr, "INPUT") {
|
||||
t.Errorf("error should mention 'validation', 'missing', or 'INPUT', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_ValidateParams_InvalidInteger(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
// App expects integer param NUM, but we submit "abc"
|
||||
appID := env.createApp(t, "int-validation-app", "#!/bin/bash\necho $NUM", json.RawMessage(`[{"name":"NUM","type":"integer","required":true}]`))
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
Values: map[string]string{"NUM": "abc"}, // invalid integer
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid integer parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, "integer") && !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "NUM") {
|
||||
t.Errorf("error should mention 'integer', 'validation', or 'NUM', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) {
|
||||
jobID := int32(99)
|
||||
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "valid-params-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
|
||||
|
||||
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
Values: map[string]string{"INPUT": "hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTask: %v", err)
|
||||
}
|
||||
|
||||
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessTask with valid params: %v", err)
|
||||
}
|
||||
|
||||
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||
if updated.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||
}
|
||||
if updated.SlurmJobID == nil || *updated.SlurmJobID != 99 {
|
||||
t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user