diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go index 5e531a1..7cb179a 100644 --- a/cmd/server/main_test.go +++ b/cmd/server/main_test.go @@ -45,6 +45,7 @@ func TestRouterRegistration(t *testing.T) { appH, nil, nil, nil, nil, + nil, ) routes := router.Routes() @@ -106,6 +107,7 @@ func TestSmokeGetJobsEndpoint(t *testing.T) { appH, nil, nil, nil, nil, + nil, ) w := httptest.NewRecorder() diff --git a/cmd/server/task_test.go b/cmd/server/task_test.go new file mode 100644 index 0000000..e2ce2e1 --- /dev/null +++ b/cmd/server/task_test.go @@ -0,0 +1,431 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "gcy_hpc_server/internal/handler" + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/server" + "gcy_hpc_server/internal/service" + "gcy_hpc_server/internal/slurm" + "gcy_hpc_server/internal/store" + + "go.uber.org/zap" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func newTaskTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + sqlDB, _ := db.DB() + sqlDB.SetMaxOpenConns(1) + db.AutoMigrate( + &model.Application{}, + &model.File{}, + &model.FileBlob{}, + &model.Task{}, + ) + t.Cleanup(func() { + sqlDB.Close() + }) + return db +} + +type mockSlurmHandler struct { + submitFn func(w http.ResponseWriter, r *http.Request) +} + +func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) { + if h != nil && h.submitFn != nil { + h.submitFn(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`) + }) + mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}}) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{}) + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) { + t.Helper() + return setupTaskTestServerWithSlurm(t, nil) +} + +func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) { + t.Helper() + + db := newTaskTestDB(t) + slurmSrv := newMockSlurmServer(t, slurmHandler) + + client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client())) + if err != nil { + t.Fatalf("slurm client: %v", err) + } + + log := zap.NewNop() + jobSvc := service.NewJobService(client, log) + appStore := store.NewApplicationStore(db) + taskStore := store.NewTaskStore(db) + + tmpDir, err := os.MkdirTemp("", "task-test-workdir-*") + if err != nil { + t.Fatalf("temp dir: %v", err) + } + + taskSvc := service.NewTaskService( + taskStore, appStore, + nil, nil, nil, + jobSvc, + tmpDir, + log, + ) + + ctx, cancel := context.WithCancel(context.Background()) + taskSvc.StartProcessor(ctx) + + appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc) + appH := handler.NewApplicationHandler(appSvc, log) + taskH := handler.NewTaskHandler(taskSvc, log) + + router := server.NewRouter( + handler.NewJobHandler(jobSvc, log), + handler.NewClusterHandler(service.NewClusterService(client, log), log), + appH, + nil, nil, nil, + taskH, + log, + ) + + httpSrv := httptest.NewServer(router) + cleanup := func() { + taskSvc.StopProcessor() + cancel() + httpSrv.Close() + os.RemoveAll(tmpDir) + } + + return httpSrv, taskStore, cleanup +} + +func createTestApp(t *testing.T, srvURL string) int64 { + t.Helper() + body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}` + resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body))) + if err != nil { + t.Fatalf("create app: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("create app: status %d, body %s", resp.StatusCode, b) + } + var result struct { + Data struct { + ID int64 `json:"id"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("decode app response: %v", err) + } + return result.Data.ID +} + +func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) { + t.Helper() + resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body))) + if err != nil { + t.Fatalf("post task: %v", err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + var result map[string]interface{} + json.Unmarshal(b, &result) + return resp, result +} + +func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) { + t.Helper() + url := srvURL + "/api/v1/tasks" + if query != "" { + url += "?" + query + } + resp, err := http.Get(url) + if err != nil { + t.Fatalf("get tasks: %v", err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + var result map[string]interface{} + json.Unmarshal(b, &result) + data, _ := result["data"].(map[string]interface{}) + items, _ := data["items"].([]interface{}) + total, _ := data["total"].(float64) + return int(total), items +} + +func TestTask_FullLifecycle(t *testing.T) { + srv, _, cleanup := setupTaskTestServer(t) + defer cleanup() + + appID := createTestApp(t, srv.URL) + + resp, result := postTask(t, srv.URL, fmt.Sprintf( + `{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID, + )) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + } + + data, _ := result["data"].(map[string]interface{}) + if data["id"] == nil { + t.Fatal("expected id in response") + } + + time.Sleep(300 * time.Millisecond) + + total, items := getTasks(t, srv.URL, "") + if total < 1 { + t.Fatalf("expected at least 1 task, got %d", total) + } + + task := items[0].(map[string]interface{}) + if task["status"] == nil || task["status"] == "" { + t.Error("expected non-empty status") + } + if task["task_name"] != "lifecycle-test" { + t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"]) + } + if task["app_id"] != float64(appID) { + t.Errorf("expected app_id=%d, got %v", appID, task["app_id"]) + } +} + +func TestTask_CreateWithMissingApp(t *testing.T) { + srv, _, cleanup := setupTaskTestServer(t) + defer cleanup() + + resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`) + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result) + } +} + +func TestTask_CreateWithInvalidBody(t *testing.T) { + srv, _, cleanup := setupTaskTestServer(t) + defer cleanup() + + resp, result := postTask(t, srv.URL, `{}`) + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result) + } +} + +func TestTask_FileLimitExceeded(t *testing.T) { + srv, _, cleanup := setupTaskTestServer(t) + defer cleanup() + + appID := createTestApp(t, srv.URL) + + fileIDs := make([]int64, 101) + for i := range fileIDs { + fileIDs[i] = int64(i + 1) + } + idsJSON, _ := json.Marshal(fileIDs) + + body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON)) + resp, result := postTask(t, srv.URL, body) + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result) + } +} + +func TestTask_RetryScenario(t *testing.T) { + var failCount int32 + slurmH := &mockSlurmHandler{ + submitFn: func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&failCount, 1) <= 2 { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`) + return + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`) + }, + } + + srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH) + defer cleanup() + + appID := createTestApp(t, srv.URL) + + resp, result := postTask(t, srv.URL, fmt.Sprintf( + `{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID, + )) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + } + + taskID := int64(result["data"].(map[string]interface{})["id"].(float64)) + + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + task, _ := taskStore.GetByID(context.Background(), taskID) + if task != nil && task.Status == model.TaskStatusQueued { + if task.SlurmJobID != nil && *task.SlurmJobID == 99 { + return + } + } + time.Sleep(100 * time.Millisecond) + } + + task, _ := taskStore.GetByID(context.Background(), taskID) + if task == nil { + t.Fatalf("task %d not found after deadline", taskID) + } + t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v", + task.Status, task.RetryCount, task.SlurmJobID) +} + +func TestTask_OldAPICompatibility(t *testing.T) { + srv, _, cleanup := setupTaskTestServer(t) + defer cleanup() + + appID := createTestApp(t, srv.URL) + + body := `{"values":{"np":"8"}}` + url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID) + resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body))) + if err != nil { + t.Fatalf("post submit: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b) + } + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + + data, _ := result["data"].(map[string]interface{}) + if data == nil { + t.Fatal("expected data in response") + } + if data["job_id"] == nil { + t.Errorf("expected job_id in old API response, got: %v", data) + } + jobID, ok := data["job_id"].(float64) + if !ok || jobID == 0 { + t.Errorf("expected non-zero job_id, got %v", data["job_id"]) + } +} + +func TestTask_ListWithFilters(t *testing.T) { + srv, taskStore, cleanup := setupTaskTestServer(t) + defer cleanup() + + appID := createTestApp(t, srv.URL) + + now := time.Now() + taskStore.Create(context.Background(), &model.Task{ + TaskName: "task-completed", AppID: appID, AppName: "test-app", + Status: model.TaskStatusCompleted, SubmittedAt: now, + }) + taskStore.Create(context.Background(), &model.Task{ + TaskName: "task-failed", AppID: appID, AppName: "test-app", + Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now, + }) + taskStore.Create(context.Background(), &model.Task{ + TaskName: "task-queued", AppID: appID, AppName: "test-app", + Status: model.TaskStatusQueued, SubmittedAt: now, + }) + + total, items := getTasks(t, srv.URL, "status=completed") + if total != 1 { + t.Fatalf("expected 1 completed, got %d", total) + } + task := items[0].(map[string]interface{}) + if task["status"] != "completed" { + t.Errorf("expected status=completed, got %v", task["status"]) + } + + total, items = getTasks(t, srv.URL, "status=failed") + if total != 1 { + t.Fatalf("expected 1 failed, got %d", total) + } + + total, _ = getTasks(t, srv.URL, "") + if total != 3 { + t.Fatalf("expected 3 total, got %d", total) + } +} + +func TestTask_WorkDirCreated(t *testing.T) { + srv, taskStore, cleanup := setupTaskTestServer(t) + defer cleanup() + + appID := createTestApp(t, srv.URL) + + resp, result := postTask(t, srv.URL, fmt.Sprintf( + `{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID, + )) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + } + + taskID := int64(result["data"].(map[string]interface{})["id"].(float64)) + + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + task, _ := taskStore.GetByID(context.Background(), taskID) + if task != nil && task.WorkDir != "" { + if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) { + t.Fatalf("work dir %s not created", task.WorkDir) + } + if !filepath.IsAbs(task.WorkDir) { + t.Errorf("expected absolute work dir, got %s", task.WorkDir) + } + return + } + time.Sleep(50 * time.Millisecond) + } + + task, _ := taskStore.GetByID(context.Background(), taskID) + if task == nil { + t.Fatalf("task %d not found after deadline", taskID) + } + t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir) +} diff --git a/internal/app/app.go b/internal/app/app.go index e8fdfbc..785b6b9 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -28,6 +28,8 @@ type App struct { db *gorm.DB server *http.Server cancelCleanup context.CancelFunc + taskSvc *service.TaskService + taskPoller *TaskPoller } // NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router. @@ -43,7 +45,7 @@ func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) { return nil, err } - srv, cancelCleanup := initHTTPServer(cfg, gormDB, slurmClient, logger) + srv, cancelCleanup, taskSvc, taskPoller := initHTTPServer(cfg, gormDB, slurmClient, logger) return &App{ cfg: cfg, @@ -51,6 +53,8 @@ func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) { db: gormDB, server: srv, cancelCleanup: cancelCleanup, + taskSvc: taskSvc, + taskPoller: taskPoller, }, nil } @@ -86,6 +90,14 @@ func (a *App) Run() error { func (a *App) Close() error { var errs []error + if a.taskSvc != nil { + a.taskSvc.StopProcessor() + } + + if a.taskPoller != nil { + a.taskPoller.Stop() + } + if a.cancelCleanup != nil { a.cancelCleanup() } @@ -145,15 +157,15 @@ func initSlurmClient(cfg *config.Config) (*slurm.Client, error) { return client, nil } -func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc) { +func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc, *service.TaskService, *TaskPoller) { + ctx := context.Background() + jobSvc := service.NewJobService(slurmClient, logger) clusterSvc := service.NewClusterService(slurmClient, logger) jobH := handler.NewJobHandler(jobSvc, logger) clusterH := handler.NewClusterHandler(clusterSvc, logger) appStore := store.NewApplicationStore(db) - appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger) - appH := handler.NewApplicationHandler(appSvc, logger) // File storage initialization minioClient, err := storage.NewMinioClient(cfg.Minio) @@ -165,9 +177,12 @@ func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, var fileH *handler.FileHandler var folderH *handler.FolderHandler + taskStore := store.NewTaskStore(db) + fileStore := store.NewFileStore(db) + blobStore := store.NewBlobStore(db) + + var stagingSvc *service.FileStagingService if minioClient != nil { - blobStore := store.NewBlobStore(db) - fileStore := store.NewFileStore(db) folderStore := store.NewFolderStore(db) uploadStore := store.NewUploadStore(db) @@ -179,25 +194,34 @@ func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, fileH = handler.NewFileHandler(fileSvc, logger) folderH = handler.NewFolderHandler(folderSvc, logger) - cleanupCtx, cancelCleanup := context.WithCancel(context.Background()) - go startCleanupWorker(cleanupCtx, uploadStore, minioClient, cfg.Minio.Bucket, logger) - - router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, logger) - - addr := ":" + cfg.ServerPort - - return &http.Server{ - Addr: addr, - Handler: router, - }, cancelCleanup + stagingSvc = service.NewFileStagingService(fileStore, blobStore, minioClient, cfg.Minio.Bucket, logger) } - router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, logger) + taskSvc := service.NewTaskService(taskStore, appStore, fileStore, blobStore, stagingSvc, jobSvc, cfg.WorkDirBase, logger) + taskSvc.StartProcessor(ctx) + + appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger, taskSvc) + appH := handler.NewApplicationHandler(appSvc, logger) + + poller := NewTaskPoller(taskSvc, 10*time.Second, logger) + poller.Start(ctx) + + taskH := handler.NewTaskHandler(taskSvc, logger) + + var cancelCleanup context.CancelFunc + + if minioClient != nil { + cleanupCtx, cancel := context.WithCancel(context.Background()) + cancelCleanup = cancel + go startCleanupWorker(cleanupCtx, store.NewUploadStore(db), minioClient, cfg.Minio.Bucket, logger) + } + + router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, taskH, logger) addr := ":" + cfg.ServerPort return &http.Server{ Addr: addr, Handler: router, - }, nil + }, cancelCleanup, taskSvc, poller } diff --git a/internal/app/task_poller.go b/internal/app/task_poller.go new file mode 100644 index 0000000..4c2c313 --- /dev/null +++ b/internal/app/task_poller.go @@ -0,0 +1,61 @@ +package app + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" +) + +// TaskPollable defines the interface for refreshing stale task statuses. +type TaskPollable interface { + RefreshStaleTasks(ctx context.Context) error +} + +// TaskPoller periodically polls Slurm for task status updates via TaskPollable. +type TaskPoller struct { + taskSvc TaskPollable + interval time.Duration + cancel context.CancelFunc + wg sync.WaitGroup + logger *zap.Logger +} + +// NewTaskPoller creates a new TaskPoller with the given service, interval, and logger. +func NewTaskPoller(taskSvc TaskPollable, interval time.Duration, logger *zap.Logger) *TaskPoller { + return &TaskPoller{ + taskSvc: taskSvc, + interval: interval, + logger: logger, + } +} + +// Start launches the background goroutine that periodically refreshes stale tasks. +func (p *TaskPoller) Start(ctx context.Context) { + ctx, p.cancel = context.WithCancel(ctx) + p.wg.Add(1) + go func() { + defer p.wg.Done() + ticker := time.NewTicker(p.interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := p.taskSvc.RefreshStaleTasks(ctx); err != nil { + p.logger.Error("failed to refresh stale tasks", zap.Error(err)) + } + } + } + }() +} + +// Stop cancels the background goroutine and waits for it to finish. +func (p *TaskPoller) Stop() { + if p.cancel != nil { + p.cancel() + } + p.wg.Wait() +} diff --git a/internal/app/task_poller_test.go b/internal/app/task_poller_test.go new file mode 100644 index 0000000..9786c19 --- /dev/null +++ b/internal/app/task_poller_test.go @@ -0,0 +1,70 @@ +package app + +import ( + "context" + "sync" + "testing" + "time" + + "go.uber.org/zap" +) + +type mockTaskPollable struct { + refreshFunc func(ctx context.Context) error + callCount int + mu sync.Mutex +} + +func (m *mockTaskPollable) RefreshStaleTasks(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + m.callCount++ + if m.refreshFunc != nil { + return m.refreshFunc(ctx) + } + return nil +} + +func (m *mockTaskPollable) getCallCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.callCount +} + +func TestTaskPoller_StartStop(t *testing.T) { + mock := &mockTaskPollable{} + logger := zap.NewNop() + poller := NewTaskPoller(mock, 1*time.Second, logger) + + poller.Start(context.Background()) + time.Sleep(100 * time.Millisecond) + poller.Stop() + + // No goroutine leak — Stop() returned means wg.Wait() completed. +} + +func TestTaskPoller_RefreshesStaleTasks(t *testing.T) { + mock := &mockTaskPollable{} + logger := zap.NewNop() + poller := NewTaskPoller(mock, 50*time.Millisecond, logger) + + poller.Start(context.Background()) + defer poller.Stop() + + time.Sleep(300 * time.Millisecond) + + if count := mock.getCallCount(); count < 1 { + t.Errorf("expected RefreshStaleTasks to be called at least once, got %d", count) + } +} + +func TestTaskPoller_StopsCleanly(t *testing.T) { + mock := &mockTaskPollable{} + logger := zap.NewNop() + poller := NewTaskPoller(mock, 1*time.Second, logger) + + poller.Start(context.Background()) + poller.Stop() + + // No panic and WaitGroup is done — Stop returned successfully. +}