feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission
This commit is contained in:
416
internal/service/task_service_async_test.go
Normal file
416
internal/service/task_service_async_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/slurm"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
gormlogger "gorm.io/gorm/logger"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupAsyncTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||
t.Fatalf("auto migrate: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
type asyncTestEnv struct {
|
||||
taskStore *store.TaskStore
|
||||
appStore *store.ApplicationStore
|
||||
svc *TaskService
|
||||
srv *httptest.Server
|
||||
db *gorm.DB
|
||||
workDirBase string
|
||||
}
|
||||
|
||||
func newAsyncTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *asyncTestEnv {
|
||||
t.Helper()
|
||||
db := setupAsyncTestDB(t)
|
||||
|
||||
ts := store.NewTaskStore(db)
|
||||
as := store.NewApplicationStore(db)
|
||||
|
||||
srv := httptest.NewServer(slurmHandler)
|
||||
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||
jobSvc := NewJobService(client, zap.NewNop())
|
||||
|
||||
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||
os.MkdirAll(workDirBase, 0777)
|
||||
|
||||
svc := NewTaskService(ts, as, nil, nil, nil, jobSvc, workDirBase, zap.NewNop())
|
||||
|
||||
return &asyncTestEnv{
|
||||
taskStore: ts,
|
||||
appStore: as,
|
||||
svc: svc,
|
||||
srv: srv,
|
||||
db: db,
|
||||
workDirBase: workDirBase,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *asyncTestEnv) close() {
|
||||
e.srv.Close()
|
||||
}
|
||||
|
||||
func (e *asyncTestEnv) createApp(t *testing.T, name, script string) int64 {
|
||||
t.Helper()
|
||||
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||
Name: name,
|
||||
ScriptTemplate: script,
|
||||
Parameters: json.RawMessage(`[]`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create app: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestTaskService_Async_SubmitAndProcess(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "async-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "async-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitAsync: %v", err)
|
||||
}
|
||||
if taskID == 0 {
|
||||
t.Fatal("expected non-zero task ID")
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
task, err := env.taskStore.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID: %v", err)
|
||||
}
|
||||
if task.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusQueued)
|
||||
}
|
||||
|
||||
env.svc.StopProcessor()
|
||||
}
|
||||
|
||||
func TestTaskService_Retry_MaxExhaustion(t *testing.T) {
|
||||
callCount := int32(0)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"slurm down"}`))
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "retry-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "retry-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitAsync: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
task, _ := env.taskStore.GetByID(ctx, taskID)
|
||||
if task.Status != model.TaskStatusFailed {
|
||||
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusFailed)
|
||||
}
|
||||
if task.RetryCount < 3 {
|
||||
t.Errorf("RetryCount = %d, want >= 3", task.RetryCount)
|
||||
}
|
||||
|
||||
env.svc.StopProcessor()
|
||||
}
|
||||
|
||||
func TestTaskService_Recover_StuckTasks(t *testing.T) {
|
||||
jobID := int32(99)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "stuck-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
task := &model.Task{
|
||||
TaskName: "stuck-task",
|
||||
AppID: appID,
|
||||
AppName: "stuck-app",
|
||||
Status: model.TaskStatusPreparing,
|
||||
CurrentStep: model.TaskStepPreparing,
|
||||
RetryCount: 0,
|
||||
SubmittedAt: time.Now(),
|
||||
}
|
||||
taskID, err := env.taskStore.Create(ctx, task)
|
||||
if err != nil {
|
||||
t.Fatalf("Create stuck task: %v", err)
|
||||
}
|
||||
|
||||
staleTime := time.Now().Add(-10 * time.Minute)
|
||||
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, taskID)
|
||||
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
updated, _ := env.taskStore.GetByID(ctx, taskID)
|
||||
if updated.Status != model.TaskStatusQueued {
|
||||
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||
}
|
||||
|
||||
env.svc.StopProcessor()
|
||||
}
|
||||
|
||||
func TestTaskService_Shutdown_InFlight(t *testing.T) {
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
jobID := int32(77)
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "shutdown-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "shutdown-test",
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
env.svc.StopProcessor()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("StopProcessor did not complete within timeout")
|
||||
}
|
||||
|
||||
task, _ := env.taskStore.GetByID(ctx, taskID)
|
||||
if task.Status != model.TaskStatusQueued && task.Status != model.TaskStatusSubmitted {
|
||||
t.Logf("task status after shutdown: %q (acceptable)", task.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskService_PanicRecovery(t *testing.T) {
|
||||
jobID := int32(55)
|
||||
panicDone := int32(0)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if atomic.CompareAndSwapInt32(&panicDone, 0, 1) {
|
||||
panic("intentional test panic")
|
||||
}
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "panic-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "panic-test",
|
||||
})
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
atomic.StoreInt32(&panicDone, 1)
|
||||
|
||||
env.svc.StopProcessor()
|
||||
_ = taskID
|
||||
}
|
||||
|
||||
func TestTaskService_SubmitAsync_DuringShutdown(t *testing.T) {
|
||||
env := newAsyncTestEnv(t, nil)
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "shutdown-err-app", "#!/bin/bash\necho hello")
|
||||
|
||||
ctx := context.Background()
|
||||
env.svc.StartProcessor(ctx)
|
||||
env.svc.StopProcessor()
|
||||
|
||||
_, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "after-shutdown",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when submitting after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTaskService_SubmitAsync_ChannelFull_NonBlocking verifies SubmitAsync
|
||||
// returns without blocking when the task channel buffer (cap=16) is full.
|
||||
// Before fix: SubmitAsync holds s.mu while blocking on full channel → deadlock.
|
||||
// After fix: non-blocking select returns immediately.
|
||||
func TestTaskService_SubmitAsync_ChannelFull_NonBlocking(t *testing.T) {
|
||||
jobID := int32(42)
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(5 * time.Second)
|
||||
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||
})
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "channel-full-app", "#!/bin/bash\necho hello")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
|
||||
taskIDs := make([]int64, 17)
|
||||
for i := range taskIDs {
|
||||
id, err := env.taskStore.Create(ctx, &model.Task{
|
||||
TaskName: fmt.Sprintf("fill-%d", i),
|
||||
AppID: appID,
|
||||
AppName: "channel-full-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
CurrentStep: model.TaskStepSubmitting,
|
||||
SubmittedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create fill task %d: %v", i, err)
|
||||
}
|
||||
taskIDs[i] = id
|
||||
}
|
||||
|
||||
env.svc.StartProcessor(ctx)
|
||||
defer env.svc.StopProcessor()
|
||||
|
||||
// Consumer grabs first ID immediately; remaining 15 sit in channel.
|
||||
// Push one more to fill buffer to 16 (full).
|
||||
for _, id := range taskIDs {
|
||||
env.svc.taskCh <- id
|
||||
}
|
||||
|
||||
// Overflow submit: must return within 3s (non-blocking after fix)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, submitErr := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "overflow-task",
|
||||
})
|
||||
done <- submitErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Logf("SubmitAsync returned error (acceptable after fix): %v", err)
|
||||
} else {
|
||||
t.Log("SubmitAsync returned without blocking — channel send is non-blocking")
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("SubmitAsync blocked for >3s — channel send is blocking, potential deadlock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTaskService_Retry_ChannelFull_NonBlocking verifies processWithRetry
|
||||
// does not deadlock when re-enqueuing a failed task into a full channel.
|
||||
// Before fix: processWithRetry holds s.mu while blocking on s.taskCh <- taskID → deadlock.
|
||||
// After fix: non-blocking select drops the retry with a Warn log.
|
||||
func TestTaskService_Retry_ChannelFull_NonBlocking(t *testing.T) {
|
||||
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"slurm down"}`))
|
||||
}))
|
||||
defer env.close()
|
||||
|
||||
appID := env.createApp(t, "retry-full-app", "#!/bin/bash\necho hello")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
|
||||
taskIDs := make([]int64, 17)
|
||||
for i := range taskIDs {
|
||||
id, err := env.taskStore.Create(ctx, &model.Task{
|
||||
TaskName: fmt.Sprintf("retry-%d", i),
|
||||
AppID: appID,
|
||||
AppName: "retry-full-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
CurrentStep: model.TaskStepSubmitting,
|
||||
RetryCount: 0,
|
||||
SubmittedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create retry task %d: %v", i, err)
|
||||
}
|
||||
taskIDs[i] = id
|
||||
}
|
||||
|
||||
env.svc.StartProcessor(ctx)
|
||||
|
||||
// Push all 17 IDs: consumer grabs one (processing ~1s), 16 fill the buffer
|
||||
for _, id := range taskIDs {
|
||||
env.svc.taskCh <- id
|
||||
}
|
||||
|
||||
// Wait for consumer to finish first task and attempt retry into full channel
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// If processWithRetry deadlocked holding s.mu, StopProcessor hangs on mutex acquisition
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
env.svc.StopProcessor()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("StopProcessor completed — retry channel send is non-blocking")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("StopProcessor did not complete within 5s — deadlock from retry channel send")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user