fix(task): prevent duplicate Slurm job submission on backend restart
RecoverStuckTasks now skips tasks that already have a slurm_job_id, and ProcessTask adds a guard before the submitting step to prevent re-submission even if a task is incorrectly re-enqueued. Also deprecates POST /api/v1/jobs/submit endpoint (replaced by POST /tasks) and comments out related handlers and tests.
This commit is contained in:
@@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -66,6 +65,8 @@ func jobSubmitViaAPI(t *testing.T, env *testenv.TestEnv, script string) int32 {
|
|||||||
return job.JobID
|
return job.JobID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// [已弃用] 以下测试依赖 POST /api/v1/jobs/submit,该接口已被 POST /tasks 取代。
|
||||||
|
/*
|
||||||
// TestIntegration_Jobs_Submit verifies POST /api/v1/jobs/submit creates a new job.
|
// TestIntegration_Jobs_Submit verifies POST /api/v1/jobs/submit creates a new job.
|
||||||
func TestIntegration_Jobs_Submit(t *testing.T) {
|
func TestIntegration_Jobs_Submit(t *testing.T) {
|
||||||
env := testenv.NewTestEnv(t)
|
env := testenv.NewTestEnv(t)
|
||||||
@@ -220,3 +221,4 @@ func TestIntegration_Jobs_History(t *testing.T) {
|
|||||||
t.Fatalf("cancelled job %d not found in history", jobID)
|
t.Fatalf("cancelled job %d not found in history", jobID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestRouterRegistration(t *testing.T) {
|
|||||||
method string
|
method string
|
||||||
path string
|
path string
|
||||||
}{
|
}{
|
||||||
{"POST", "/api/v1/jobs/submit"},
|
// {"POST", "/api/v1/jobs/submit"}, // [已弃用] 已被 POST /tasks 取代
|
||||||
{"GET", "/api/v1/jobs"},
|
{"GET", "/api/v1/jobs"},
|
||||||
{"GET", "/api/v1/jobs/history"},
|
{"GET", "/api/v1/jobs/history"},
|
||||||
{"GET", "/api/v1/jobs/:id"},
|
{"GET", "/api/v1/jobs/:id"},
|
||||||
|
|||||||
@@ -22,29 +22,31 @@ func NewJobHandler(jobSvc *service.JobService, logger *zap.Logger) *JobHandler {
|
|||||||
return &JobHandler{jobSvc: jobSvc, logger: logger}
|
return &JobHandler{jobSvc: jobSvc, logger: logger}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// [已弃用] SubmitJob 已被 POST /tasks 取代。
|
||||||
|
// 保留方法体以防需要回滚。
|
||||||
// SubmitJob handles POST /api/v1/jobs/submit.
|
// SubmitJob handles POST /api/v1/jobs/submit.
|
||||||
func (h *JobHandler) SubmitJob(c *gin.Context) {
|
// func (h *JobHandler) SubmitJob(c *gin.Context) {
|
||||||
var req model.SubmitJobRequest
|
// var req model.SubmitJobRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
// if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "invalid request body"))
|
// h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "invalid request body"))
|
||||||
server.BadRequest(c, "invalid request body")
|
// server.BadRequest(c, "invalid request body")
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
if req.Script == "" {
|
// if req.Script == "" {
|
||||||
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "script is required"))
|
// h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "script is required"))
|
||||||
server.BadRequest(c, "script is required")
|
// server.BadRequest(c, "script is required")
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
resp, err := h.jobSvc.SubmitJob(c.Request.Context(), &req)
|
// resp, err := h.jobSvc.SubmitJob(c.Request.Context(), &req)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
h.logger.Error("handler error", zap.String("method", "SubmitJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
|
// h.logger.Error("handler error", zap.String("method", "SubmitJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
|
||||||
server.ErrorWithStatus(c, http.StatusBadGateway, "slurm error: "+err.Error())
|
// server.ErrorWithStatus(c, http.StatusBadGateway, "slurm error: "+err.Error())
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
server.Created(c, resp)
|
// server.Created(c, resp)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// GetJobs handles GET /api/v1/jobs with pagination.
|
// GetJobs handles GET /api/v1/jobs with pagination.
|
||||||
func (h *JobHandler) GetJobs(c *gin.Context) {
|
func (h *JobHandler) GetJobs(c *gin.Context) {
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -23,7 +21,7 @@ func setupJobRouter(h *JobHandler) *gin.Engine {
|
|||||||
v1 := r.Group("/api/v1")
|
v1 := r.Group("/api/v1")
|
||||||
jobs := v1.Group("/jobs")
|
jobs := v1.Group("/jobs")
|
||||||
{
|
{
|
||||||
jobs.POST("/submit", h.SubmitJob)
|
// jobs.POST("/submit", h.SubmitJob) // [已弃用] 已被 POST /tasks 取代
|
||||||
jobs.GET("", h.GetJobs)
|
jobs.GET("", h.GetJobs)
|
||||||
jobs.GET("/history", h.GetJobHistory)
|
jobs.GET("/history", h.GetJobHistory)
|
||||||
jobs.GET("/:id", h.GetJob)
|
jobs.GET("/:id", h.GetJob)
|
||||||
@@ -61,6 +59,8 @@ func handlerLogs(logs *observer.ObservedLogs) []observer.LoggedEntry {
|
|||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// [已弃用] SubmitJob 相关测试已被禁用,该接口已被 POST /tasks 取代。
|
||||||
|
/*
|
||||||
func TestSubmitJob_Success(t *testing.T) {
|
func TestSubmitJob_Success(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -171,6 +171,9 @@ func TestSubmitJob_SlurmError(t *testing.T) {
|
|||||||
t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String())
|
t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// --- Logging verification tests ---
|
||||||
|
|
||||||
func TestGetJobs_Success(t *testing.T) {
|
func TestGetJobs_Success(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -462,6 +465,7 @@ func TestGetJobHistory_DefaultPagination(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
func TestSubmitJob_InvalidBody(t *testing.T) {
|
func TestSubmitJob_InvalidBody(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
srv, handler := setupJobHandler(mux)
|
srv, handler := setupJobHandler(mux)
|
||||||
@@ -479,9 +483,11 @@ func TestSubmitJob_InvalidBody(t *testing.T) {
|
|||||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// --- Logging verification tests ---
|
// --- Logging verification tests ---
|
||||||
|
|
||||||
|
/*
|
||||||
func TestSubmitJob_InvalidBody_LogsWarn(t *testing.T) {
|
func TestSubmitJob_InvalidBody_LogsWarn(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
@@ -614,6 +620,7 @@ func TestSubmitJob_Success_NoHandlerLogs(t *testing.T) {
|
|||||||
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
|
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
func TestGetJobs_Error_LogsError(t *testing.T) {
|
func TestGetJobs_Error_LogsError(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type JobHandler interface {
|
type JobHandler interface {
|
||||||
SubmitJob(c *gin.Context)
|
// SubmitJob(c *gin.Context) // [已弃用] 已被 POST /tasks 取代
|
||||||
GetJobs(c *gin.Context)
|
GetJobs(c *gin.Context)
|
||||||
GetJobHistory(c *gin.Context)
|
GetJobHistory(c *gin.Context)
|
||||||
GetJob(c *gin.Context)
|
GetJob(c *gin.Context)
|
||||||
@@ -73,7 +73,7 @@ func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler
|
|||||||
v1 := r.Group("/api/v1")
|
v1 := r.Group("/api/v1")
|
||||||
|
|
||||||
jobs := v1.Group("/jobs")
|
jobs := v1.Group("/jobs")
|
||||||
jobs.POST("/submit", jobH.SubmitJob)
|
// jobs.POST("/submit", jobH.SubmitJob) // [已弃用] 已被 POST /tasks 取代
|
||||||
jobs.GET("", jobH.GetJobs)
|
jobs.GET("", jobH.GetJobs)
|
||||||
jobs.GET("/history", jobH.GetJobHistory)
|
jobs.GET("/history", jobH.GetJobHistory)
|
||||||
jobs.GET("/:id", jobH.GetJob)
|
jobs.GET("/:id", jobH.GetJob)
|
||||||
@@ -144,7 +144,7 @@ func NewTestRouter() *gin.Engine {
|
|||||||
|
|
||||||
func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
|
func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
|
||||||
jobs := v1.Group("/jobs")
|
jobs := v1.Group("/jobs")
|
||||||
jobs.POST("/submit", notImplemented)
|
// jobs.POST("/submit", notImplemented) // [已弃用] 已被 POST /tasks 取代
|
||||||
jobs.GET("", notImplemented)
|
jobs.GET("", notImplemented)
|
||||||
jobs.GET("/history", notImplemented)
|
jobs.GET("/history", notImplemented)
|
||||||
jobs.GET("/:id", notImplemented)
|
jobs.GET("/:id", notImplemented)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func TestAllRoutesRegistered(t *testing.T) {
|
|||||||
method string
|
method string
|
||||||
path string
|
path string
|
||||||
}{
|
}{
|
||||||
{"POST", "/api/v1/jobs/submit"},
|
// {"POST", "/api/v1/jobs/submit"}, // [已弃用] 已被 POST /tasks 取代
|
||||||
{"GET", "/api/v1/jobs"},
|
{"GET", "/api/v1/jobs"},
|
||||||
{"GET", "/api/v1/jobs/history"},
|
{"GET", "/api/v1/jobs/history"},
|
||||||
{"GET", "/api/v1/jobs/:id"},
|
{"GET", "/api/v1/jobs/:id"},
|
||||||
|
|||||||
@@ -263,7 +263,15 @@ func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 13-14. Set ready + submitting
|
// 13-14. Set ready + submitting (guard: skip if already submitted to Slurm)
|
||||||
|
if task.SlurmJobID != nil {
|
||||||
|
s.logger.Info("task already has slurm job, skipping submission",
|
||||||
|
zap.Int64("task_id", taskID),
|
||||||
|
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
|
||||||
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
|
||||||
}
|
}
|
||||||
@@ -694,6 +702,13 @@ func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
for i := range tasks {
|
for i := range tasks {
|
||||||
|
if tasks[i].SlurmJobID != nil {
|
||||||
|
s.logger.Info("skipping stuck task recovery, already in slurm",
|
||||||
|
zap.Int64("taskID", tasks[i].ID),
|
||||||
|
zap.Int32("slurm_job_id", *tasks[i].SlurmJobID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
|
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if !s.stopped {
|
if !s.stopped {
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func TestAllRoutesRegistered(t *testing.T) {
|
|||||||
method string
|
method string
|
||||||
path string
|
path string
|
||||||
}{
|
}{
|
||||||
{"POST", "/api/v1/jobs/submit"},
|
// {"POST", "/api/v1/jobs/submit"}, // [已弃用] 已被 POST /tasks 取代
|
||||||
{"GET", "/api/v1/jobs"},
|
{"GET", "/api/v1/jobs"},
|
||||||
{"GET", "/api/v1/jobs/history"},
|
{"GET", "/api/v1/jobs/history"},
|
||||||
{"GET", "/api/v1/jobs/1"},
|
{"GET", "/api/v1/jobs/1"},
|
||||||
@@ -82,8 +82,8 @@ func TestAllRoutesRegistered(t *testing.T) {
|
|||||||
{"GET", "/api/v1/tasks"},
|
{"GET", "/api/v1/tasks"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(routes) != 30 {
|
if len(routes) != 29 {
|
||||||
t.Fatalf("expected 31 routes, got %d", len(routes))
|
t.Fatalf("expected 30 routes, got %d", len(routes))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
|||||||
Reference in New Issue
Block a user