feat(handler): add TaskHandler endpoints and register task routes
This commit is contained in:
98
internal/handler/task_handler.go
Normal file
98
internal/handler/task_handler.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type taskServiceProvider interface {
|
||||||
|
SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error)
|
||||||
|
ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskHandler struct {
|
||||||
|
svc taskServiceProvider
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler {
|
||||||
|
return &TaskHandler{svc: svc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TaskHandler) CreateTask(c *gin.Context) {
|
||||||
|
var req model.CreateTaskRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for create task", zap.Error(err))
|
||||||
|
server.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "not found") {
|
||||||
|
h.logger.Warn("task submit target not found", zap.Error(err))
|
||||||
|
server.NotFound(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") {
|
||||||
|
h.logger.Warn("task submit validation failed", zap.Error(err))
|
||||||
|
server.BadRequest(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to create task", zap.Error(err))
|
||||||
|
server.InternalError(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("task created", zap.Int64("id", taskID))
|
||||||
|
server.Created(c, gin.H{"id": taskID})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TaskHandler) ListTasks(c *gin.Context) {
|
||||||
|
var query model.TaskListQuery
|
||||||
|
_ = c.ShouldBindQuery(&query)
|
||||||
|
|
||||||
|
if query.Page < 1 {
|
||||||
|
query.Page = 1
|
||||||
|
}
|
||||||
|
if query.PageSize < 1 || query.PageSize > 100 {
|
||||||
|
query.PageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to list tasks", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := make([]model.TaskResponse, 0, len(tasks))
|
||||||
|
for i := range tasks {
|
||||||
|
responses = append(responses, model.TaskResponse{
|
||||||
|
ID: tasks[i].ID,
|
||||||
|
TaskName: tasks[i].TaskName,
|
||||||
|
AppID: tasks[i].AppID,
|
||||||
|
AppName: tasks[i].AppName,
|
||||||
|
Status: tasks[i].Status,
|
||||||
|
CurrentStep: tasks[i].CurrentStep,
|
||||||
|
RetryCount: tasks[i].RetryCount,
|
||||||
|
SlurmJobID: tasks[i].SlurmJobID,
|
||||||
|
WorkDir: tasks[i].WorkDir,
|
||||||
|
ErrorMessage: tasks[i].ErrorMessage,
|
||||||
|
CreatedAt: tasks[i].CreatedAt,
|
||||||
|
UpdatedAt: tasks[i].UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, model.TaskListResponse{
|
||||||
|
Items: responses,
|
||||||
|
Total: total,
|
||||||
|
})
|
||||||
|
}
|
||||||
286
internal/handler/task_handler_test.go
Normal file
286
internal/handler/task_handler_test.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var taskDBCounter atomic.Int64
|
||||||
|
|
||||||
|
func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) {
|
||||||
|
t.Helper()
|
||||||
|
dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1)))
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open db: %v", err)
|
||||||
|
}
|
||||||
|
db.AutoMigrate(&model.Task{}, &model.Application{})
|
||||||
|
t.Cleanup(func() { os.Remove(dbFile) })
|
||||||
|
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
|
||||||
|
var jobSvc *service.JobService
|
||||||
|
if slurmSrv != nil {
|
||||||
|
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
|
||||||
|
jobSvc = service.NewJobService(client, zap.NewNop())
|
||||||
|
}
|
||||||
|
|
||||||
|
workDir := filepath.Join(t.TempDir(), "work")
|
||||||
|
taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop())
|
||||||
|
h := NewTaskHandler(taskSvc, zap.NewNop())
|
||||||
|
return h, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTaskRouter(h *TaskHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
tasks := v1.Group("/tasks")
|
||||||
|
tasks.POST("", h.CreateTask)
|
||||||
|
tasks.GET("", h.ListTasks)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestAppForTask(db *gorm.DB) int64 {
|
||||||
|
app := &model.Application{
|
||||||
|
Name: "test-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
Parameters: json.RawMessage(`[]`),
|
||||||
|
}
|
||||||
|
db.Create(app)
|
||||||
|
return app.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- CreateTask Tests ----
|
||||||
|
|
||||||
|
func TestTaskHandler_CreateTask_Success(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
h, db := setupTaskHandler(t, slurmSrv)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
appID := createTestAppForTask(db)
|
||||||
|
|
||||||
|
taskSvc := h.svc.(*service.TaskService)
|
||||||
|
ctx := context.Background()
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
defer taskSvc.StopProcessor()
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "my-task",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if _, ok := data["id"]; !ok {
|
||||||
|
t.Fatal("expected id in response data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) {
|
||||||
|
h, _ := setupTaskHandler(t, nil)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
body := `{"task_name":"no-app"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) {
|
||||||
|
h, _ := setupTaskHandler(t, nil)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- ListTasks Tests ----
|
||||||
|
|
||||||
|
func TestTaskHandler_ListTasks_Pagination(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
h, db := setupTaskHandler(t, slurmSrv)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
appID := createTestAppForTask(db)
|
||||||
|
|
||||||
|
taskSvc := h.svc.(*service.TaskService)
|
||||||
|
ctx := context.Background()
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
defer taskSvc.StopProcessor()
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: fmt.Sprintf("task-%d", i),
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async processing
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["total"].(float64) != 5 {
|
||||||
|
t.Fatalf("expected total=5, got %v", data["total"])
|
||||||
|
}
|
||||||
|
items := data["items"].([]interface{})
|
||||||
|
if len(items) != 3 {
|
||||||
|
t.Fatalf("expected 3 items, got %d", len(items))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
h, db := setupTaskHandler(t, slurmSrv)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
appID := createTestAppForTask(db)
|
||||||
|
|
||||||
|
taskSvc := h.svc.(*service.TaskService)
|
||||||
|
ctx := context.Background()
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
defer taskSvc.StopProcessor()
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: fmt.Sprintf("filter-task-%d", i),
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async processing
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
items := data["items"].([]interface{})
|
||||||
|
for _, item := range items {
|
||||||
|
m := item.(map[string]interface{})
|
||||||
|
if m["status"] != "queued" {
|
||||||
|
t.Fatalf("expected status=queued, got %v", m["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
|
||||||
|
h, db := setupTaskHandler(t, nil)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
_ = createTestAppForTask(db)
|
||||||
|
|
||||||
|
// Directly insert tasks via DB to avoid needing processor
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
task := &model.Task{
|
||||||
|
TaskName: fmt.Sprintf("default-task-%d", i),
|
||||||
|
AppID: 1,
|
||||||
|
AppName: "test-app",
|
||||||
|
Status: model.TaskStatusSubmitted,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
db.Create(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["total"].(float64) != 15 {
|
||||||
|
t.Fatalf("expected total=15, got %v", data["total"])
|
||||||
|
}
|
||||||
|
items := data["items"].([]interface{})
|
||||||
|
if len(items) != 10 {
|
||||||
|
t.Fatalf("expected 10 items (default page_size), got %d", len(items))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,8 +56,13 @@ type FolderHandler interface {
|
|||||||
DeleteFolder(c *gin.Context)
|
DeleteFolder(c *gin.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TaskHandler interface {
|
||||||
|
CreateTask(c *gin.Context)
|
||||||
|
ListTasks(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
|
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
|
||||||
func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, logger *zap.Logger) *gin.Engine {
|
func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, taskH TaskHandler, logger *zap.Logger) *gin.Engine {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.Use(gin.Recovery())
|
r.Use(gin.Recovery())
|
||||||
@@ -116,6 +121,14 @@ func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler
|
|||||||
folders.DELETE("/:id", folderH.DeleteFolder)
|
folders.DELETE("/:id", folderH.DeleteFolder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if taskH != nil {
|
||||||
|
tasks := v1.Group("/tasks")
|
||||||
|
{
|
||||||
|
tasks.POST("", taskH.CreateTask)
|
||||||
|
tasks.GET("", taskH.ListTasks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,6 +185,9 @@ func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
|
|||||||
folders.GET("", notImplemented)
|
folders.GET("", notImplemented)
|
||||||
folders.GET("/:id", notImplemented)
|
folders.GET("/:id", notImplemented)
|
||||||
folders.DELETE("/:id", notImplemented)
|
folders.DELETE("/:id", notImplemented)
|
||||||
|
|
||||||
|
v1.POST("/tasks", notImplemented)
|
||||||
|
v1.GET("/tasks", notImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
func notImplemented(c *gin.Context) {
|
func notImplemented(c *gin.Context) {
|
||||||
|
|||||||
Reference in New Issue
Block a user