feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission

This commit is contained in:
dailz
2026-04-15 21:31:02 +08:00
parent acf8c1d62b
commit ec64300ff2
9 changed files with 2394 additions and 136 deletions

View File

@@ -0,0 +1,294 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskSvcTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type taskSvcTestEnv struct {
taskStore *store.TaskStore
jobSvc *JobService
svc *TaskService
srv *httptest.Server
db *gorm.DB
}
func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv {
t.Helper()
db := newTaskSvcTestDB(t)
ts := store.NewTaskStore(db)
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop())
return &taskSvcTestEnv{
taskStore: ts,
jobSvc: jobSvc,
svc: svc,
srv: srv,
db: db,
}
}
func (e *taskSvcTestEnv) close() {
e.srv.Close()
}
func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task {
return &model.Task{
TaskName: name,
AppID: 1,
AppName: "test-app",
Status: status,
CurrentStep: "",
RetryCount: 0,
UserID: "user1",
SubmittedAt: time.Now(),
SlurmJobID: slurmJobID,
}
}
func TestTaskService_MapSlurmState_AllStates(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
cases := []struct {
input []string
expected string
}{
{[]string{"PENDING"}, model.TaskStatusQueued},
{[]string{"RUNNING"}, model.TaskStatusRunning},
{[]string{"CONFIGURING"}, model.TaskStatusRunning},
{[]string{"COMPLETING"}, model.TaskStatusRunning},
{[]string{"COMPLETED"}, model.TaskStatusCompleted},
{[]string{"FAILED"}, model.TaskStatusFailed},
{[]string{"CANCELLED"}, model.TaskStatusFailed},
{[]string{"TIMEOUT"}, model.TaskStatusFailed},
{[]string{"NODE_FAIL"}, model.TaskStatusFailed},
{[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed},
{[]string{"PREEMPTED"}, model.TaskStatusFailed},
{[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning},
{[]string{"unknown_state"}, model.TaskStatusRunning},
{[]string{"pending"}, model.TaskStatusQueued},
{[]string{"Running"}, model.TaskStatusRunning},
}
for _, tc := range cases {
got := env.svc.mapSlurmStateToTaskStatus(tc.input)
if got != tc.expected {
t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected)
}
}
}
func TestTaskService_MapSlurmState_Empty(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
got := env.svc.mapSlurmStateToTaskStatus([]string{})
if got != model.TaskStatusRunning {
t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning)
}
got = env.svc.mapSlurmStateToTaskStatus(nil)
if got != model.TaskStatusRunning {
t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"RUNNING"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID)
id, err := env.taskStore.Create(ctx, task)
if err != nil {
t.Fatalf("Create: %v", err)
}
err = env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("refreshTaskStatus: %v", err)
}
updated, _ := env.taskStore.GetByID(ctx, id)
if updated.Status != model.TaskStatusRunning {
t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
ctx := context.Background()
task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusQueued {
t.Errorf("status should remain unchanged, got %q", got.Status)
}
}
func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("expected no error (soft fail), got %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusQueued {
t.Errorf("status should remain unchanged on slurm error, got %q", got.Status)
}
}
func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"RUNNING"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("refreshTaskStatus: %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusRunning {
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) {
jobID := int32(42)
slurmQueried := false
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slurmQueried = true
w.WriteHeader(http.StatusInternalServerError)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID)
id, _ := env.taskStore.Create(ctx, task)
freshTask, _ := env.taskStore.GetByID(ctx, id)
if freshTask == nil {
t.Fatal("task not found")
}
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id)
err := env.svc.RefreshStaleTasks(ctx)
if err != nil {
t.Fatalf("RefreshStaleTasks: %v", err)
}
if slurmQueried {
t.Error("expected no Slurm query for fresh task")
}
}
func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"COMPLETED"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID)
id, _ := env.taskStore.Create(ctx, task)
staleTime := time.Now().Add(-60 * time.Second)
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id)
err := env.svc.RefreshStaleTasks(ctx)
if err != nil {
t.Fatalf("RefreshStaleTasks: %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusCompleted {
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted)
}
}