feat(handler): add TaskHandler endpoints and register task routes

This commit is contained in:
dailz
2026-04-15 21:31:11 +08:00
parent ec64300ff2
commit 3f8a680c99
3 changed files with 401 additions and 1 deletions

View File

@@ -0,0 +1,98 @@
package handler
import (
"context"
"strings"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type taskServiceProvider interface {
SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error)
ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error)
}
type TaskHandler struct {
svc taskServiceProvider
logger *zap.Logger
}
func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler {
return &TaskHandler{svc: svc, logger: logger}
}
func (h *TaskHandler) CreateTask(c *gin.Context) {
var req model.CreateTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create task", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not found") {
h.logger.Warn("task submit target not found", zap.Error(err))
server.NotFound(c, errStr)
return
}
if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") {
h.logger.Warn("task submit validation failed", zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to create task", zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("task created", zap.Int64("id", taskID))
server.Created(c, gin.H{"id": taskID})
}
func (h *TaskHandler) ListTasks(c *gin.Context) {
var query model.TaskListQuery
_ = c.ShouldBindQuery(&query)
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 || query.PageSize > 100 {
query.PageSize = 10
}
tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query)
if err != nil {
h.logger.Error("failed to list tasks", zap.Error(err))
server.InternalError(c, err.Error())
return
}
responses := make([]model.TaskResponse, 0, len(tasks))
for i := range tasks {
responses = append(responses, model.TaskResponse{
ID: tasks[i].ID,
TaskName: tasks[i].TaskName,
AppID: tasks[i].AppID,
AppName: tasks[i].AppName,
Status: tasks[i].Status,
CurrentStep: tasks[i].CurrentStep,
RetryCount: tasks[i].RetryCount,
SlurmJobID: tasks[i].SlurmJobID,
WorkDir: tasks[i].WorkDir,
ErrorMessage: tasks[i].ErrorMessage,
CreatedAt: tasks[i].CreatedAt,
UpdatedAt: tasks[i].UpdatedAt,
})
}
server.OK(c, model.TaskListResponse{
Items: responses,
Total: total,
})
}

View File

@@ -0,0 +1,286 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
var taskDBCounter atomic.Int64
func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) {
t.Helper()
dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1)))
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
if err != nil {
t.Fatalf("open db: %v", err)
}
db.AutoMigrate(&model.Task{}, &model.Application{})
t.Cleanup(func() { os.Remove(dbFile) })
taskStore := store.NewTaskStore(db)
appStore := store.NewApplicationStore(db)
var jobSvc *service.JobService
if slurmSrv != nil {
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
jobSvc = service.NewJobService(client, zap.NewNop())
}
workDir := filepath.Join(t.TempDir(), "work")
taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop())
h := NewTaskHandler(taskSvc, zap.NewNop())
return h, db
}
func setupTaskRouter(h *TaskHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
tasks := v1.Group("/tasks")
tasks.POST("", h.CreateTask)
tasks.GET("", h.ListTasks)
return r
}
func createTestAppForTask(db *gorm.DB) int64 {
app := &model.Application{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\necho hello",
Parameters: json.RawMessage(`[]`),
}
db.Create(app)
return app.ID
}
// ---- CreateTask Tests ----
func TestTaskHandler_CreateTask_Success(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: "my-task",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if _, ok := data["id"]; !ok {
t.Fatal("expected id in response data")
}
}
func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) {
h, _ := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
body := `{"task_name":"no-app"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) {
h, _ := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json")))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
// ---- ListTasks Tests ----
func TestTaskHandler_ListTasks_Pagination(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
for i := 0; i < 5; i++ {
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: fmt.Sprintf("task-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
// Wait for async processing
time.Sleep(200 * time.Millisecond)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 5 {
t.Fatalf("expected total=5, got %v", data["total"])
}
items := data["items"].([]interface{})
if len(items) != 3 {
t.Fatalf("expected 3 items, got %d", len(items))
}
}
func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
for i := 0; i < 3; i++ {
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: fmt.Sprintf("filter-task-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
// Wait for async processing
time.Sleep(200 * time.Millisecond)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
items := data["items"].([]interface{})
for _, item := range items {
m := item.(map[string]interface{})
if m["status"] != "queued" {
t.Fatalf("expected status=queued, got %v", m["status"])
}
}
}
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
h, db := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
_ = createTestAppForTask(db)
// Directly insert tasks via DB to avoid needing processor
for i := 0; i < 15; i++ {
task := &model.Task{
TaskName: fmt.Sprintf("default-task-%d", i),
AppID: 1,
AppName: "test-app",
Status: model.TaskStatusSubmitted,
SubmittedAt: time.Now(),
}
db.Create(task)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 15 {
t.Fatalf("expected total=15, got %v", data["total"])
}
items := data["items"].([]interface{})
if len(items) != 10 {
t.Fatalf("expected 10 items (default page_size), got %d", len(items))
}
}

View File

@@ -56,8 +56,13 @@ type FolderHandler interface {
DeleteFolder(c *gin.Context) DeleteFolder(c *gin.Context)
} }
type TaskHandler interface {
CreateTask(c *gin.Context)
ListTasks(c *gin.Context)
}
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers. // NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, logger *zap.Logger) *gin.Engine { func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, taskH TaskHandler, logger *zap.Logger) *gin.Engine {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
r := gin.New() r := gin.New()
r.Use(gin.Recovery()) r.Use(gin.Recovery())
@@ -116,6 +121,14 @@ func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler
folders.DELETE("/:id", folderH.DeleteFolder) folders.DELETE("/:id", folderH.DeleteFolder)
} }
if taskH != nil {
tasks := v1.Group("/tasks")
{
tasks.POST("", taskH.CreateTask)
tasks.GET("", taskH.ListTasks)
}
}
return r return r
} }
@@ -172,6 +185,9 @@ func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
folders.GET("", notImplemented) folders.GET("", notImplemented)
folders.GET("/:id", notImplemented) folders.GET("/:id", notImplemented)
folders.DELETE("/:id", notImplemented) folders.DELETE("/:id", notImplemented)
v1.POST("/tasks", notImplemented)
v1.GET("/tasks", notImplemented)
} }
func notImplemented(c *gin.Context) { func notImplemented(c *gin.Context) {