feat(app): add TaskPoller, wire DI, and add task integration tests
This commit is contained in:
@@ -45,6 +45,7 @@ func TestRouterRegistration(t *testing.T) {
|
|||||||
appH,
|
appH,
|
||||||
nil, nil, nil,
|
nil, nil, nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
routes := router.Routes()
|
routes := router.Routes()
|
||||||
@@ -106,6 +107,7 @@ func TestSmokeGetJobsEndpoint(t *testing.T) {
|
|||||||
appH,
|
appH,
|
||||||
nil, nil, nil,
|
nil, nil, nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|||||||
431
cmd/server/task_test.go
Normal file
431
cmd/server/task_test.go
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/handler"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTaskTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
sqlDB, _ := db.DB()
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
db.AutoMigrate(
|
||||||
|
&model.Application{},
|
||||||
|
&model.File{},
|
||||||
|
&model.FileBlob{},
|
||||||
|
&model.Task{},
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sqlDB.Close()
|
||||||
|
})
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockSlurmHandler struct {
|
||||||
|
submitFn func(w http.ResponseWriter, r *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if h != nil && h.submitFn != nil {
|
||||||
|
h.submitFn(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`)
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
srv := httptest.NewServer(mux)
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) {
|
||||||
|
t.Helper()
|
||||||
|
return setupTaskTestServerWithSlurm(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
slurmSrv := newMockSlurmServer(t, slurmHandler)
|
||||||
|
|
||||||
|
client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("slurm client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log := zap.NewNop()
|
||||||
|
jobSvc := service.NewJobService(client, log)
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "task-test-workdir-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("temp dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskSvc := service.NewTaskService(
|
||||||
|
taskStore, appStore,
|
||||||
|
nil, nil, nil,
|
||||||
|
jobSvc,
|
||||||
|
tmpDir,
|
||||||
|
log,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc)
|
||||||
|
appH := handler.NewApplicationHandler(appSvc, log)
|
||||||
|
taskH := handler.NewTaskHandler(taskSvc, log)
|
||||||
|
|
||||||
|
router := server.NewRouter(
|
||||||
|
handler.NewJobHandler(jobSvc, log),
|
||||||
|
handler.NewClusterHandler(service.NewClusterService(client, log), log),
|
||||||
|
appH,
|
||||||
|
nil, nil, nil,
|
||||||
|
taskH,
|
||||||
|
log,
|
||||||
|
)
|
||||||
|
|
||||||
|
httpSrv := httptest.NewServer(router)
|
||||||
|
cleanup := func() {
|
||||||
|
taskSvc.StopProcessor()
|
||||||
|
cancel()
|
||||||
|
httpSrv.Close()
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
return httpSrv, taskStore, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestApp(t *testing.T, srvURL string) int64 {
|
||||||
|
t.Helper()
|
||||||
|
body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}`
|
||||||
|
resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("create app: status %d, body %s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
var result struct {
|
||||||
|
Data struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("decode app response: %v", err)
|
||||||
|
}
|
||||||
|
return result.Data.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("post task: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.Unmarshal(b, &result)
|
||||||
|
return resp, result
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
url := srvURL + "/api/v1/tasks"
|
||||||
|
if query != "" {
|
||||||
|
url += "?" + query
|
||||||
|
}
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get tasks: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.Unmarshal(b, &result)
|
||||||
|
data, _ := result["data"].(map[string]interface{})
|
||||||
|
items, _ := data["items"].([]interface{})
|
||||||
|
total, _ := data["total"].(float64)
|
||||||
|
return int(total), items
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_FullLifecycle(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, fmt.Sprintf(
|
||||||
|
`{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID,
|
||||||
|
))
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := result["data"].(map[string]interface{})
|
||||||
|
if data["id"] == nil {
|
||||||
|
t.Fatal("expected id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
total, items := getTasks(t, srv.URL, "")
|
||||||
|
if total < 1 {
|
||||||
|
t.Fatalf("expected at least 1 task, got %d", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
task := items[0].(map[string]interface{})
|
||||||
|
if task["status"] == nil || task["status"] == "" {
|
||||||
|
t.Error("expected non-empty status")
|
||||||
|
}
|
||||||
|
if task["task_name"] != "lifecycle-test" {
|
||||||
|
t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"])
|
||||||
|
}
|
||||||
|
if task["app_id"] != float64(appID) {
|
||||||
|
t.Errorf("expected app_id=%d, got %v", appID, task["app_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_CreateWithMissingApp(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_CreateWithInvalidBody(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, `{}`)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_FileLimitExceeded(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
fileIDs := make([]int64, 101)
|
||||||
|
for i := range fileIDs {
|
||||||
|
fileIDs[i] = int64(i + 1)
|
||||||
|
}
|
||||||
|
idsJSON, _ := json.Marshal(fileIDs)
|
||||||
|
|
||||||
|
body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON))
|
||||||
|
resp, result := postTask(t, srv.URL, body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_RetryScenario(t *testing.T) {
|
||||||
|
var failCount int32
|
||||||
|
slurmH := &mockSlurmHandler{
|
||||||
|
submitFn: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if atomic.AddInt32(&failCount, 1) <= 2 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, fmt.Sprintf(
|
||||||
|
`{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID,
|
||||||
|
))
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
|
||||||
|
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task != nil && task.Status == model.TaskStatusQueued {
|
||||||
|
if task.SlurmJobID != nil && *task.SlurmJobID == 99 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task == nil {
|
||||||
|
t.Fatalf("task %d not found after deadline", taskID)
|
||||||
|
}
|
||||||
|
t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v",
|
||||||
|
task.Status, task.RetryCount, task.SlurmJobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_OldAPICompatibility(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
body := `{"values":{"np":"8"}}`
|
||||||
|
url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID)
|
||||||
|
resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("post submit: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
data, _ := result["data"].(map[string]interface{})
|
||||||
|
if data == nil {
|
||||||
|
t.Fatal("expected data in response")
|
||||||
|
}
|
||||||
|
if data["job_id"] == nil {
|
||||||
|
t.Errorf("expected job_id in old API response, got: %v", data)
|
||||||
|
}
|
||||||
|
jobID, ok := data["job_id"].(float64)
|
||||||
|
if !ok || jobID == 0 {
|
||||||
|
t.Errorf("expected non-zero job_id, got %v", data["job_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_ListWithFilters(t *testing.T) {
|
||||||
|
srv, taskStore, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
taskStore.Create(context.Background(), &model.Task{
|
||||||
|
TaskName: "task-completed", AppID: appID, AppName: "test-app",
|
||||||
|
Status: model.TaskStatusCompleted, SubmittedAt: now,
|
||||||
|
})
|
||||||
|
taskStore.Create(context.Background(), &model.Task{
|
||||||
|
TaskName: "task-failed", AppID: appID, AppName: "test-app",
|
||||||
|
Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now,
|
||||||
|
})
|
||||||
|
taskStore.Create(context.Background(), &model.Task{
|
||||||
|
TaskName: "task-queued", AppID: appID, AppName: "test-app",
|
||||||
|
Status: model.TaskStatusQueued, SubmittedAt: now,
|
||||||
|
})
|
||||||
|
|
||||||
|
total, items := getTasks(t, srv.URL, "status=completed")
|
||||||
|
if total != 1 {
|
||||||
|
t.Fatalf("expected 1 completed, got %d", total)
|
||||||
|
}
|
||||||
|
task := items[0].(map[string]interface{})
|
||||||
|
if task["status"] != "completed" {
|
||||||
|
t.Errorf("expected status=completed, got %v", task["status"])
|
||||||
|
}
|
||||||
|
|
||||||
|
total, items = getTasks(t, srv.URL, "status=failed")
|
||||||
|
if total != 1 {
|
||||||
|
t.Fatalf("expected 1 failed, got %d", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, _ = getTasks(t, srv.URL, "")
|
||||||
|
if total != 3 {
|
||||||
|
t.Fatalf("expected 3 total, got %d", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_WorkDirCreated(t *testing.T) {
|
||||||
|
srv, taskStore, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, fmt.Sprintf(
|
||||||
|
`{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID,
|
||||||
|
))
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
|
||||||
|
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task != nil && task.WorkDir != "" {
|
||||||
|
if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("work dir %s not created", task.WorkDir)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(task.WorkDir) {
|
||||||
|
t.Errorf("expected absolute work dir, got %s", task.WorkDir)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task == nil {
|
||||||
|
t.Fatalf("task %d not found after deadline", taskID)
|
||||||
|
}
|
||||||
|
t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir)
|
||||||
|
}
|
||||||
@@ -28,6 +28,8 @@ type App struct {
|
|||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
server *http.Server
|
server *http.Server
|
||||||
cancelCleanup context.CancelFunc
|
cancelCleanup context.CancelFunc
|
||||||
|
taskSvc *service.TaskService
|
||||||
|
taskPoller *TaskPoller
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
|
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
|
||||||
@@ -43,7 +45,7 @@ func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, cancelCleanup := initHTTPServer(cfg, gormDB, slurmClient, logger)
|
srv, cancelCleanup, taskSvc, taskPoller := initHTTPServer(cfg, gormDB, slurmClient, logger)
|
||||||
|
|
||||||
return &App{
|
return &App{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -51,6 +53,8 @@ func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
|
|||||||
db: gormDB,
|
db: gormDB,
|
||||||
server: srv,
|
server: srv,
|
||||||
cancelCleanup: cancelCleanup,
|
cancelCleanup: cancelCleanup,
|
||||||
|
taskSvc: taskSvc,
|
||||||
|
taskPoller: taskPoller,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,6 +90,14 @@ func (a *App) Run() error {
|
|||||||
func (a *App) Close() error {
|
func (a *App) Close() error {
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
|
if a.taskSvc != nil {
|
||||||
|
a.taskSvc.StopProcessor()
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.taskPoller != nil {
|
||||||
|
a.taskPoller.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
if a.cancelCleanup != nil {
|
if a.cancelCleanup != nil {
|
||||||
a.cancelCleanup()
|
a.cancelCleanup()
|
||||||
}
|
}
|
||||||
@@ -145,15 +157,15 @@ func initSlurmClient(cfg *config.Config) (*slurm.Client, error) {
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc) {
|
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc, *service.TaskService, *TaskPoller) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
jobSvc := service.NewJobService(slurmClient, logger)
|
jobSvc := service.NewJobService(slurmClient, logger)
|
||||||
clusterSvc := service.NewClusterService(slurmClient, logger)
|
clusterSvc := service.NewClusterService(slurmClient, logger)
|
||||||
jobH := handler.NewJobHandler(jobSvc, logger)
|
jobH := handler.NewJobHandler(jobSvc, logger)
|
||||||
clusterH := handler.NewClusterHandler(clusterSvc, logger)
|
clusterH := handler.NewClusterHandler(clusterSvc, logger)
|
||||||
|
|
||||||
appStore := store.NewApplicationStore(db)
|
appStore := store.NewApplicationStore(db)
|
||||||
appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger)
|
|
||||||
appH := handler.NewApplicationHandler(appSvc, logger)
|
|
||||||
|
|
||||||
// File storage initialization
|
// File storage initialization
|
||||||
minioClient, err := storage.NewMinioClient(cfg.Minio)
|
minioClient, err := storage.NewMinioClient(cfg.Minio)
|
||||||
@@ -165,9 +177,12 @@ func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client,
|
|||||||
var fileH *handler.FileHandler
|
var fileH *handler.FileHandler
|
||||||
var folderH *handler.FolderHandler
|
var folderH *handler.FolderHandler
|
||||||
|
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
fileStore := store.NewFileStore(db)
|
||||||
|
blobStore := store.NewBlobStore(db)
|
||||||
|
|
||||||
|
var stagingSvc *service.FileStagingService
|
||||||
if minioClient != nil {
|
if minioClient != nil {
|
||||||
blobStore := store.NewBlobStore(db)
|
|
||||||
fileStore := store.NewFileStore(db)
|
|
||||||
folderStore := store.NewFolderStore(db)
|
folderStore := store.NewFolderStore(db)
|
||||||
uploadStore := store.NewUploadStore(db)
|
uploadStore := store.NewUploadStore(db)
|
||||||
|
|
||||||
@@ -179,25 +194,34 @@ func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client,
|
|||||||
fileH = handler.NewFileHandler(fileSvc, logger)
|
fileH = handler.NewFileHandler(fileSvc, logger)
|
||||||
folderH = handler.NewFolderHandler(folderSvc, logger)
|
folderH = handler.NewFolderHandler(folderSvc, logger)
|
||||||
|
|
||||||
cleanupCtx, cancelCleanup := context.WithCancel(context.Background())
|
stagingSvc = service.NewFileStagingService(fileStore, blobStore, minioClient, cfg.Minio.Bucket, logger)
|
||||||
go startCleanupWorker(cleanupCtx, uploadStore, minioClient, cfg.Minio.Bucket, logger)
|
|
||||||
|
|
||||||
router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, logger)
|
|
||||||
|
|
||||||
addr := ":" + cfg.ServerPort
|
|
||||||
|
|
||||||
return &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
Handler: router,
|
|
||||||
}, cancelCleanup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, logger)
|
taskSvc := service.NewTaskService(taskStore, appStore, fileStore, blobStore, stagingSvc, jobSvc, cfg.WorkDirBase, logger)
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger, taskSvc)
|
||||||
|
appH := handler.NewApplicationHandler(appSvc, logger)
|
||||||
|
|
||||||
|
poller := NewTaskPoller(taskSvc, 10*time.Second, logger)
|
||||||
|
poller.Start(ctx)
|
||||||
|
|
||||||
|
taskH := handler.NewTaskHandler(taskSvc, logger)
|
||||||
|
|
||||||
|
var cancelCleanup context.CancelFunc
|
||||||
|
|
||||||
|
if minioClient != nil {
|
||||||
|
cleanupCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancelCleanup = cancel
|
||||||
|
go startCleanupWorker(cleanupCtx, store.NewUploadStore(db), minioClient, cfg.Minio.Bucket, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, taskH, logger)
|
||||||
|
|
||||||
addr := ":" + cfg.ServerPort
|
addr := ":" + cfg.ServerPort
|
||||||
|
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: router,
|
Handler: router,
|
||||||
}, nil
|
}, cancelCleanup, taskSvc, poller
|
||||||
}
|
}
|
||||||
|
|||||||
61
internal/app/task_poller.go
Normal file
61
internal/app/task_poller.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskPollable defines the interface for refreshing stale task statuses.
|
||||||
|
type TaskPollable interface {
|
||||||
|
RefreshStaleTasks(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskPoller periodically polls Slurm for task status updates via TaskPollable.
|
||||||
|
type TaskPoller struct {
|
||||||
|
taskSvc TaskPollable
|
||||||
|
interval time.Duration
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskPoller creates a new TaskPoller with the given service, interval, and logger.
|
||||||
|
func NewTaskPoller(taskSvc TaskPollable, interval time.Duration, logger *zap.Logger) *TaskPoller {
|
||||||
|
return &TaskPoller{
|
||||||
|
taskSvc: taskSvc,
|
||||||
|
interval: interval,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start launches the background goroutine that periodically refreshes stale tasks.
|
||||||
|
func (p *TaskPoller) Start(ctx context.Context) {
|
||||||
|
ctx, p.cancel = context.WithCancel(ctx)
|
||||||
|
p.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer p.wg.Done()
|
||||||
|
ticker := time.NewTicker(p.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := p.taskSvc.RefreshStaleTasks(ctx); err != nil {
|
||||||
|
p.logger.Error("failed to refresh stale tasks", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop cancels the background goroutine and waits for it to finish.
|
||||||
|
func (p *TaskPoller) Stop() {
|
||||||
|
if p.cancel != nil {
|
||||||
|
p.cancel()
|
||||||
|
}
|
||||||
|
p.wg.Wait()
|
||||||
|
}
|
||||||
70
internal/app/task_poller_test.go
Normal file
70
internal/app/task_poller_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockTaskPollable struct {
|
||||||
|
refreshFunc func(ctx context.Context) error
|
||||||
|
callCount int
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTaskPollable) RefreshStaleTasks(ctx context.Context) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.callCount++
|
||||||
|
if m.refreshFunc != nil {
|
||||||
|
return m.refreshFunc(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTaskPollable) getCallCount() int {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.callCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPoller_StartStop(t *testing.T) {
|
||||||
|
mock := &mockTaskPollable{}
|
||||||
|
logger := zap.NewNop()
|
||||||
|
poller := NewTaskPoller(mock, 1*time.Second, logger)
|
||||||
|
|
||||||
|
poller.Start(context.Background())
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
poller.Stop()
|
||||||
|
|
||||||
|
// No goroutine leak — Stop() returned means wg.Wait() completed.
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPoller_RefreshesStaleTasks(t *testing.T) {
|
||||||
|
mock := &mockTaskPollable{}
|
||||||
|
logger := zap.NewNop()
|
||||||
|
poller := NewTaskPoller(mock, 50*time.Millisecond, logger)
|
||||||
|
|
||||||
|
poller.Start(context.Background())
|
||||||
|
defer poller.Stop()
|
||||||
|
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
if count := mock.getCallCount(); count < 1 {
|
||||||
|
t.Errorf("expected RefreshStaleTasks to be called at least once, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPoller_StopsCleanly(t *testing.T) {
|
||||||
|
mock := &mockTaskPollable{}
|
||||||
|
logger := zap.NewNop()
|
||||||
|
poller := NewTaskPoller(mock, 1*time.Second, logger)
|
||||||
|
|
||||||
|
poller.Start(context.Background())
|
||||||
|
poller.Stop()
|
||||||
|
|
||||||
|
// No panic and WaitGroup is done — Stop returned successfully.
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user