feat(handler): add TaskHandler endpoints and register task routes
This commit is contained in:
98
internal/handler/task_handler.go
Normal file
98
internal/handler/task_handler.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/server"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type taskServiceProvider interface {
|
||||
SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error)
|
||||
ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error)
|
||||
}
|
||||
|
||||
type TaskHandler struct {
|
||||
svc taskServiceProvider
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler {
|
||||
return &TaskHandler{svc: svc, logger: logger}
|
||||
}
|
||||
|
||||
func (h *TaskHandler) CreateTask(c *gin.Context) {
|
||||
var req model.CreateTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Warn("invalid request body for create task", zap.Error(err))
|
||||
server.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "not found") {
|
||||
h.logger.Warn("task submit target not found", zap.Error(err))
|
||||
server.NotFound(c, errStr)
|
||||
return
|
||||
}
|
||||
if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") {
|
||||
h.logger.Warn("task submit validation failed", zap.Error(err))
|
||||
server.BadRequest(c, errStr)
|
||||
return
|
||||
}
|
||||
h.logger.Error("failed to create task", zap.Error(err))
|
||||
server.InternalError(c, errStr)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("task created", zap.Int64("id", taskID))
|
||||
server.Created(c, gin.H{"id": taskID})
|
||||
}
|
||||
|
||||
func (h *TaskHandler) ListTasks(c *gin.Context) {
|
||||
var query model.TaskListQuery
|
||||
_ = c.ShouldBindQuery(&query)
|
||||
|
||||
if query.Page < 1 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize < 1 || query.PageSize > 100 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
|
||||
tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list tasks", zap.Error(err))
|
||||
server.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
responses := make([]model.TaskResponse, 0, len(tasks))
|
||||
for i := range tasks {
|
||||
responses = append(responses, model.TaskResponse{
|
||||
ID: tasks[i].ID,
|
||||
TaskName: tasks[i].TaskName,
|
||||
AppID: tasks[i].AppID,
|
||||
AppName: tasks[i].AppName,
|
||||
Status: tasks[i].Status,
|
||||
CurrentStep: tasks[i].CurrentStep,
|
||||
RetryCount: tasks[i].RetryCount,
|
||||
SlurmJobID: tasks[i].SlurmJobID,
|
||||
WorkDir: tasks[i].WorkDir,
|
||||
ErrorMessage: tasks[i].ErrorMessage,
|
||||
CreatedAt: tasks[i].CreatedAt,
|
||||
UpdatedAt: tasks[i].UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
server.OK(c, model.TaskListResponse{
|
||||
Items: responses,
|
||||
Total: total,
|
||||
})
|
||||
}
|
||||
286
internal/handler/task_handler_test.go
Normal file
286
internal/handler/task_handler_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gcy_hpc_server/internal/model"
|
||||
"gcy_hpc_server/internal/service"
|
||||
"gcy_hpc_server/internal/slurm"
|
||||
"gcy_hpc_server/internal/store"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
gormlogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var taskDBCounter atomic.Int64
|
||||
|
||||
func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) {
|
||||
t.Helper()
|
||||
dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1)))
|
||||
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
db.AutoMigrate(&model.Task{}, &model.Application{})
|
||||
t.Cleanup(func() { os.Remove(dbFile) })
|
||||
|
||||
taskStore := store.NewTaskStore(db)
|
||||
appStore := store.NewApplicationStore(db)
|
||||
|
||||
var jobSvc *service.JobService
|
||||
if slurmSrv != nil {
|
||||
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
|
||||
jobSvc = service.NewJobService(client, zap.NewNop())
|
||||
}
|
||||
|
||||
workDir := filepath.Join(t.TempDir(), "work")
|
||||
taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop())
|
||||
h := NewTaskHandler(taskSvc, zap.NewNop())
|
||||
return h, db
|
||||
}
|
||||
|
||||
func setupTaskRouter(h *TaskHandler) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
v1 := r.Group("/api/v1")
|
||||
tasks := v1.Group("/tasks")
|
||||
tasks.POST("", h.CreateTask)
|
||||
tasks.GET("", h.ListTasks)
|
||||
return r
|
||||
}
|
||||
|
||||
func createTestAppForTask(db *gorm.DB) int64 {
|
||||
app := &model.Application{
|
||||
Name: "test-app",
|
||||
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||
Parameters: json.RawMessage(`[]`),
|
||||
}
|
||||
db.Create(app)
|
||||
return app.ID
|
||||
}
|
||||
|
||||
// ---- CreateTask Tests ----
|
||||
|
||||
func TestTaskHandler_CreateTask_Success(t *testing.T) {
|
||||
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
|
||||
}))
|
||||
defer slurmSrv.Close()
|
||||
|
||||
h, db := setupTaskHandler(t, slurmSrv)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
appID := createTestAppForTask(db)
|
||||
|
||||
taskSvc := h.svc.(*service.TaskService)
|
||||
ctx := context.Background()
|
||||
taskSvc.StartProcessor(ctx)
|
||||
defer taskSvc.StopProcessor()
|
||||
|
||||
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: "my-task",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !resp["success"].(bool) {
|
||||
t.Fatal("expected success=true")
|
||||
}
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if _, ok := data["id"]; !ok {
|
||||
t.Fatal("expected id in response data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) {
|
||||
h, _ := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
body := `{"task_name":"no-app"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) {
|
||||
h, _ := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- ListTasks Tests ----
|
||||
|
||||
func TestTaskHandler_ListTasks_Pagination(t *testing.T) {
|
||||
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)})
|
||||
}))
|
||||
defer slurmSrv.Close()
|
||||
|
||||
h, db := setupTaskHandler(t, slurmSrv)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
appID := createTestAppForTask(db)
|
||||
|
||||
taskSvc := h.svc.(*service.TaskService)
|
||||
ctx := context.Background()
|
||||
taskSvc.StartProcessor(ctx)
|
||||
defer taskSvc.StopProcessor()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: fmt.Sprintf("task-%d", i),
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Wait for async processing
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if data["total"].(float64) != 5 {
|
||||
t.Fatalf("expected total=5, got %v", data["total"])
|
||||
}
|
||||
items := data["items"].([]interface{})
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("expected 3 items, got %d", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
|
||||
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)})
|
||||
}))
|
||||
defer slurmSrv.Close()
|
||||
|
||||
h, db := setupTaskHandler(t, slurmSrv)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
appID := createTestAppForTask(db)
|
||||
|
||||
taskSvc := h.svc.(*service.TaskService)
|
||||
ctx := context.Background()
|
||||
taskSvc.StartProcessor(ctx)
|
||||
defer taskSvc.StopProcessor()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||
AppID: appID,
|
||||
TaskName: fmt.Sprintf("filter-task-%d", i),
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Wait for async processing
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
data := resp["data"].(map[string]interface{})
|
||||
items := data["items"].([]interface{})
|
||||
for _, item := range items {
|
||||
m := item.(map[string]interface{})
|
||||
if m["status"] != "queued" {
|
||||
t.Fatalf("expected status=queued, got %v", m["status"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
|
||||
h, db := setupTaskHandler(t, nil)
|
||||
r := setupTaskRouter(h)
|
||||
|
||||
_ = createTestAppForTask(db)
|
||||
|
||||
// Directly insert tasks via DB to avoid needing processor
|
||||
for i := 0; i < 15; i++ {
|
||||
task := &model.Task{
|
||||
TaskName: fmt.Sprintf("default-task-%d", i),
|
||||
AppID: 1,
|
||||
AppName: "test-app",
|
||||
Status: model.TaskStatusSubmitted,
|
||||
SubmittedAt: time.Now(),
|
||||
}
|
||||
db.Create(task)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
data := resp["data"].(map[string]interface{})
|
||||
if data["total"].(float64) != 15 {
|
||||
t.Fatalf("expected total=15, got %v", data["total"])
|
||||
}
|
||||
items := data["items"].([]interface{})
|
||||
if len(items) != 10 {
|
||||
t.Fatalf("expected 10 items (default page_size), got %d", len(items))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user