diff --git a/internal/handler/task_handler.go b/internal/handler/task_handler.go new file mode 100644 index 0000000..d5027be --- /dev/null +++ b/internal/handler/task_handler.go @@ -0,0 +1,98 @@ +package handler + +import ( + "context" + "strings" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/server" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +type taskServiceProvider interface { + SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) + ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) +} + +type TaskHandler struct { + svc taskServiceProvider + logger *zap.Logger +} + +func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler { + return &TaskHandler{svc: svc, logger: logger} +} + +func (h *TaskHandler) CreateTask(c *gin.Context) { + var req model.CreateTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + h.logger.Warn("invalid request body for create task", zap.Error(err)) + server.BadRequest(c, err.Error()) + return + } + + taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "not found") { + h.logger.Warn("task submit target not found", zap.Error(err)) + server.NotFound(c, errStr) + return + } + if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") { + h.logger.Warn("task submit validation failed", zap.Error(err)) + server.BadRequest(c, errStr) + return + } + h.logger.Error("failed to create task", zap.Error(err)) + server.InternalError(c, errStr) + return + } + + h.logger.Info("task created", zap.Int64("id", taskID)) + server.Created(c, gin.H{"id": taskID}) +} + +func (h *TaskHandler) ListTasks(c *gin.Context) { + var query model.TaskListQuery + _ = c.ShouldBindQuery(&query) + + if query.Page < 1 { + query.Page = 1 + } + if query.PageSize < 1 || query.PageSize > 100 { + query.PageSize = 10 + } + + tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query) + if err != nil { + h.logger.Error("failed to list tasks", zap.Error(err)) + server.InternalError(c, err.Error()) + return + } + + responses := make([]model.TaskResponse, 0, len(tasks)) + for i := range tasks { + responses = append(responses, model.TaskResponse{ + ID: tasks[i].ID, + TaskName: tasks[i].TaskName, + AppID: tasks[i].AppID, + AppName: tasks[i].AppName, + Status: tasks[i].Status, + CurrentStep: tasks[i].CurrentStep, + RetryCount: tasks[i].RetryCount, + SlurmJobID: tasks[i].SlurmJobID, + WorkDir: tasks[i].WorkDir, + ErrorMessage: tasks[i].ErrorMessage, + CreatedAt: tasks[i].CreatedAt, + UpdatedAt: tasks[i].UpdatedAt, + }) + } + + server.OK(c, model.TaskListResponse{ + Items: responses, + Total: total, + }) +} diff --git a/internal/handler/task_handler_test.go b/internal/handler/task_handler_test.go new file mode 100644 index 0000000..166f81a --- /dev/null +++ b/internal/handler/task_handler_test.go @@ -0,0 +1,286 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/service" + "gcy_hpc_server/internal/slurm" + "gcy_hpc_server/internal/store" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +var taskDBCounter atomic.Int64 + +func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) { + t.Helper() + dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1))) + db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) + if err != nil { + t.Fatalf("open db: %v", err) + } + db.AutoMigrate(&model.Task{}, &model.Application{}) + t.Cleanup(func() { os.Remove(dbFile) }) + + taskStore := store.NewTaskStore(db) + appStore := store.NewApplicationStore(db) + + var jobSvc *service.JobService + if slurmSrv != nil { + client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client()) + jobSvc = service.NewJobService(client, zap.NewNop()) + } + + workDir := filepath.Join(t.TempDir(), "work") + taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop()) + h := NewTaskHandler(taskSvc, zap.NewNop()) + return h, db +} + +func setupTaskRouter(h *TaskHandler) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + v1 := r.Group("/api/v1") + tasks := v1.Group("/tasks") + tasks.POST("", h.CreateTask) + tasks.GET("", h.ListTasks) + return r +} + +func createTestAppForTask(db *gorm.DB) int64 { + app := &model.Application{ + Name: "test-app", + ScriptTemplate: "#!/bin/bash\necho hello", + Parameters: json.RawMessage(`[]`), + } + db.Create(app) + return app.ID +} + +// ---- CreateTask Tests ---- + +func TestTaskHandler_CreateTask_Success(t *testing.T) { + slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345}) + })) + defer slurmSrv.Close() + + h, db := setupTaskHandler(t, slurmSrv) + r := setupTaskRouter(h) + + appID := createTestAppForTask(db) + + taskSvc := h.svc.(*service.TaskService) + ctx := context.Background() + taskSvc.StartProcessor(ctx) + defer taskSvc.StopProcessor() + + body, _ := json.Marshal(model.CreateTaskRequest{ + AppID: appID, + TaskName: "my-task", + }) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + if !resp["success"].(bool) { + t.Fatal("expected success=true") + } + data := resp["data"].(map[string]interface{}) + if _, ok := data["id"]; !ok { + t.Fatal("expected id in response data") + } +} + +func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) { + h, _ := setupTaskHandler(t, nil) + r := setupTaskRouter(h) + + body := `{"task_name":"no-app"}` + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) { + h, _ := setupTaskHandler(t, nil) + r := setupTaskRouter(h) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json"))) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---- ListTasks Tests ---- + +func TestTaskHandler_ListTasks_Pagination(t *testing.T) { + slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)}) + })) + defer slurmSrv.Close() + + h, db := setupTaskHandler(t, slurmSrv) + r := setupTaskRouter(h) + + appID := createTestAppForTask(db) + + taskSvc := h.svc.(*service.TaskService) + ctx := context.Background() + taskSvc.StartProcessor(ctx) + defer taskSvc.StopProcessor() + + for i := 0; i < 5; i++ { + body, _ := json.Marshal(model.CreateTaskRequest{ + AppID: appID, + TaskName: fmt.Sprintf("task-%d", i), + }) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + } + + // Wait for async processing + time.Sleep(200 * time.Millisecond) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + data := resp["data"].(map[string]interface{}) + if data["total"].(float64) != 5 { + t.Fatalf("expected total=5, got %v", data["total"]) + } + items := data["items"].([]interface{}) + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } +} + +func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) { + slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)}) + })) + defer slurmSrv.Close() + + h, db := setupTaskHandler(t, slurmSrv) + r := setupTaskRouter(h) + + appID := createTestAppForTask(db) + + taskSvc := h.svc.(*service.TaskService) + ctx := context.Background() + taskSvc.StartProcessor(ctx) + defer taskSvc.StopProcessor() + + for i := 0; i < 3; i++ { + body, _ := json.Marshal(model.CreateTaskRequest{ + AppID: appID, + TaskName: fmt.Sprintf("filter-task-%d", i), + }) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + } + + // Wait for async processing + time.Sleep(200 * time.Millisecond) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + data := resp["data"].(map[string]interface{}) + items := data["items"].([]interface{}) + for _, item := range items { + m := item.(map[string]interface{}) + if m["status"] != "queued" { + t.Fatalf("expected status=queued, got %v", m["status"]) + } + } +} + +func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) { + h, db := setupTaskHandler(t, nil) + r := setupTaskRouter(h) + + _ = createTestAppForTask(db) + + // Directly insert tasks via DB to avoid needing processor + for i := 0; i < 15; i++ { + task := &model.Task{ + TaskName: fmt.Sprintf("default-task-%d", i), + AppID: 1, + AppName: "test-app", + Status: model.TaskStatusSubmitted, + SubmittedAt: time.Now(), + } + db.Create(task) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + data := resp["data"].(map[string]interface{}) + if data["total"].(float64) != 15 { + t.Fatalf("expected total=15, got %v", data["total"]) + } + items := data["items"].([]interface{}) + if len(items) != 10 { + t.Fatalf("expected 10 items (default page_size), got %d", len(items)) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index ab1cb46..e74c119 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -56,8 +56,13 @@ type FolderHandler interface { DeleteFolder(c *gin.Context) } +type TaskHandler interface { + CreateTask(c *gin.Context) + ListTasks(c *gin.Context) +} + // NewRouter creates a Gin engine with all API v1 routes registered with real handlers. -func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, logger *zap.Logger) *gin.Engine { +func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, taskH TaskHandler, logger *zap.Logger) *gin.Engine { gin.SetMode(gin.ReleaseMode) r := gin.New() r.Use(gin.Recovery()) @@ -116,6 +121,14 @@ func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler folders.DELETE("/:id", folderH.DeleteFolder) } + if taskH != nil { + tasks := v1.Group("/tasks") + { + tasks.POST("", taskH.CreateTask) + tasks.GET("", taskH.ListTasks) + } + } + return r } @@ -172,6 +185,9 @@ func registerPlaceholderRoutes(v1 *gin.RouterGroup) { folders.GET("", notImplemented) folders.GET("/:id", notImplemented) folders.DELETE("/:id", notImplemented) + + v1.POST("/tasks", notImplemented) + v1.GET("/tasks", notImplemented) } func notImplemented(c *gin.Context) {