Compare commits
7 Commits
246c19c052
...
4ff02d4a80
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ff02d4a80 | ||
|
|
1784331969 | ||
|
|
e6162063ca | ||
|
|
4903f7d07f | ||
|
|
fbfd5c5f42 | ||
|
|
f7a21ee455 | ||
|
|
7550e75945 |
39
cmd/server/main.go
Normal file
39
cmd/server/main.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/app"
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
"gcy_hpc_server/internal/logger"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfgPath := ""
|
||||||
|
if len(os.Args) > 1 {
|
||||||
|
cfgPath = os.Args[1]
|
||||||
|
}
|
||||||
|
cfg, err := config.Load(cfgPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger, err := logger.NewLogger(cfg.Log)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer zapLogger.Sync()
|
||||||
|
|
||||||
|
application, err := app.NewApp(cfg, zapLogger)
|
||||||
|
if err != nil {
|
||||||
|
zapLogger.Fatal("failed to initialize application", zap.Error(err))
|
||||||
|
}
|
||||||
|
if err := application.Run(); err != nil {
|
||||||
|
zapLogger.Fatal("application error", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
117
cmd/server/main_test.go
Normal file
117
cmd/server/main_test.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/handler"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestDB() *gorm.DB {
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
|
db.AutoMigrate(&model.JobTemplate{})
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRegistration(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
||||||
|
templateStore := store.NewTemplateStore(newTestDB())
|
||||||
|
|
||||||
|
router := server.NewRouter(
|
||||||
|
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
|
||||||
|
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
|
||||||
|
handler.NewTemplateHandler(templateStore, zap.NewNop()),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
routes := router.Routes()
|
||||||
|
expected := []struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
}{
|
||||||
|
{"POST", "/api/v1/jobs/submit"},
|
||||||
|
{"GET", "/api/v1/jobs"},
|
||||||
|
{"GET", "/api/v1/jobs/history"},
|
||||||
|
{"GET", "/api/v1/jobs/:id"},
|
||||||
|
{"DELETE", "/api/v1/jobs/:id"},
|
||||||
|
{"GET", "/api/v1/nodes"},
|
||||||
|
{"GET", "/api/v1/nodes/:name"},
|
||||||
|
{"GET", "/api/v1/partitions"},
|
||||||
|
{"GET", "/api/v1/partitions/:name"},
|
||||||
|
{"GET", "/api/v1/diag"},
|
||||||
|
{"GET", "/api/v1/templates"},
|
||||||
|
{"POST", "/api/v1/templates"},
|
||||||
|
{"GET", "/api/v1/templates/:id"},
|
||||||
|
{"PUT", "/api/v1/templates/:id"},
|
||||||
|
{"DELETE", "/api/v1/templates/:id"},
|
||||||
|
}
|
||||||
|
|
||||||
|
routeMap := map[string]bool{}
|
||||||
|
for _, r := range routes {
|
||||||
|
routeMap[r.Method+" "+r.Path] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range expected {
|
||||||
|
key := exp.method + " " + exp.path
|
||||||
|
if !routeMap[key] {
|
||||||
|
t.Errorf("missing route: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routes) < len(expected) {
|
||||||
|
t.Errorf("expected at least %d routes, got %d", len(expected), len(routes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSmokeGetJobsEndpoint(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
||||||
|
templateStore := store.NewTemplateStore(newTestDB())
|
||||||
|
|
||||||
|
router := server.NewRouter(
|
||||||
|
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
|
||||||
|
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
|
||||||
|
handler.NewTemplateHandler(templateStore, zap.NewNop()),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if success, ok := resp["success"].(bool); !ok || !success {
|
||||||
|
t.Fatalf("expected success=true, got %v", resp["success"])
|
||||||
|
}
|
||||||
|
}
|
||||||
54
go.mod
54
go.mod
@@ -1,3 +1,53 @@
|
|||||||
module slurm-client
|
module gcy_hpc_server
|
||||||
|
|
||||||
go 1.22
|
go 1.25.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/gin-gonic/gin v1.12.0
|
||||||
|
go.uber.org/zap v1.27.1
|
||||||
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
gorm.io/driver/mysql v1.6.0
|
||||||
|
gorm.io/driver/sqlite v1.6.0
|
||||||
|
gorm.io/gorm v1.31.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
|
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||||
|
github.com/bytedance/sonic v1.15.0 // indirect
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||||
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||||
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||||
|
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
|
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
|
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||||
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
|
golang.org/x/net v0.51.0 // indirect
|
||||||
|
golang.org/x/sys v0.41.0 // indirect
|
||||||
|
golang.org/x/text v0.34.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
|
)
|
||||||
|
|||||||
123
go.sum
Normal file
123
go.sum
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
|
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||||
|
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||||
|
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||||
|
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||||
|
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||||
|
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
|
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||||
|
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||||
|
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||||
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
|
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||||
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
|
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||||
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||||
|
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||||
|
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||||
|
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||||
|
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||||
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
|
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||||
|
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
|
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||||
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
|
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||||
|
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||||
|
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||||
|
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||||
|
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||||
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
|
||||||
|
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
|
||||||
|
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||||
|
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||||
|
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||||
|
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||||
159
internal/app/app.go
Normal file
159
internal/app/app.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
"gcy_hpc_server/internal/handler"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// App encapsulates the entire application lifecycle.
|
||||||
|
type App struct {
|
||||||
|
cfg *config.Config
|
||||||
|
logger *zap.Logger
|
||||||
|
db *gorm.DB
|
||||||
|
server *http.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
|
||||||
|
func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
|
||||||
|
gormDB, err := initDB(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
slurmClient, err := initSlurmClient(cfg)
|
||||||
|
if err != nil {
|
||||||
|
closeDB(gormDB)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := initHTTPServer(cfg, gormDB, slurmClient, logger)
|
||||||
|
|
||||||
|
return &App{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
db: gormDB,
|
||||||
|
server: srv,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the HTTP server and blocks until a shutdown signal or server error.
|
||||||
|
func (a *App) Run() error {
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
a.logger.Info("starting server", zap.String("addr", a.server.Addr))
|
||||||
|
if err := a.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
errCh <- fmt.Errorf("server listen: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
// Server crashed before receiving a signal — clean up resources before
|
||||||
|
// returning, because the caller may call os.Exit and skip deferred Close().
|
||||||
|
a.logger.Error("server exited unexpectedly", zap.Error(err))
|
||||||
|
_ = a.Close()
|
||||||
|
return err
|
||||||
|
case sig := <-quit:
|
||||||
|
a.logger.Info("received shutdown signal", zap.String("signal", sig.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
a.logger.Info("shutting down server...")
|
||||||
|
return a.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up all resources: HTTP server and database connections.
|
||||||
|
func (a *App) Close() error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if a.server != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := a.server.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
|
||||||
|
errs = append(errs, fmt.Errorf("shutdown http server: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.db != nil {
|
||||||
|
sqlDB, err := a.db.DB()
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("get underlying sql.DB: %w", err))
|
||||||
|
} else if err := sqlDB.Close(); err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("close database: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Initialization helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func initDB(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
|
||||||
|
gormDB, err := store.NewGormDB(cfg.MySQLDSN, logger, cfg.Log.GormLevel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.AutoMigrate(gormDB); err != nil {
|
||||||
|
closeDB(gormDB)
|
||||||
|
return nil, fmt.Errorf("run migrations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return gormDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeDB(db *gorm.DB) {
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sqlDB, err := db.DB(); err == nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initSlurmClient(cfg *config.Config) (*slurm.Client, error) {
|
||||||
|
client, err := service.NewSlurmClient(cfg.SlurmAPIURL, cfg.SlurmUserName, cfg.SlurmJWTKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init slurm client: %w", err)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) *http.Server {
|
||||||
|
jobSvc := service.NewJobService(slurmClient, logger)
|
||||||
|
clusterSvc := service.NewClusterService(slurmClient, logger)
|
||||||
|
templateStore := store.NewTemplateStore(db)
|
||||||
|
|
||||||
|
jobH := handler.NewJobHandler(jobSvc, logger)
|
||||||
|
clusterH := handler.NewClusterHandler(clusterSvc, logger)
|
||||||
|
templateH := handler.NewTemplateHandler(templateStore, logger)
|
||||||
|
|
||||||
|
router := server.NewRouter(jobH, clusterH, templateH, logger)
|
||||||
|
|
||||||
|
addr := ":" + cfg.ServerPort
|
||||||
|
|
||||||
|
return &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: router,
|
||||||
|
}
|
||||||
|
}
|
||||||
25
internal/app/app_test.go
Normal file
25
internal/app/app_test.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap/zaptest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewApp_InvalidDB(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
ServerPort: "8080",
|
||||||
|
MySQLDSN: "invalid:dsn@tcp(localhost:99999)/nonexistent?parseTime=true",
|
||||||
|
SlurmAPIURL: "http://localhost:6820",
|
||||||
|
SlurmUserName: "root",
|
||||||
|
SlurmJWTKeyPath: "/nonexistent/jwt.key",
|
||||||
|
}
|
||||||
|
logger := zaptest.NewLogger(t)
|
||||||
|
|
||||||
|
_, err := NewApp(cfg, logger)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid DSN, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
16
internal/config/config.example.yaml
Normal file
16
internal/config/config.example.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
server_port: "8080"
|
||||||
|
slurm_api_url: "http://localhost:6820"
|
||||||
|
slurm_user_name: "root"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
||||||
|
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
||||||
|
|
||||||
|
log:
|
||||||
|
level: "info" # debug, info, warn, error
|
||||||
|
encoding: "json" # json, console
|
||||||
|
output_stdout: true # 是否输出日志到终端
|
||||||
|
file_path: "" # 日志文件路径,留空则不写文件
|
||||||
|
max_size: 100 # max MB per log file
|
||||||
|
max_backups: 5 # number of old log files to retain
|
||||||
|
max_age: 30 # days to retain old log files
|
||||||
|
compress: true # gzip rotated log files
|
||||||
|
gorm_level: "warn" # GORM SQL log level: silent, error, warn, info
|
||||||
51
internal/config/config.go
Normal file
51
internal/config/config.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogConfig holds logging configuration values.
|
||||||
|
type LogConfig struct {
|
||||||
|
Level string `yaml:"level"` // debug, info, warn, error (default: info)
|
||||||
|
Encoding string `yaml:"encoding"` // json, console (default: json)
|
||||||
|
OutputStdout *bool `yaml:"output_stdout"` // 输出到终端 (default: true)
|
||||||
|
FilePath string `yaml:"file_path"` // log file path (for rotation)
|
||||||
|
MaxSize int `yaml:"max_size"` // MB per file (default: 100)
|
||||||
|
MaxBackups int `yaml:"max_backups"` // retained files (default: 5)
|
||||||
|
MaxAge int `yaml:"max_age"` // days to retain (default: 30)
|
||||||
|
Compress bool `yaml:"compress"` // gzip old files (default: true)
|
||||||
|
GormLevel string `yaml:"gorm_level"` // GORM SQL log level (default: warn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds all application configuration values.
|
||||||
|
type Config struct {
|
||||||
|
ServerPort string `yaml:"server_port"`
|
||||||
|
SlurmAPIURL string `yaml:"slurm_api_url"`
|
||||||
|
SlurmUserName string `yaml:"slurm_user_name"`
|
||||||
|
SlurmJWTKeyPath string `yaml:"slurm_jwt_key_path"`
|
||||||
|
MySQLDSN string `yaml:"mysql_dsn"`
|
||||||
|
Log LogConfig `yaml:"log"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads a YAML configuration file and returns a parsed Config.
|
||||||
|
// If path is empty, it defaults to "./config.yaml".
|
||||||
|
func Load(path string) (*Config, error) {
|
||||||
|
if path == "" {
|
||||||
|
path = "./config.yaml"
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read config file %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config file %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
256
internal/config/config_test.go
Normal file
256
internal/config/config_test.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "9090"
|
||||||
|
slurm_api_url: "http://slurm.example.com:6820"
|
||||||
|
slurm_user_name: "admin"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt.key"
|
||||||
|
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ServerPort != "9090" {
|
||||||
|
t.Errorf("ServerPort = %q, want %q", cfg.ServerPort, "9090")
|
||||||
|
}
|
||||||
|
if cfg.SlurmAPIURL != "http://slurm.example.com:6820" {
|
||||||
|
t.Errorf("SlurmAPIURL = %q, want %q", cfg.SlurmAPIURL, "http://slurm.example.com:6820")
|
||||||
|
}
|
||||||
|
if cfg.SlurmUserName != "admin" {
|
||||||
|
t.Errorf("SlurmUserName = %q, want %q", cfg.SlurmUserName, "admin")
|
||||||
|
}
|
||||||
|
if cfg.SlurmJWTKeyPath != "/etc/slurm/jwt.key" {
|
||||||
|
t.Errorf("SlurmJWTKeyPath = %q, want %q", cfg.SlurmJWTKeyPath, "/etc/slurm/jwt.key")
|
||||||
|
}
|
||||||
|
if cfg.MySQLDSN != "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true" {
|
||||||
|
t.Errorf("MySQLDSN = %q, want %q", cfg.MySQLDSN, "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultPath(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "8080"
|
||||||
|
slurm_api_url: "http://localhost:6820"
|
||||||
|
slurm_user_name: "root"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
||||||
|
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
t.Fatal("Load() returned nil config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWithLogConfig(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "9090"
|
||||||
|
slurm_api_url: "http://slurm.example.com:6820"
|
||||||
|
slurm_user_name: "admin"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt.key"
|
||||||
|
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
|
||||||
|
log:
|
||||||
|
level: "debug"
|
||||||
|
encoding: "console"
|
||||||
|
output_stdout: true
|
||||||
|
file_path: "/var/log/app.log"
|
||||||
|
max_size: 200
|
||||||
|
max_backups: 10
|
||||||
|
max_age: 60
|
||||||
|
compress: false
|
||||||
|
gorm_level: "info"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Log.Level != "debug" {
|
||||||
|
t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug")
|
||||||
|
}
|
||||||
|
if cfg.Log.Encoding != "console" {
|
||||||
|
t.Errorf("Log.Encoding = %q, want %q", cfg.Log.Encoding, "console")
|
||||||
|
}
|
||||||
|
if cfg.Log.OutputStdout == nil || *cfg.Log.OutputStdout != true {
|
||||||
|
t.Errorf("Log.OutputStdout = %v, want true", cfg.Log.OutputStdout)
|
||||||
|
}
|
||||||
|
if cfg.Log.FilePath != "/var/log/app.log" {
|
||||||
|
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
|
||||||
|
}
|
||||||
|
if cfg.Log.MaxSize != 200 {
|
||||||
|
t.Errorf("Log.MaxSize = %d, want %d", cfg.Log.MaxSize, 200)
|
||||||
|
}
|
||||||
|
if cfg.Log.MaxBackups != 10 {
|
||||||
|
t.Errorf("Log.MaxBackups = %d, want %d", cfg.Log.MaxBackups, 10)
|
||||||
|
}
|
||||||
|
if cfg.Log.MaxAge != 60 {
|
||||||
|
t.Errorf("Log.MaxAge = %d, want %d", cfg.Log.MaxAge, 60)
|
||||||
|
}
|
||||||
|
if cfg.Log.Compress != false {
|
||||||
|
t.Errorf("Log.Compress = %v, want %v", cfg.Log.Compress, false)
|
||||||
|
}
|
||||||
|
if cfg.Log.GormLevel != "info" {
|
||||||
|
t.Errorf("Log.GormLevel = %q, want %q", cfg.Log.GormLevel, "info")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWithoutLogConfig(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "8080"
|
||||||
|
slurm_api_url: "http://localhost:6820"
|
||||||
|
slurm_user_name: "root"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
||||||
|
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Log.Level != "" {
|
||||||
|
t.Errorf("Log.Level = %q, want empty string", cfg.Log.Level)
|
||||||
|
}
|
||||||
|
if cfg.Log.Encoding != "" {
|
||||||
|
t.Errorf("Log.Encoding = %q, want empty string", cfg.Log.Encoding)
|
||||||
|
}
|
||||||
|
if cfg.Log.OutputStdout != nil {
|
||||||
|
t.Errorf("Log.OutputStdout = %v, want nil", cfg.Log.OutputStdout)
|
||||||
|
}
|
||||||
|
if cfg.Log.FilePath != "" {
|
||||||
|
t.Errorf("Log.FilePath = %q, want empty string", cfg.Log.FilePath)
|
||||||
|
}
|
||||||
|
if cfg.Log.MaxSize != 0 {
|
||||||
|
t.Errorf("Log.MaxSize = %d, want 0", cfg.Log.MaxSize)
|
||||||
|
}
|
||||||
|
if cfg.Log.MaxBackups != 0 {
|
||||||
|
t.Errorf("Log.MaxBackups = %d, want 0", cfg.Log.MaxBackups)
|
||||||
|
}
|
||||||
|
if cfg.Log.MaxAge != 0 {
|
||||||
|
t.Errorf("Log.MaxAge = %d, want 0", cfg.Log.MaxAge)
|
||||||
|
}
|
||||||
|
if cfg.Log.Compress != false {
|
||||||
|
t.Errorf("Log.Compress = %v, want false", cfg.Log.Compress)
|
||||||
|
}
|
||||||
|
if cfg.Log.GormLevel != "" {
|
||||||
|
t.Errorf("Log.GormLevel = %q, want empty string", cfg.Log.GormLevel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadExistingFieldsWithLogConfig(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "7070"
|
||||||
|
slurm_api_url: "http://slurm2.example.com:6820"
|
||||||
|
slurm_user_name: "testuser"
|
||||||
|
slurm_jwt_key_path: "/keys/jwt.key"
|
||||||
|
mysql_dsn: "root:secret@tcp(db:3306)/mydb?parseTime=true"
|
||||||
|
log:
|
||||||
|
level: "warn"
|
||||||
|
encoding: "json"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ServerPort != "7070" {
|
||||||
|
t.Errorf("ServerPort = %q, want %q", cfg.ServerPort, "7070")
|
||||||
|
}
|
||||||
|
if cfg.SlurmAPIURL != "http://slurm2.example.com:6820" {
|
||||||
|
t.Errorf("SlurmAPIURL = %q, want %q", cfg.SlurmAPIURL, "http://slurm2.example.com:6820")
|
||||||
|
}
|
||||||
|
if cfg.SlurmUserName != "testuser" {
|
||||||
|
t.Errorf("SlurmUserName = %q, want %q", cfg.SlurmUserName, "testuser")
|
||||||
|
}
|
||||||
|
if cfg.SlurmJWTKeyPath != "/keys/jwt.key" {
|
||||||
|
t.Errorf("SlurmJWTKeyPath = %q, want %q", cfg.SlurmJWTKeyPath, "/keys/jwt.key")
|
||||||
|
}
|
||||||
|
if cfg.MySQLDSN != "root:secret@tcp(db:3306)/mydb?parseTime=true" {
|
||||||
|
t.Errorf("MySQLDSN = %q, want %q", cfg.MySQLDSN, "root:secret@tcp(db:3306)/mydb?parseTime=true")
|
||||||
|
}
|
||||||
|
if cfg.Log.Level != "warn" {
|
||||||
|
t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "warn")
|
||||||
|
}
|
||||||
|
if cfg.Log.Encoding != "json" {
|
||||||
|
t.Errorf("Log.Encoding = %q, want %q", cfg.Log.Encoding, "json")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadNonExistentFile(t *testing.T) {
|
||||||
|
_, err := Load("/nonexistent/path/config.yaml")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Load() expected error for non-existent file, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWithOutputStdoutFalse(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "9090"
|
||||||
|
slurm_api_url: "http://slurm.example.com:6820"
|
||||||
|
slurm_user_name: "admin"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt.key"
|
||||||
|
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
|
||||||
|
log:
|
||||||
|
level: "info"
|
||||||
|
encoding: "json"
|
||||||
|
output_stdout: false
|
||||||
|
file_path: "/var/log/app.log"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Log.OutputStdout == nil || *cfg.Log.OutputStdout != false {
|
||||||
|
t.Errorf("Log.OutputStdout = %v, want false", cfg.Log.OutputStdout)
|
||||||
|
}
|
||||||
|
if cfg.Log.FilePath != "/var/log/app.log" {
|
||||||
|
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
|
||||||
|
}
|
||||||
|
}
|
||||||
94
internal/handler/cluster.go
Normal file
94
internal/handler/cluster.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClusterHandler handles HTTP requests for cluster operations (nodes, partitions, diag).
|
||||||
|
type ClusterHandler struct {
|
||||||
|
clusterSvc *service.ClusterService
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClusterHandler creates a new ClusterHandler with the given ClusterService.
|
||||||
|
func NewClusterHandler(clusterSvc *service.ClusterService, logger *zap.Logger) *ClusterHandler {
|
||||||
|
return &ClusterHandler{clusterSvc: clusterSvc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodes handles GET /api/v1/nodes.
|
||||||
|
func (h *ClusterHandler) GetNodes(c *gin.Context) {
|
||||||
|
nodes, err := h.clusterSvc.GetNodes(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetNodes"), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNode handles GET /api/v1/nodes/:name.
|
||||||
|
func (h *ClusterHandler) GetNode(c *gin.Context) {
|
||||||
|
name := c.Param("name")
|
||||||
|
|
||||||
|
resp, err := h.clusterSvc.GetNode(c.Request.Context(), name)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetNode"), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
h.logger.Warn("not found", zap.String("method", "GetNode"), zap.String("name", name))
|
||||||
|
server.NotFound(c, "node not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPartitions handles GET /api/v1/partitions.
|
||||||
|
func (h *ClusterHandler) GetPartitions(c *gin.Context) {
|
||||||
|
partitions, err := h.clusterSvc.GetPartitions(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetPartitions"), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, partitions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPartition handles GET /api/v1/partitions/:name.
|
||||||
|
func (h *ClusterHandler) GetPartition(c *gin.Context) {
|
||||||
|
name := c.Param("name")
|
||||||
|
|
||||||
|
resp, err := h.clusterSvc.GetPartition(c.Request.Context(), name)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetPartition"), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
h.logger.Warn("not found", zap.String("method", "GetPartition"), zap.String("name", name))
|
||||||
|
server.NotFound(c, "partition not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDiag handles GET /api/v1/diag.
|
||||||
|
func (h *ClusterHandler) GetDiag(c *gin.Context) {
|
||||||
|
resp, err := h.clusterSvc.GetDiag(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetDiag"), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
634
internal/handler/cluster_test.go
Normal file
634
internal/handler/cluster_test.go
Normal file
@@ -0,0 +1,634 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupClusterHandler(slurmHandler http.HandlerFunc) (*httptest.Server, *ClusterHandler) {
|
||||||
|
srv := httptest.NewServer(slurmHandler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
clusterSvc := service.NewClusterService(client, zap.NewNop())
|
||||||
|
return srv, NewClusterHandler(clusterSvc, zap.NewNop())
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupClusterHandlerWithObserver(slurmHandler http.HandlerFunc) (*httptest.Server, *ClusterHandler, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
srv := httptest.NewServer(slurmHandler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
clusterSvc := service.NewClusterService(client, l)
|
||||||
|
return srv, NewClusterHandler(clusterSvc, l), recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupClusterRouter(h *ClusterHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
v1.GET("/nodes", h.GetNodes)
|
||||||
|
v1.GET("/nodes/:name", h.GetNode)
|
||||||
|
v1.GET("/partitions", h.GetPartitions)
|
||||||
|
v1.GET("/partitions/:name", h.GetPartition)
|
||||||
|
v1.GET("/diag", h.GetDiag)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNodes_Success(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{
|
||||||
|
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["success"] != true {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNode_Success(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{
|
||||||
|
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["success"] != true {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNode_NotFound(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes/nonexistent", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartitions_Success(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "normal",
|
||||||
|
"partition": map[string]interface{}{
|
||||||
|
"state": []string{"UP"},
|
||||||
|
},
|
||||||
|
"nodes": map[string]interface{}{
|
||||||
|
"configured": "node[1-10]",
|
||||||
|
"total": int32(10),
|
||||||
|
},
|
||||||
|
"cpus": map[string]interface{}{
|
||||||
|
"total": int32(640),
|
||||||
|
},
|
||||||
|
"maximums": map[string]interface{}{
|
||||||
|
"time": map[string]interface{}{"number": int64(60)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["success"] != true {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartition_Success(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "normal",
|
||||||
|
"partition": map[string]interface{}{
|
||||||
|
"state": []string{"UP"},
|
||||||
|
},
|
||||||
|
"nodes": map[string]interface{}{
|
||||||
|
"configured": "node[1-10]",
|
||||||
|
"total": int32(10),
|
||||||
|
},
|
||||||
|
"cpus": map[string]interface{}{
|
||||||
|
"total": int32(640),
|
||||||
|
},
|
||||||
|
"maximums": map[string]interface{}{
|
||||||
|
"time": map[string]interface{}{"number": int64(60)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["success"] != true {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartition_NotFound(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions/nonexistent", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiag_Success(t *testing.T) {
|
||||||
|
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"statistics": map[string]interface{}{
|
||||||
|
"server_thread_count": 3,
|
||||||
|
"agent_queue_size": 0,
|
||||||
|
"jobs_submitted": 100,
|
||||||
|
"jobs_started": 90,
|
||||||
|
"jobs_completed": 85,
|
||||||
|
"schedule_cycle_last": 10,
|
||||||
|
"schedule_cycle_total": 500,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["success"] != true {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Logging tests ---
|
||||||
|
|
||||||
|
func TestClusterHandler_GetNodes_InternalError_LogsError(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerLogs := recorded.FilterMessage("handler error")
|
||||||
|
if handlerLogs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
|
||||||
|
}
|
||||||
|
entry := handlerLogs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetNodes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetNodes_Success_NoLogs(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{
|
||||||
|
{"name": "node1"},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 0 {
|
||||||
|
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetNode_InternalError_LogsError(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerLogs := recorded.FilterMessage("handler error")
|
||||||
|
if handlerLogs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
|
||||||
|
}
|
||||||
|
entry := handlerLogs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetNode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetNode_NotFound_LogsWarn(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes/nonexistent", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
entry := recorded.All()[0]
|
||||||
|
if entry.Level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected Warn level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
if entry.Message != "not found" {
|
||||||
|
t.Fatalf("expected message 'not found', got %q", entry.Message)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetNode")
|
||||||
|
assertField(t, entry.Context, "name", "nonexistent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetNode_Success_NoLogs(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{
|
||||||
|
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 0 {
|
||||||
|
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetPartitions_InternalError_LogsError(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerLogs := recorded.FilterMessage("handler error")
|
||||||
|
if handlerLogs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
|
||||||
|
}
|
||||||
|
entry := handlerLogs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetPartitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetPartitions_Success_NoLogs(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{
|
||||||
|
{"name": "normal"},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 0 {
|
||||||
|
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetPartition_InternalError_LogsError(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerLogs := recorded.FilterMessage("handler error")
|
||||||
|
if handlerLogs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
|
||||||
|
}
|
||||||
|
entry := handlerLogs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetPartition")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetPartition_NotFound_LogsWarn(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions/nonexistent", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
entry := recorded.All()[0]
|
||||||
|
if entry.Level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected Warn level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
if entry.Message != "not found" {
|
||||||
|
t.Fatalf("expected message 'not found', got %q", entry.Message)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetPartition")
|
||||||
|
assertField(t, entry.Context, "name", "nonexistent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetPartition_Success_NoLogs(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "normal",
|
||||||
|
"partition": map[string]interface{}{
|
||||||
|
"state": []string{"UP"},
|
||||||
|
},
|
||||||
|
"nodes": map[string]interface{}{
|
||||||
|
"configured": "node[1-10]",
|
||||||
|
"total": int32(10),
|
||||||
|
},
|
||||||
|
"cpus": map[string]interface{}{
|
||||||
|
"total": int32(640),
|
||||||
|
},
|
||||||
|
"maximums": map[string]interface{}{
|
||||||
|
"time": map[string]interface{}{"number": int64(60)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"last_update": map[string]interface{}{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 0 {
|
||||||
|
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetDiag_InternalError_LogsError(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerLogs := recorded.FilterMessage("handler error")
|
||||||
|
if handlerLogs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
|
||||||
|
}
|
||||||
|
entry := handlerLogs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
assertField(t, entry.Context, "method", "GetDiag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterHandler_GetDiag_Success_NoLogs(t *testing.T) {
|
||||||
|
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"statistics": map[string]interface{}{
|
||||||
|
"server_thread_count": 3,
|
||||||
|
"agent_queue_size": 0,
|
||||||
|
"jobs_submitted": 100,
|
||||||
|
"jobs_started": 90,
|
||||||
|
"jobs_completed": 85,
|
||||||
|
"schedule_cycle_last": 10,
|
||||||
|
"schedule_cycle_total": 500,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupClusterRouter(h)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorded.Len() != 0 {
|
||||||
|
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertField checks that a zap Field slice contains a string field with the given key and value.
|
||||||
|
func assertField(t *testing.T, fields []zapcore.Field, key, value string) {
|
||||||
|
t.Helper()
|
||||||
|
for _, f := range fields {
|
||||||
|
if f.Key == key && f.String == value {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("expected field %q=%q in context, got %v", key, value, fields)
|
||||||
|
}
|
||||||
118
internal/handler/job.go
Normal file
118
internal/handler/job.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JobHandler handles HTTP requests for job operations.
|
||||||
|
type JobHandler struct {
|
||||||
|
jobSvc *service.JobService
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJobHandler creates a new JobHandler with the given JobService.
|
||||||
|
func NewJobHandler(jobSvc *service.JobService, logger *zap.Logger) *JobHandler {
|
||||||
|
return &JobHandler{jobSvc: jobSvc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubmitJob handles POST /api/v1/jobs/submit.
|
||||||
|
func (h *JobHandler) SubmitJob(c *gin.Context) {
|
||||||
|
var req model.SubmitJobRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "invalid request body"))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Script == "" {
|
||||||
|
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "script is required"))
|
||||||
|
server.BadRequest(c, "script is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.jobSvc.SubmitJob(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "SubmitJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
|
||||||
|
server.ErrorWithStatus(c, http.StatusBadGateway, "slurm error: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.Created(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJobs handles GET /api/v1/jobs.
|
||||||
|
func (h *JobHandler) GetJobs(c *gin.Context) {
|
||||||
|
jobs, err := h.jobSvc.GetJobs(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetJobs"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJob handles GET /api/v1/jobs/:id.
|
||||||
|
func (h *JobHandler) GetJob(c *gin.Context) {
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
resp, err := h.jobSvc.GetJob(c.Request.Context(), jobID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetJob"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
h.logger.Warn("bad request", zap.String("method", "GetJob"), zap.String("error", "job not found"))
|
||||||
|
server.NotFound(c, "job not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelJob handles DELETE /api/v1/jobs/:id.
|
||||||
|
func (h *JobHandler) CancelJob(c *gin.Context) {
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
err := h.jobSvc.CancelJob(c.Request.Context(), jobID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "CancelJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
|
||||||
|
server.ErrorWithStatus(c, http.StatusBadGateway, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, gin.H{"message": "job cancelled"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJobHistory handles GET /api/v1/jobs/history.
|
||||||
|
func (h *JobHandler) GetJobHistory(c *gin.Context) {
|
||||||
|
var query model.JobHistoryQuery
|
||||||
|
if err := c.ShouldBindQuery(&query); err != nil {
|
||||||
|
h.logger.Warn("bad request", zap.String("method", "GetJobHistory"), zap.String("error", "invalid query params"))
|
||||||
|
server.BadRequest(c, "invalid query params")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.Page < 1 {
|
||||||
|
query.Page = 1
|
||||||
|
}
|
||||||
|
if query.PageSize < 1 {
|
||||||
|
query.PageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.jobSvc.GetJobHistory(c.Request.Context(), &query)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("handler error", zap.String("method", "GetJobHistory"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
821
internal/handler/job_test.go
Normal file
821
internal/handler/job_test.go
Normal file
@@ -0,0 +1,821 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupJobRouter(h *JobHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
jobs := v1.Group("/jobs")
|
||||||
|
{
|
||||||
|
jobs.POST("/submit", h.SubmitJob)
|
||||||
|
jobs.GET("", h.GetJobs)
|
||||||
|
jobs.GET("/history", h.GetJobHistory)
|
||||||
|
jobs.GET("/:id", h.GetJob)
|
||||||
|
jobs.DELETE("/:id", h.CancelJob)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupJobHandler(mux *http.ServeMux) (*httptest.Server, *JobHandler) {
|
||||||
|
srv := httptest.NewServer(mux)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
jobSvc := service.NewJobService(client, zap.NewNop())
|
||||||
|
return srv, NewJobHandler(jobSvc, zap.NewNop())
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupJobHandlerWithObserver(mux *http.ServeMux) (*httptest.Server, *JobHandler, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
srv := httptest.NewServer(mux)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
jobSvc := service.NewJobService(client, l)
|
||||||
|
return srv, NewJobHandler(jobSvc, l), recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerLogs(logs *observer.ObservedLogs) []observer.LoggedEntry {
|
||||||
|
var handler []observer.LoggedEntry
|
||||||
|
for _, e := range logs.All() {
|
||||||
|
for _, f := range e.Context {
|
||||||
|
if f.Key == "method" {
|
||||||
|
handler = append(handler, e)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_Success(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: slurm.Ptr(int32(123))},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
body := `{"script":"#!/bin/bash\necho hello"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if int(data["job_id"].(float64)) != 123 {
|
||||||
|
t.Errorf("expected job_id=123, got %v", data["job_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_MissingScript(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
body := `{"partition":"normal"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
if resp["error"] != "invalid request body" && resp["error"] != "script is required" {
|
||||||
|
t.Errorf("expected validation error, got %v", resp["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_EmptyScript(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
body := `{"script":""}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_SlurmError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
body := `{"script":"#!/bin/bash\necho hello"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobs_Success(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: []slurm.JobInfo{
|
||||||
|
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("job1")},
|
||||||
|
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("job2")},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_Success(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: []slurm.JobInfo{
|
||||||
|
{JobID: slurm.Ptr(int32(42)), Name: slurm.Ptr("test-job")},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if int(data["job_id"].(float64)) != 42 {
|
||||||
|
t.Errorf("expected job_id=42, got %v", data["job_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_NotFound(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/999", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: []slurm.JobInfo{},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if resp["error"].(string) != "job not found" {
|
||||||
|
t.Errorf("expected 'job not found' error, got %v", resp["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelJob_Success(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiResp{})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["message"].(string) != "job cancelled" {
|
||||||
|
t.Errorf("expected 'job cancelled', got %v", data["message"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_Success(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
|
||||||
|
Jobs: []slurm.Job{
|
||||||
|
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("hist1")},
|
||||||
|
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("hist2")},
|
||||||
|
{JobID: slurm.Ptr(int32(3)), Name: slurm.Ptr("hist3")},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=2", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
jobs := data["jobs"].([]interface{})
|
||||||
|
if len(jobs) != 2 {
|
||||||
|
t.Fatalf("expected 2 jobs on page 1, got %d", len(jobs))
|
||||||
|
}
|
||||||
|
if int(data["total"].(float64)) != 3 {
|
||||||
|
t.Errorf("expected total=3, got %v", data["total"])
|
||||||
|
}
|
||||||
|
if int(data["page"].(float64)) != 1 {
|
||||||
|
t.Errorf("expected page=1, got %v", data["page"])
|
||||||
|
}
|
||||||
|
if int(data["page_size"].(float64)) != 2 {
|
||||||
|
t.Errorf("expected page_size=2, got %v", data["page_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_DefaultPagination(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
|
||||||
|
Jobs: []slurm.Job{
|
||||||
|
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("h1")},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if int(data["page"].(float64)) != 1 {
|
||||||
|
t.Errorf("expected default page=1, got %v", data["page"])
|
||||||
|
}
|
||||||
|
if int(data["page_size"].(float64)) != 20 {
|
||||||
|
t.Errorf("expected default page_size=20, got %v", data["page_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_InvalidBody(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`not json`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Logging verification tests ---
|
||||||
|
|
||||||
|
func TestSubmitJob_InvalidBody_LogsWarn(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`not json`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.WarnLevel {
|
||||||
|
t.Errorf("expected Warn level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
if entry.Context[0].Key != "method" || entry.Context[0].String != "SubmitJob" {
|
||||||
|
t.Errorf("expected method=SubmitJob, got %v", entry.Context[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_EmptyScript_LogsWarn(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":""}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.WarnLevel {
|
||||||
|
t.Errorf("expected Warn level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_SlurmError_LogsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":"#!/bin/bash\necho hello"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected 502, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
foundMethod := false
|
||||||
|
foundStatus := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "method" && f.String == "SubmitJob" {
|
||||||
|
foundMethod = true
|
||||||
|
}
|
||||||
|
if f.Key == "status" && f.Integer == http.StatusBadGateway {
|
||||||
|
foundStatus = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundMethod {
|
||||||
|
t.Error("expected method=SubmitJob in log fields")
|
||||||
|
}
|
||||||
|
if !foundStatus {
|
||||||
|
t.Error("expected status=502 in log fields")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_Success_NoHandlerLogs(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: slurm.Ptr(int32(123))},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":"#!/bin/bash\necho hello"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 0 {
|
||||||
|
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobs_Error_LogsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
foundMethod := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "method" && f.String == "GetJobs" {
|
||||||
|
foundMethod = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundMethod {
|
||||||
|
t.Error("expected method=GetJobs in log fields")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_NotFound_LogsWarn(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/999", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/999", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.WarnLevel {
|
||||||
|
t.Errorf("expected Warn level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
foundMethod := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "method" && f.String == "GetJob" {
|
||||||
|
foundMethod = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundMethod {
|
||||||
|
t.Error("expected method=GetJob in log fields")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_Error_LogsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
if hLogs[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected Error level, got %v", hLogs[0].Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelJob_SlurmError_LogsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected 502, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected Error level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
foundMethod := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "method" && f.String == "CancelJob" {
|
||||||
|
foundMethod = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundMethod {
|
||||||
|
t.Error("expected method=CancelJob in log fields")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_InvalidQuery_LogsWarn(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{})
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=abc", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
entry := hLogs[0]
|
||||||
|
if entry.Level != zapcore.WarnLevel {
|
||||||
|
t.Errorf("expected Warn level, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
foundMethod := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "method" && f.String == "GetJobHistory" {
|
||||||
|
foundMethod = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundMethod {
|
||||||
|
t.Error("expected method=GetJobHistory in log fields")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_Error_LogsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=10", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 1 {
|
||||||
|
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
if hLogs[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected Error level, got %v", hLogs[0].Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_Success_NoHandlerLogs(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
|
||||||
|
Jobs: []slurm.Job{
|
||||||
|
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("h1")},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=10", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 0 {
|
||||||
|
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobs_Success_NoHandlerLogs(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 0 {
|
||||||
|
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelJob_Success_NoHandlerLogs(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiResp{})
|
||||||
|
})
|
||||||
|
srv, handler, logs := setupJobHandlerWithObserver(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
hLogs := handlerLogs(logs)
|
||||||
|
if len(hLogs) != 0 {
|
||||||
|
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
139
internal/handler/template.go
Normal file
139
internal/handler/template.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TemplateHandler struct {
|
||||||
|
store *store.TemplateStore
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTemplateHandler(s *store.TemplateStore, logger *zap.Logger) *TemplateHandler {
|
||||||
|
return &TemplateHandler{store: s, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TemplateHandler) ListTemplates(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
templates, total, err := h.store.List(c.Request.Context(), page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to list templates", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, gin.H{
|
||||||
|
"templates": templates,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": pageSize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TemplateHandler) CreateTemplate(c *gin.Context) {
|
||||||
|
var req model.CreateTemplateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for create template", zap.Error(err))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" || req.Script == "" {
|
||||||
|
h.logger.Warn("missing required fields for create template")
|
||||||
|
server.BadRequest(c, "name and script are required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := h.store.Create(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to create template", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("template created", zap.Int64("id", id))
|
||||||
|
server.Created(c, gin.H{"id": id})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TemplateHandler) GetTemplate(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid template id", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl, err := h.store.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to get template", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tmpl == nil {
|
||||||
|
h.logger.Warn("template not found", zap.Int64("id", id))
|
||||||
|
server.NotFound(c, "template not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, tmpl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TemplateHandler) UpdateTemplate(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid template id for update", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req model.UpdateTemplateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for update template", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.store.Update(c.Request.Context(), id, &req); err != nil {
|
||||||
|
h.logger.Error("failed to update template", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("template updated", zap.Int64("id", id))
|
||||||
|
server.OK(c, gin.H{"message": "template updated"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TemplateHandler) DeleteTemplate(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid template id for delete", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.store.Delete(c.Request.Context(), id); err != nil {
|
||||||
|
h.logger.Error("failed to delete template", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("template deleted", zap.Int64("id", id))
|
||||||
|
server.OK(c, gin.H{"message": "template deleted"})
|
||||||
|
}
|
||||||
387
internal/handler/template_test.go
Normal file
387
internal/handler/template_test.go
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTemplateHandler() (*TemplateHandler, *gorm.DB) {
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
|
db.AutoMigrate(&model.JobTemplate{})
|
||||||
|
s := store.NewTemplateStore(db)
|
||||||
|
h := NewTemplateHandler(s, zap.NewNop())
|
||||||
|
return h, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTemplateRouter(h *TemplateHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
templates := v1.Group("/templates")
|
||||||
|
templates.GET("", h.ListTemplates)
|
||||||
|
templates.POST("", h.CreateTemplate)
|
||||||
|
templates.GET("/:id", h.GetTemplate)
|
||||||
|
templates.PUT("/:id", h.UpdateTemplate)
|
||||||
|
templates.DELETE("/:id", h.DeleteTemplate)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListTemplates_Success(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
// Seed data
|
||||||
|
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal", QOS: "high", CPUs: 4, Memory: "4GB"})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp server.APIResponse
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateTemplate_Success(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
body := `{"name":"my-tpl","description":"desc","script":"echo hello","partition":"gpu"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp server.APIResponse
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateTemplate_MissingFields(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
body := `{"name":"","script":""}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTemplate_Success(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
// Seed data
|
||||||
|
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal"})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp server.APIResponse
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTemplate_NotFound(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTemplate_InvalidID(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateTemplate_Success(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
// Seed data
|
||||||
|
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
|
||||||
|
|
||||||
|
body := `{"name":"updated"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp server.APIResponse
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteTemplate_Success(t *testing.T) {
|
||||||
|
h, _ := setupTemplateHandler()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
// Seed data
|
||||||
|
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp server.APIResponse
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// itoa converts int64 to string for URL path construction.
|
||||||
|
func itoa(id int64) string {
|
||||||
|
return fmt.Sprintf("%d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTemplateHandlerWithObserver() (*TemplateHandler, *gorm.DB, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
|
db.AutoMigrate(&model.JobTemplate{})
|
||||||
|
s := store.NewTemplateStore(db)
|
||||||
|
return NewTemplateHandler(s, l), db, recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
body := `{"name":"log-tpl","script":"echo hi"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.FilterMessage("template created").FilterLevelExact(zapcore.InfoLevel).All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 info log for 'template created', got %d", len(entries))
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["id"] == nil {
|
||||||
|
t.Fatal("expected log entry to contain 'id' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
|
||||||
|
|
||||||
|
body := `{"name":"updated"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.FilterMessage("template updated").FilterLevelExact(zapcore.InfoLevel).All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 info log for 'template updated', got %d", len(entries))
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["id"] == nil {
|
||||||
|
t.Fatal("expected log entry to contain 'id' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.FilterMessage("template deleted").FilterLevelExact(zapcore.InfoLevel).All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 info log for 'template deleted', got %d", len(entries))
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["id"] == nil {
|
||||||
|
t.Fatal("expected log entry to contain 'id' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.FilterMessage("template not found").FilterLevelExact(zapcore.WarnLevel).All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 warn log for 'template not found', got %d", len(entries))
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["id"] == nil {
|
||||||
|
t.Fatal("expected log entry to contain 'id' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_CreateBadRequest_LogsWarn(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
body := `{"name":"","script":""}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
|
||||||
|
if len(warnEntries) == 0 {
|
||||||
|
t.Fatal("expected at least 1 warn log for bad request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_InvalidID_LogsWarn(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
|
||||||
|
if len(warnEntries) == 0 {
|
||||||
|
t.Fatal("expected at least 1 warn log for invalid id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_ListSuccess_NoInfoLog(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
// Seed data
|
||||||
|
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi"})
|
||||||
|
|
||||||
|
// Reset recorded logs so the create log doesn't interfere
|
||||||
|
recorded.TakeAll()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
infoEntries := recorded.FilterLevelExact(zapcore.InfoLevel).All()
|
||||||
|
if len(infoEntries) != 0 {
|
||||||
|
t.Fatalf("expected 0 info logs for list success, got %d: %+v", len(infoEntries), infoEntries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateLogging_LogsDoNotContainTemplateContent(t *testing.T) {
|
||||||
|
h, _, recorded := setupTemplateHandlerWithObserver()
|
||||||
|
r := setupTemplateRouter(h)
|
||||||
|
|
||||||
|
body := `{"name":"secret-name","script":"secret-script","partition":"secret-partition"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
for _, e := range entries {
|
||||||
|
logStr := e.Message + " " + fmt.Sprintf("%v", e.ContextMap())
|
||||||
|
if strings.Contains(logStr, "secret-name") || strings.Contains(logStr, "secret-script") || strings.Contains(logStr, "secret-partition") {
|
||||||
|
t.Fatalf("log entry contains sensitive template content: %s", logStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
148
internal/logger/gorm.go
Normal file
148
internal/logger/gorm.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const slowQueryThreshold = 200 * time.Millisecond
|
||||||
|
|
||||||
|
// GormLogger implements gorm's logger.Interface backed by zap.
|
||||||
|
type GormLogger struct {
|
||||||
|
logger *zap.Logger
|
||||||
|
level zapcore.Level
|
||||||
|
silent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time interface check.
|
||||||
|
var _ gormlogger.Interface = (*GormLogger)(nil)
|
||||||
|
|
||||||
|
// NewGormLogger creates a new GormLogger wrapping the given zap logger.
|
||||||
|
// The level string maps to zap levels; empty defaults to "warn".
|
||||||
|
// The special value "silent" suppresses all output.
|
||||||
|
func NewGormLogger(zapLogger *zap.Logger, level string) gormlogger.Interface {
|
||||||
|
lvl := parseGormLevel(level)
|
||||||
|
silent := level == "silent"
|
||||||
|
return &GormLogger{
|
||||||
|
logger: zapLogger,
|
||||||
|
level: lvl,
|
||||||
|
silent: silent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogMode returns a new GormLogger with the given gorm log level.
|
||||||
|
// It does NOT mutate the receiver.
|
||||||
|
func (l *GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
|
||||||
|
newLogger := &GormLogger{
|
||||||
|
logger: l.logger,
|
||||||
|
level: l.level,
|
||||||
|
silent: l.silent,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch level {
|
||||||
|
case gormlogger.Silent:
|
||||||
|
newLogger.silent = true
|
||||||
|
case gormlogger.Error:
|
||||||
|
newLogger.level = zapcore.ErrorLevel
|
||||||
|
newLogger.silent = false
|
||||||
|
case gormlogger.Warn:
|
||||||
|
newLogger.level = zapcore.WarnLevel
|
||||||
|
newLogger.silent = false
|
||||||
|
case gormlogger.Info:
|
||||||
|
newLogger.level = zapcore.InfoLevel
|
||||||
|
newLogger.silent = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return newLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info logs at zap.InfoLevel with structured fields from key-value pairs.
|
||||||
|
func (l *GormLogger) Info(ctx context.Context, msg string, args ...any) {
|
||||||
|
if l.silent || l.level > zapcore.InfoLevel {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.logger.Info(msg, argsToFields(args)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn logs at zap.WarnLevel with structured fields from key-value pairs.
|
||||||
|
func (l *GormLogger) Warn(ctx context.Context, msg string, args ...any) {
|
||||||
|
if l.silent || l.level > zapcore.WarnLevel {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.logger.Warn(msg, argsToFields(args)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs at zap.ErrorLevel with structured fields from key-value pairs.
|
||||||
|
func (l *GormLogger) Error(ctx context.Context, msg string, args ...any) {
|
||||||
|
if l.silent || l.level > zapcore.ErrorLevel {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.logger.Error(msg, argsToFields(args)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace logs SQL query information based on execution results.
|
||||||
|
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||||
|
if l.silent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(begin)
|
||||||
|
sql, rows := fc()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
|
||||||
|
l.logger.Error("gorm query error",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("sql", sql),
|
||||||
|
zap.Int64("rows", rows),
|
||||||
|
zap.Duration("elapsed", elapsed),
|
||||||
|
)
|
||||||
|
case elapsed > slowQueryThreshold:
|
||||||
|
l.logger.Warn("gorm slow query",
|
||||||
|
zap.String("sql", sql),
|
||||||
|
zap.Int64("rows", rows),
|
||||||
|
zap.Float64("elapsed_ms", float64(elapsed.Nanoseconds())/1e6),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
if l.level > zapcore.InfoLevel {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.logger.Info("gorm query",
|
||||||
|
zap.String("sql", sql),
|
||||||
|
zap.Int64("rows", rows),
|
||||||
|
zap.Duration("elapsed", elapsed),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseGormLevel parses a level string into a zapcore.Level.
|
||||||
|
// Defaults to zapcore.WarnLevel.
|
||||||
|
func parseGormLevel(level string) zapcore.Level {
|
||||||
|
if level == "" || level == "silent" {
|
||||||
|
return zapcore.WarnLevel
|
||||||
|
}
|
||||||
|
var lvl zapcore.Level
|
||||||
|
if err := lvl.UnmarshalText([]byte(level)); err != nil {
|
||||||
|
return zapcore.WarnLevel
|
||||||
|
}
|
||||||
|
return lvl
|
||||||
|
}
|
||||||
|
|
||||||
|
// argsToFields converts alternating key-value pairs into zap fields.
|
||||||
|
func argsToFields(args []any) []zap.Field {
|
||||||
|
fields := make([]zap.Field, 0, len(args)/2)
|
||||||
|
for i := 0; i+1 < len(args); i += 2 {
|
||||||
|
key, ok := args[i].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields = append(fields, zap.Any(key, args[i+1]))
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
509
internal/logger/gorm_test.go
Normal file
509
internal/logger/gorm_test.go
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newObservedLogger creates a zap logger backed by an observer for test assertions.
|
||||||
|
func newObservedLogger() (*zap.Logger, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
return zap.New(core), recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGormLogger_DefaultLevel(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "")
|
||||||
|
g, ok := gl.(*GormLogger)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected *GormLogger")
|
||||||
|
}
|
||||||
|
if g.level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected warn level, got %v", g.level)
|
||||||
|
}
|
||||||
|
if g.silent {
|
||||||
|
t.Fatal("expected silent=false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGormLogger_Silent(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "silent")
|
||||||
|
g, ok := gl.(*GormLogger)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected *GormLogger")
|
||||||
|
}
|
||||||
|
if !g.silent {
|
||||||
|
t.Fatal("expected silent=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGormLogger_ExplicitLevel(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "error")
|
||||||
|
g, ok := gl.(*GormLogger)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected *GormLogger")
|
||||||
|
}
|
||||||
|
if g.level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected error level, got %v", g.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGormLogger_InvalidLevel(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "bogus")
|
||||||
|
g, ok := gl.(*GormLogger)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected *GormLogger")
|
||||||
|
}
|
||||||
|
// Should default to warn
|
||||||
|
if g.level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected warn level fallback, got %v", g.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Info(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info")
|
||||||
|
|
||||||
|
gl.Info(context.Background(), "test info", "key", "value")
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Message != "test info" {
|
||||||
|
t.Fatalf("expected message 'test info', got %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.InfoLevel {
|
||||||
|
t.Fatalf("expected InfoLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Warn(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
gl.Warn(context.Background(), "test warn", "code", 42)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Message != "test warn" {
|
||||||
|
t.Fatalf("expected message 'test warn', got %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected WarnLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Error(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "error")
|
||||||
|
|
||||||
|
gl.Error(context.Background(), "test error", "module", "gorm")
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Message != "test error" {
|
||||||
|
t.Fatalf("expected message 'test error', got %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected ErrorLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_LevelFiltering(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
// Info should be suppressed at warn level
|
||||||
|
gl.Info(context.Background(), "should be suppressed")
|
||||||
|
if len(recorded.All()) != 0 {
|
||||||
|
t.Fatal("info should be suppressed at warn level")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn should pass
|
||||||
|
gl.Warn(context.Background(), "should pass")
|
||||||
|
if len(recorded.All()) != 1 {
|
||||||
|
t.Fatal("warn should pass at warn level")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_SilentSuppressesAll(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "silent")
|
||||||
|
|
||||||
|
gl.Info(context.Background(), "info msg")
|
||||||
|
gl.Warn(context.Background(), "warn msg")
|
||||||
|
gl.Error(context.Background(), "error msg")
|
||||||
|
|
||||||
|
if len(recorded.All()) != 0 {
|
||||||
|
t.Fatalf("silent mode should suppress all logs, got %d", len(recorded.All()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_LogMode_ReturnsNewInstance(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
original := NewGormLogger(log, "warn").(*GormLogger)
|
||||||
|
|
||||||
|
modified := original.LogMode(gormlogger.Info)
|
||||||
|
|
||||||
|
// Must be a different instance
|
||||||
|
if modified == original {
|
||||||
|
t.Fatal("LogMode should return a new instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original should be unchanged
|
||||||
|
if original.level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("original level should remain warn, got %v", original.level)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New instance should have InfoLevel
|
||||||
|
mod := modified.(*GormLogger)
|
||||||
|
if mod.level != zapcore.InfoLevel {
|
||||||
|
t.Fatalf("modified level should be info, got %v", mod.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_LogMode_Silent(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info").(*GormLogger)
|
||||||
|
|
||||||
|
silent := gl.LogMode(gormlogger.Silent).(*GormLogger)
|
||||||
|
if !silent.silent {
|
||||||
|
t.Fatal("LogMode(Silent) should set silent=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original should not be silent
|
||||||
|
if gl.silent {
|
||||||
|
t.Fatal("original should not be affected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_LogMode_ErrorLevel(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info").(*GormLogger)
|
||||||
|
|
||||||
|
modified := gl.LogMode(gormlogger.Error).(*GormLogger)
|
||||||
|
if modified.level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected ErrorLevel, got %v", modified.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_LogMode_WarnLevel(t *testing.T) {
|
||||||
|
log, _ := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info").(*GormLogger)
|
||||||
|
|
||||||
|
modified := gl.LogMode(gormlogger.Warn).(*GormLogger)
|
||||||
|
if modified.level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected WarnLevel, got %v", modified.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_WithError(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "SELECT * FROM users", 0 }
|
||||||
|
err := errors.New("connection refused")
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, err)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Fatalf("expected ErrorLevel for real errors, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
if entries[0].Message != "gorm query error" {
|
||||||
|
t.Fatalf("unexpected message: %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_ErrRecordNotFound(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
// ErrRecordNotFound should NOT be logged as error
|
||||||
|
// At default level with no slow query, it should log as info
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level == zapcore.ErrorLevel {
|
||||||
|
t.Fatal("ErrRecordNotFound should NOT be logged at ErrorLevel")
|
||||||
|
}
|
||||||
|
if entries[0].Message != "gorm query" {
|
||||||
|
t.Fatalf("expected 'gorm query', got %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_ErrRecordNotFound_SuppressedAtWarnLevel(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
// ErrRecordNotFound is not a real error, so it falls through to the default path.
|
||||||
|
// At warn level, the default path (info-level) is suppressed.
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Fatalf("expected 0 entries at warn level, got %d", len(entries))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_SlowQuery(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
// Simulate a begin time far enough in the past to exceed the threshold
|
||||||
|
begin := time.Now().Add(-500 * time.Millisecond)
|
||||||
|
fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, nil)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.WarnLevel {
|
||||||
|
t.Fatalf("expected WarnLevel for slow query, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
if entries[0].Message != "gorm slow query" {
|
||||||
|
t.Fatalf("expected 'gorm slow query', got %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_NormalQuery(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "SELECT 1", 1 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, nil)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.InfoLevel {
|
||||||
|
t.Fatalf("expected InfoLevel for normal query, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
if entries[0].Message != "gorm query" {
|
||||||
|
t.Fatalf("expected 'gorm query', got %q", entries[0].Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_NormalQuerySuppressedAtWarnLevel(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "SELECT 1", 1 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, nil)
|
||||||
|
|
||||||
|
// Normal queries are info-level, suppressed at warn
|
||||||
|
if len(recorded.All()) != 0 {
|
||||||
|
t.Fatal("normal query should be suppressed at warn level")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_SilentSuppressesAll(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "silent")
|
||||||
|
|
||||||
|
begin := time.Now().Add(-500 * time.Millisecond)
|
||||||
|
fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 }
|
||||||
|
err := errors.New("some error")
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, err)
|
||||||
|
|
||||||
|
if len(recorded.All()) != 0 {
|
||||||
|
t.Fatal("silent mode should suppress even Trace errors")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_StructuredFields(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "SELECT * FROM users", 42 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, nil)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify structured fields
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["sql"] != "SELECT * FROM users" {
|
||||||
|
t.Fatalf("expected sql field, got %v", fields["sql"])
|
||||||
|
}
|
||||||
|
if fields["rows"] != int64(42) {
|
||||||
|
t.Fatalf("expected rows=42, got %v", fields["rows"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["elapsed"]; !ok {
|
||||||
|
t.Fatal("expected elapsed field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_ErrorFields(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
begin := time.Now()
|
||||||
|
fc := func() (string, int64) { return "INSERT INTO users", 0 }
|
||||||
|
err := errors.New("duplicate key")
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, err)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["sql"] != "INSERT INTO users" {
|
||||||
|
t.Fatalf("expected sql field, got %v", fields["sql"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["error"]; !ok {
|
||||||
|
t.Fatal("expected error field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_Trace_SlowQueryFields(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "warn")
|
||||||
|
|
||||||
|
begin := time.Now().Add(-500 * time.Millisecond)
|
||||||
|
fc := func() (string, int64) { return "SELECT * FROM large_table", 1000 }
|
||||||
|
|
||||||
|
gl.Trace(context.Background(), begin, fc, nil)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["sql"] != "SELECT * FROM large_table" {
|
||||||
|
t.Fatalf("expected sql field, got %v", fields["sql"])
|
||||||
|
}
|
||||||
|
if fields["rows"] != int64(1000) {
|
||||||
|
t.Fatalf("expected rows=1000, got %v", fields["rows"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["elapsed_ms"]; !ok {
|
||||||
|
t.Fatal("expected elapsed_ms field for slow query")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArgsToFields(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []interface{}
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"empty", nil, 0},
|
||||||
|
{"single_pair", []interface{}{"key", "value"}, 1},
|
||||||
|
{"two_pairs", []interface{}{"a", 1, "b", 2}, 2},
|
||||||
|
{"odd_args_ignores_last", []interface{}{"key", "value", "orphan"}, 1},
|
||||||
|
{"non_string_key_ignored", []interface{}{123, "value"}, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
fields := argsToFields(tt.args)
|
||||||
|
if len(fields) != tt.expected {
|
||||||
|
t.Fatalf("expected %d fields, got %d", tt.expected, len(fields))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArgsToFields_FieldValues(t *testing.T) {
|
||||||
|
fields := argsToFields([]interface{}{"name", "test", "count", 42})
|
||||||
|
if len(fields) != 2 {
|
||||||
|
t.Fatalf("expected 2 fields, got %d", len(fields))
|
||||||
|
}
|
||||||
|
if fields[0].Key != "name" {
|
||||||
|
t.Fatalf("expected key 'name', got %q", fields[0].Key)
|
||||||
|
}
|
||||||
|
if fields[1].Key != "count" {
|
||||||
|
t.Fatalf("expected key 'count', got %q", fields[1].Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGormLevel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected zapcore.Level
|
||||||
|
}{
|
||||||
|
{"debug", zapcore.DebugLevel},
|
||||||
|
{"info", zapcore.InfoLevel},
|
||||||
|
{"warn", zapcore.WarnLevel},
|
||||||
|
{"error", zapcore.ErrorLevel},
|
||||||
|
{"", zapcore.WarnLevel},
|
||||||
|
{"silent", zapcore.WarnLevel},
|
||||||
|
{"invalid", zapcore.WarnLevel},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := parseGormLevel(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Fatalf("expected %v, got %v", tt.expected, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormLogger_InfoWithMultipleFields(t *testing.T) {
|
||||||
|
log, recorded := newObservedLogger()
|
||||||
|
gl := NewGormLogger(log, "info")
|
||||||
|
|
||||||
|
gl.Info(context.Background(), "multi fields", "key1", "val1", "key2", 123, "key3", true)
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["key1"] != "val1" {
|
||||||
|
t.Fatalf("expected key1=val1, got %v", fields["key1"])
|
||||||
|
}
|
||||||
|
if fields["key2"] != int64(123) {
|
||||||
|
t.Fatalf("expected key2=123, got %v", fields["key2"])
|
||||||
|
}
|
||||||
|
if fields["key3"] != true {
|
||||||
|
t.Fatalf("expected key3=true, got %v", fields["key3"])
|
||||||
|
}
|
||||||
|
}
|
||||||
93
internal/logger/logger.go
Normal file
93
internal/logger/logger.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewLogger(cfg config.LogConfig) (*zap.Logger, error) {
|
||||||
|
level := applyDefault(cfg.Level, "info")
|
||||||
|
|
||||||
|
var zapLevel zapcore.Level
|
||||||
|
if err := zapLevel.UnmarshalText([]byte(level)); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid log level %q: %w", level, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encoding := applyDefault(cfg.Encoding, "json")
|
||||||
|
encoderConfig := zap.NewProductionEncoderConfig()
|
||||||
|
encoderConfig.TimeKey = "ts"
|
||||||
|
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||||
|
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
||||||
|
|
||||||
|
var encoder zapcore.Encoder
|
||||||
|
switch encoding {
|
||||||
|
case "console":
|
||||||
|
encoder = zapcore.NewConsoleEncoder(encoderConfig)
|
||||||
|
default:
|
||||||
|
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
var syncers []zapcore.WriteSyncer
|
||||||
|
|
||||||
|
stdout := true
|
||||||
|
if cfg.OutputStdout != nil {
|
||||||
|
stdout = *cfg.OutputStdout
|
||||||
|
}
|
||||||
|
if stdout {
|
||||||
|
syncers = append(syncers, zapcore.AddSync(os.Stdout))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.FilePath != "" {
|
||||||
|
maxSize := applyDefaultInt(cfg.MaxSize, 100)
|
||||||
|
maxBackups := applyDefaultInt(cfg.MaxBackups, 5)
|
||||||
|
maxAge := applyDefaultInt(cfg.MaxAge, 30)
|
||||||
|
compress := cfg.Compress || cfg.MaxSize == 0 && cfg.MaxBackups == 0 && cfg.MaxAge == 0
|
||||||
|
|
||||||
|
lj := &lumberjack.Logger{
|
||||||
|
Filename: cfg.FilePath,
|
||||||
|
MaxSize: maxSize,
|
||||||
|
MaxBackups: maxBackups,
|
||||||
|
MaxAge: maxAge,
|
||||||
|
Compress: compress,
|
||||||
|
}
|
||||||
|
syncers = append(syncers, zapcore.AddSync(lj))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(syncers) == 0 {
|
||||||
|
syncers = append(syncers, zapcore.AddSync(os.Stdout))
|
||||||
|
}
|
||||||
|
|
||||||
|
writeSyncer := syncers[0]
|
||||||
|
if len(syncers) > 1 {
|
||||||
|
writeSyncer = zapcore.NewMultiWriteSyncer(syncers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
core := zapcore.NewCore(encoder, writeSyncer, zapLevel)
|
||||||
|
|
||||||
|
opts := []zap.Option{
|
||||||
|
zap.AddCaller(),
|
||||||
|
zap.AddStacktrace(zapcore.ErrorLevel),
|
||||||
|
}
|
||||||
|
|
||||||
|
return zap.New(core, opts...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDefault(val, def string) string {
|
||||||
|
if val == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDefaultInt(val, def int) int {
|
||||||
|
if val == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
286
internal/logger/logger_test.go
Normal file
286
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ptrBool(v bool) *bool { return &v }
|
||||||
|
|
||||||
|
// TestNewLogger_JSONConfig creates a logger with JSON encoding and verifies
|
||||||
|
// that log entries are emitted successfully.
|
||||||
|
func TestNewLogger_JSONConfig(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "debug",
|
||||||
|
Encoding: "json",
|
||||||
|
OutputStdout: ptrBool(true),
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
// Should not panic when logging
|
||||||
|
log.Info("json logger test", zap.String("key", "value"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_ConsoleConfig creates a logger with console encoding.
|
||||||
|
func TestNewLogger_ConsoleConfig(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "console",
|
||||||
|
OutputStdout: ptrBool(true),
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
log.Info("console logger test", zap.Int("num", 42))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_InvalidLevel verifies that an invalid log level returns an error.
|
||||||
|
func TestNewLogger_InvalidLevel(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "bogus",
|
||||||
|
Encoding: "json",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewLogger(cfg)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid log level, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_EmptyConfig verifies defaults are applied when config is zero-value.
|
||||||
|
func TestNewLogger_EmptyConfig(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{} // all zero values
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
log.Info("default config test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_FileOutput verifies that file output with rotation config works.
|
||||||
|
func TestNewLogger_FileOutput(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
logFile := filepath.Join(tmpDir, "test.log")
|
||||||
|
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "json",
|
||||||
|
FilePath: logFile,
|
||||||
|
MaxSize: 10,
|
||||||
|
MaxBackups: 3,
|
||||||
|
MaxAge: 7,
|
||||||
|
Compress: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("file output test", zap.String("msg", "hello"))
|
||||||
|
log.Sync()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(logFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
t.Fatal("log file is empty, expected output")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(data), "file output test") {
|
||||||
|
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_MultiWriter verifies that both stdout and file output work together.
|
||||||
|
func TestNewLogger_MultiWriter(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
logFile := filepath.Join(tmpDir, "multi.log")
|
||||||
|
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "json",
|
||||||
|
OutputStdout: ptrBool(true),
|
||||||
|
FilePath: logFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("multi writer test", zap.String("writer", "both"))
|
||||||
|
log.Sync()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(logFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(data), "multi writer test") {
|
||||||
|
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_Observer verifies actual log output content using zaptest.
|
||||||
|
func TestNewLogger_Observer(t *testing.T) {
|
||||||
|
// Use zaptest.NewLogger to capture logs in test output
|
||||||
|
log := zaptest.NewLogger(t,
|
||||||
|
zaptest.WrapOptions(zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)),
|
||||||
|
zaptest.Level(zapcore.DebugLevel),
|
||||||
|
)
|
||||||
|
|
||||||
|
// These should all succeed without panicking
|
||||||
|
log.Debug("debug msg", zap.String("k", "v"))
|
||||||
|
log.Info("info msg", zap.Int("n", 1))
|
||||||
|
log.Warn("warn msg")
|
||||||
|
log.Error("error msg")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_AllLevels verifies all valid log levels parse correctly.
|
||||||
|
func TestNewLogger_AllLevels(t *testing.T) {
|
||||||
|
levels := []string{"debug", "info", "warn", "error", "dpanic", "panic", "fatal"}
|
||||||
|
for _, level := range levels {
|
||||||
|
t.Run(level, func(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: level,
|
||||||
|
Encoding: "json",
|
||||||
|
OutputStdout: ptrBool(true),
|
||||||
|
}
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("level %q: NewLogger returned error: %v", level, err)
|
||||||
|
}
|
||||||
|
log.Sync()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_InvalidEncoding falls back gracefully — the factory should
|
||||||
|
// treat an unrecognized encoding as an error or default to JSON.
|
||||||
|
func TestNewLogger_InvalidEncoding(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "xml",
|
||||||
|
OutputStdout: ptrBool(true),
|
||||||
|
}
|
||||||
|
|
||||||
|
// The implementation should default to JSON for unknown encoding.
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error for invalid encoding: %v", err)
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
log.Info("invalid encoding test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewLogger_DefaultRotation verifies rotation defaults are applied.
|
||||||
|
func TestNewLogger_DefaultRotation(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
logFile := filepath.Join(tmpDir, "rotation.log")
|
||||||
|
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "json",
|
||||||
|
FilePath: logFile,
|
||||||
|
// MaxSize, MaxBackups, MaxAge, Compress all zero → defaults apply
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("rotation defaults test")
|
||||||
|
log.Sync()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(logFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
t.Fatal("log file is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLogger_OutputStdoutNil(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "json",
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
log.Info("default stdout test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLogger_OutputStdoutFalseWithFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
logFile := filepath.Join(tmpDir, "nostdout.log")
|
||||||
|
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "json",
|
||||||
|
OutputStdout: ptrBool(false),
|
||||||
|
FilePath: logFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
log.Info("file only test")
|
||||||
|
log.Sync()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(logFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read log file: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(data), "file only test") {
|
||||||
|
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLogger_OutputStdoutFalseFallback(t *testing.T) {
|
||||||
|
cfg := config.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Encoding: "json",
|
||||||
|
OutputStdout: ptrBool(false),
|
||||||
|
}
|
||||||
|
|
||||||
|
log, err := NewLogger(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewLogger returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
log.Info("fallback stdout test")
|
||||||
|
}
|
||||||
25
internal/middleware/logger.go
Normal file
25
internal/middleware/logger.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestLogger returns a Gin middleware that logs each request using zap.
|
||||||
|
func RequestLogger(logger *zap.Logger) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
logger.Info("request",
|
||||||
|
zap.String("method", c.Request.Method),
|
||||||
|
zap.String("path", c.Request.URL.Path),
|
||||||
|
zap.Int("status", c.Writer.Status()),
|
||||||
|
zap.Duration("latency", time.Since(start)),
|
||||||
|
zap.String("client_ip", c.ClientIP()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
23
internal/model/cluster.go
Normal file
23
internal/model/cluster.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
// NodeResponse is the simplified API response for a node.
|
||||||
|
type NodeResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
State []string `json:"state"`
|
||||||
|
CPUs int32 `json:"cpus"`
|
||||||
|
RealMemory int64 `json:"real_memory"`
|
||||||
|
AllocMem int64 `json:"alloc_memory,omitempty"`
|
||||||
|
Arch string `json:"architecture,omitempty"`
|
||||||
|
OS string `json:"operating_system,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PartitionResponse is the simplified API response for a partition.
|
||||||
|
type PartitionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
State []string `json:"state"`
|
||||||
|
Nodes string `json:"nodes,omitempty"`
|
||||||
|
TotalCPUs int32 `json:"total_cpus,omitempty"`
|
||||||
|
TotalNodes int32 `json:"total_nodes,omitempty"`
|
||||||
|
MaxTime string `json:"max_time,omitempty"`
|
||||||
|
Default bool `json:"default,omitempty"`
|
||||||
|
}
|
||||||
47
internal/model/job.go
Normal file
47
internal/model/job.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
// SubmitJobRequest is the API request for submitting a job.
|
||||||
|
type SubmitJobRequest struct {
|
||||||
|
Script string `json:"script" binding:"required"`
|
||||||
|
Partition string `json:"partition,omitempty"`
|
||||||
|
QOS string `json:"qos,omitempty"`
|
||||||
|
CPUs int32 `json:"cpus,omitempty"`
|
||||||
|
Memory string `json:"memory,omitempty"`
|
||||||
|
TimeLimit string `json:"time_limit,omitempty"`
|
||||||
|
JobName string `json:"job_name,omitempty"`
|
||||||
|
Environment map[string]string `json:"environment,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobResponse is the simplified API response for a job.
|
||||||
|
type JobResponse struct {
|
||||||
|
JobID int32 `json:"job_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
State []string `json:"job_state"`
|
||||||
|
Partition string `json:"partition"`
|
||||||
|
SubmitTime *int64 `json:"submit_time,omitempty"`
|
||||||
|
StartTime *int64 `json:"start_time,omitempty"`
|
||||||
|
EndTime *int64 `json:"end_time,omitempty"`
|
||||||
|
ExitCode *int32 `json:"exit_code,omitempty"`
|
||||||
|
Nodes string `json:"nodes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobListResponse is the paginated response for job listings.
|
||||||
|
type JobListResponse struct {
|
||||||
|
Jobs []JobResponse `json:"jobs"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobHistoryQuery contains query parameters for job history.
|
||||||
|
type JobHistoryQuery struct {
|
||||||
|
Users string `form:"users" json:"users,omitempty"`
|
||||||
|
StartTime string `form:"start_time" json:"start_time,omitempty"`
|
||||||
|
EndTime string `form:"end_time" json:"end_time,omitempty"`
|
||||||
|
Account string `form:"account" json:"account,omitempty"`
|
||||||
|
Partition string `form:"partition" json:"partition,omitempty"`
|
||||||
|
State string `form:"state" json:"state,omitempty"`
|
||||||
|
JobName string `form:"job_name" json:"job_name,omitempty"`
|
||||||
|
Page int `form:"page,default=1" json:"page,omitempty"`
|
||||||
|
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"`
|
||||||
|
}
|
||||||
45
internal/model/template.go
Normal file
45
internal/model/template.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// JobTemplate represents a saved job template.
|
||||||
|
type JobTemplate struct {
|
||||||
|
ID int64 `json:"id" gorm:"primaryKey;autoIncrement"`
|
||||||
|
Name string `json:"name" gorm:"uniqueIndex;size:255;not null"`
|
||||||
|
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||||
|
Script string `json:"script" gorm:"type:text;not null"`
|
||||||
|
Partition string `json:"partition,omitempty" gorm:"size:255"`
|
||||||
|
QOS string `json:"qos,omitempty" gorm:"column:qos;size:255"`
|
||||||
|
CPUs int `json:"cpus,omitempty" gorm:"column:cpus"`
|
||||||
|
Memory string `json:"memory,omitempty" gorm:"size:50"`
|
||||||
|
TimeLimit string `json:"time_limit,omitempty" gorm:"column:time_limit;size:50"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName specifies the database table name for GORM.
|
||||||
|
func (JobTemplate) TableName() string { return "job_templates" }
|
||||||
|
|
||||||
|
// CreateTemplateRequest is the API request for creating a template.
|
||||||
|
type CreateTemplateRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Script string `json:"script" binding:"required"`
|
||||||
|
Partition string `json:"partition,omitempty"`
|
||||||
|
QOS string `json:"qos,omitempty"`
|
||||||
|
CPUs int `json:"cpus,omitempty"`
|
||||||
|
Memory string `json:"memory,omitempty"`
|
||||||
|
TimeLimit string `json:"time_limit,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTemplateRequest is the API request for updating a template.
|
||||||
|
type UpdateTemplateRequest struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Script string `json:"script,omitempty"`
|
||||||
|
Partition string `json:"partition,omitempty"`
|
||||||
|
QOS string `json:"qos,omitempty"`
|
||||||
|
CPUs int `json:"cpus,omitempty"`
|
||||||
|
Memory string `json:"memory,omitempty"`
|
||||||
|
TimeLimit string `json:"time_limit,omitempty"`
|
||||||
|
}
|
||||||
44
internal/server/response.go
Normal file
44
internal/server/response.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// APIResponse represents the unified JSON structure for all API responses.
|
||||||
|
type APIResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OK responds with 200 and success data.
|
||||||
|
func OK(c *gin.Context, data interface{}) {
|
||||||
|
c.JSON(http.StatusOK, APIResponse{Success: true, Data: data})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Created responds with 201 and success data.
|
||||||
|
func Created(c *gin.Context, data interface{}) {
|
||||||
|
c.JSON(http.StatusCreated, APIResponse{Success: true, Data: data})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadRequest responds with 400 and an error message.
|
||||||
|
func BadRequest(c *gin.Context, msg string) {
|
||||||
|
c.JSON(http.StatusBadRequest, APIResponse{Success: false, Error: msg})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotFound responds with 404 and an error message.
|
||||||
|
func NotFound(c *gin.Context, msg string) {
|
||||||
|
c.JSON(http.StatusNotFound, APIResponse{Success: false, Error: msg})
|
||||||
|
}
|
||||||
|
|
||||||
|
// InternalError responds with 500 and an error message.
|
||||||
|
func InternalError(c *gin.Context, msg string) {
|
||||||
|
c.JSON(http.StatusInternalServerError, APIResponse{Success: false, Error: msg})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorWithStatus responds with a custom status code and an error message.
|
||||||
|
func ErrorWithStatus(c *gin.Context, code int, msg string) {
|
||||||
|
c.JSON(code, APIResponse{Success: false, Error: msg})
|
||||||
|
}
|
||||||
116
internal/server/response_test.go
Normal file
116
internal/server/response_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestContext() (*httptest.ResponseRecorder, *gin.Context) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
return w, c
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseResponse(t *testing.T, w *httptest.ResponseRecorder) APIResponse {
|
||||||
|
t.Helper()
|
||||||
|
var resp APIResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response body: %v", err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOK(t *testing.T) {
|
||||||
|
w, c := setupTestContext()
|
||||||
|
OK(c, map[string]string{"msg": "hello"})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
resp := parseResponse(t, w)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreated(t *testing.T) {
|
||||||
|
w, c := setupTestContext()
|
||||||
|
Created(c, map[string]int{"id": 1})
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d", w.Code)
|
||||||
|
}
|
||||||
|
resp := parseResponse(t, w)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBadRequest(t *testing.T) {
|
||||||
|
w, c := setupTestContext()
|
||||||
|
BadRequest(c, "invalid input")
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
resp := parseResponse(t, w)
|
||||||
|
if resp.Success {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
if resp.Error != "invalid input" {
|
||||||
|
t.Fatalf("expected error 'invalid input', got '%s'", resp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotFound(t *testing.T) {
|
||||||
|
w, c := setupTestContext()
|
||||||
|
NotFound(c, "resource missing")
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
resp := parseResponse(t, w)
|
||||||
|
if resp.Success {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
if resp.Error != "resource missing" {
|
||||||
|
t.Fatalf("expected error 'resource missing', got '%s'", resp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInternalError(t *testing.T) {
|
||||||
|
w, c := setupTestContext()
|
||||||
|
InternalError(c, "something broke")
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
resp := parseResponse(t, w)
|
||||||
|
if resp.Success {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
if resp.Error != "something broke" {
|
||||||
|
t.Fatalf("expected error 'something broke', got '%s'", resp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorWithStatus(t *testing.T) {
|
||||||
|
w, c := setupTestContext()
|
||||||
|
ErrorWithStatus(c, http.StatusConflict, "already exists")
|
||||||
|
|
||||||
|
if w.Code != http.StatusConflict {
|
||||||
|
t.Fatalf("expected 409, got %d", w.Code)
|
||||||
|
}
|
||||||
|
resp := parseResponse(t, w)
|
||||||
|
if resp.Success {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
if resp.Error != "already exists" {
|
||||||
|
t.Fatalf("expected error 'already exists', got '%s'", resp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
111
internal/server/server.go
Normal file
111
internal/server/server.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/middleware"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JobHandler interface {
|
||||||
|
SubmitJob(c *gin.Context)
|
||||||
|
GetJobs(c *gin.Context)
|
||||||
|
GetJobHistory(c *gin.Context)
|
||||||
|
GetJob(c *gin.Context)
|
||||||
|
CancelJob(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClusterHandler interface {
|
||||||
|
GetNodes(c *gin.Context)
|
||||||
|
GetNode(c *gin.Context)
|
||||||
|
GetPartitions(c *gin.Context)
|
||||||
|
GetPartition(c *gin.Context)
|
||||||
|
GetDiag(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TemplateHandler interface {
|
||||||
|
ListTemplates(c *gin.Context)
|
||||||
|
CreateTemplate(c *gin.Context)
|
||||||
|
GetTemplate(c *gin.Context)
|
||||||
|
UpdateTemplate(c *gin.Context)
|
||||||
|
DeleteTemplate(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
|
||||||
|
func NewRouter(jobH JobHandler, clusterH ClusterHandler, templateH TemplateHandler, logger *zap.Logger) *gin.Engine {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(gin.Recovery())
|
||||||
|
if logger != nil {
|
||||||
|
r.Use(middleware.RequestLogger(logger))
|
||||||
|
}
|
||||||
|
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
|
||||||
|
jobs := v1.Group("/jobs")
|
||||||
|
jobs.POST("/submit", jobH.SubmitJob)
|
||||||
|
jobs.GET("", jobH.GetJobs)
|
||||||
|
jobs.GET("/history", jobH.GetJobHistory)
|
||||||
|
jobs.GET("/:id", jobH.GetJob)
|
||||||
|
jobs.DELETE("/:id", jobH.CancelJob)
|
||||||
|
|
||||||
|
v1.GET("/nodes", clusterH.GetNodes)
|
||||||
|
v1.GET("/nodes/:name", clusterH.GetNode)
|
||||||
|
|
||||||
|
v1.GET("/partitions", clusterH.GetPartitions)
|
||||||
|
v1.GET("/partitions/:name", clusterH.GetPartition)
|
||||||
|
|
||||||
|
v1.GET("/diag", clusterH.GetDiag)
|
||||||
|
|
||||||
|
templates := v1.Group("/templates")
|
||||||
|
templates.GET("", templateH.ListTemplates)
|
||||||
|
templates.POST("", templateH.CreateTemplate)
|
||||||
|
templates.GET("/:id", templateH.GetTemplate)
|
||||||
|
templates.PUT("/:id", templateH.UpdateTemplate)
|
||||||
|
templates.DELETE("/:id", templateH.DeleteTemplate)
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestRouter creates a router for testing without real handlers.
|
||||||
|
func NewTestRouter() *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(gin.Recovery())
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
registerPlaceholderRoutes(v1)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
|
||||||
|
jobs := v1.Group("/jobs")
|
||||||
|
jobs.POST("/submit", notImplemented)
|
||||||
|
jobs.GET("", notImplemented)
|
||||||
|
jobs.GET("/history", notImplemented)
|
||||||
|
jobs.GET("/:id", notImplemented)
|
||||||
|
jobs.DELETE("/:id", notImplemented)
|
||||||
|
|
||||||
|
v1.GET("/nodes", notImplemented)
|
||||||
|
v1.GET("/nodes/:name", notImplemented)
|
||||||
|
|
||||||
|
v1.GET("/partitions", notImplemented)
|
||||||
|
v1.GET("/partitions/:name", notImplemented)
|
||||||
|
|
||||||
|
v1.GET("/diag", notImplemented)
|
||||||
|
|
||||||
|
templates := v1.Group("/templates")
|
||||||
|
templates.GET("", notImplemented)
|
||||||
|
templates.POST("", notImplemented)
|
||||||
|
templates.GET("/:id", notImplemented)
|
||||||
|
templates.PUT("/:id", notImplemented)
|
||||||
|
templates.DELETE("/:id", notImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
func notImplemented(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusNotImplemented, APIResponse{
|
||||||
|
Success: false,
|
||||||
|
Error: "not implemented",
|
||||||
|
})
|
||||||
|
}
|
||||||
108
internal/server/server_test.go
Normal file
108
internal/server/server_test.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAllRoutesRegistered(t *testing.T) {
|
||||||
|
r := NewTestRouter()
|
||||||
|
routes := r.Routes()
|
||||||
|
|
||||||
|
expected := []struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
}{
|
||||||
|
{"POST", "/api/v1/jobs/submit"},
|
||||||
|
{"GET", "/api/v1/jobs"},
|
||||||
|
{"GET", "/api/v1/jobs/history"},
|
||||||
|
{"GET", "/api/v1/jobs/:id"},
|
||||||
|
{"DELETE", "/api/v1/jobs/:id"},
|
||||||
|
{"GET", "/api/v1/nodes"},
|
||||||
|
{"GET", "/api/v1/nodes/:name"},
|
||||||
|
{"GET", "/api/v1/partitions"},
|
||||||
|
{"GET", "/api/v1/partitions/:name"},
|
||||||
|
{"GET", "/api/v1/diag"},
|
||||||
|
{"GET", "/api/v1/templates"},
|
||||||
|
{"POST", "/api/v1/templates"},
|
||||||
|
{"GET", "/api/v1/templates/:id"},
|
||||||
|
{"PUT", "/api/v1/templates/:id"},
|
||||||
|
{"DELETE", "/api/v1/templates/:id"},
|
||||||
|
}
|
||||||
|
|
||||||
|
routeMap := map[string]bool{}
|
||||||
|
for _, route := range routes {
|
||||||
|
key := route.Method + " " + route.Path
|
||||||
|
routeMap[key] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range expected {
|
||||||
|
key := exp.method + " " + exp.path
|
||||||
|
if !routeMap[key] {
|
||||||
|
t.Errorf("missing route: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routes) < len(expected) {
|
||||||
|
t.Errorf("expected at least %d routes, got %d", len(expected), len(routes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnregisteredPathReturns404(t *testing.T) {
|
||||||
|
r := NewTestRouter()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/nonexistent", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404 for unregistered path, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisteredPathReturns501(t *testing.T) {
|
||||||
|
r := NewTestRouter()
|
||||||
|
|
||||||
|
endpoints := []struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
}{
|
||||||
|
{"GET", "/api/v1/jobs"},
|
||||||
|
{"GET", "/api/v1/nodes"},
|
||||||
|
{"GET", "/api/v1/partitions"},
|
||||||
|
{"GET", "/api/v1/diag"},
|
||||||
|
{"GET", "/api/v1/templates"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(ep.method, ep.path, nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotImplemented {
|
||||||
|
t.Fatalf("%s %s: expected 501, got %d", ep.method, ep.path, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp APIResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Success {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
if resp.Error != "not implemented" {
|
||||||
|
t.Fatalf("expected error 'not implemented', got '%s'", resp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterUsesGinMode(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := NewTestRouter()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("NewRouter returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
167
internal/service/cluster_service.go
Normal file
167
internal/service/cluster_service.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func derefStr(s *string) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *s
|
||||||
|
}
|
||||||
|
|
||||||
|
func derefInt32(i *int32) int32 {
|
||||||
|
if i == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *i
|
||||||
|
}
|
||||||
|
|
||||||
|
func derefInt64(i *int64) int64 {
|
||||||
|
if i == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *i
|
||||||
|
}
|
||||||
|
|
||||||
|
func uint32NoValString(v *slurm.Uint32NoVal) string {
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v.Infinite != nil && *v.Infinite {
|
||||||
|
return "UNLIMITED"
|
||||||
|
}
|
||||||
|
if v.Number != nil {
|
||||||
|
return strconv.FormatInt(*v.Number, 10)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClusterService struct {
|
||||||
|
client *slurm.Client
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClusterService(client *slurm.Client, logger *zap.Logger) *ClusterService {
|
||||||
|
return &ClusterService{client: client, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) {
|
||||||
|
resp, _, err := s.client.Nodes.GetNodes(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get nodes", zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("get nodes: %w", err)
|
||||||
|
}
|
||||||
|
if resp.Nodes == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
result := make([]model.NodeResponse, 0, len(*resp.Nodes))
|
||||||
|
for _, n := range *resp.Nodes {
|
||||||
|
result = append(result, mapNode(n))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) {
|
||||||
|
resp, _, err := s.client.Nodes.GetNode(ctx, name, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("get node %s: %w", name, err)
|
||||||
|
}
|
||||||
|
if resp.Nodes == nil || len(*resp.Nodes) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
n := (*resp.Nodes)[0]
|
||||||
|
mapped := mapNode(n)
|
||||||
|
return &mapped, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) {
|
||||||
|
resp, _, err := s.client.Partitions.GetPartitions(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get partitions", zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("get partitions: %w", err)
|
||||||
|
}
|
||||||
|
if resp.Partitions == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
result := make([]model.PartitionResponse, 0, len(*resp.Partitions))
|
||||||
|
for _, pi := range *resp.Partitions {
|
||||||
|
result = append(result, mapPartition(pi))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) {
|
||||||
|
resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("get partition %s: %w", name, err)
|
||||||
|
}
|
||||||
|
if resp.Partitions == nil || len(*resp.Partitions) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
p := (*resp.Partitions)[0]
|
||||||
|
mapped := mapPartition(p)
|
||||||
|
return &mapped, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) {
|
||||||
|
resp, _, err := s.client.Diag.GetDiag(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get diag", zap.Error(err))
|
||||||
|
return nil, fmt.Errorf("get diag: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapNode(n slurm.Node) model.NodeResponse {
|
||||||
|
return model.NodeResponse{
|
||||||
|
Name: derefStr(n.Name),
|
||||||
|
State: n.State,
|
||||||
|
CPUs: derefInt32(n.Cpus),
|
||||||
|
RealMemory: derefInt64(n.RealMemory),
|
||||||
|
AllocMem: derefInt64(n.AllocMemory),
|
||||||
|
Arch: derefStr(n.Architecture),
|
||||||
|
OS: derefStr(n.OperatingSystem),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
|
||||||
|
var state []string
|
||||||
|
if pi.Partition != nil {
|
||||||
|
state = pi.Partition.State
|
||||||
|
}
|
||||||
|
var nodes string
|
||||||
|
if pi.Nodes != nil {
|
||||||
|
nodes = derefStr(pi.Nodes.Configured)
|
||||||
|
}
|
||||||
|
var totalCPUs int32
|
||||||
|
if pi.CPUs != nil {
|
||||||
|
totalCPUs = derefInt32(pi.CPUs.Total)
|
||||||
|
}
|
||||||
|
var totalNodes int32
|
||||||
|
if pi.Nodes != nil {
|
||||||
|
totalNodes = derefInt32(pi.Nodes.Total)
|
||||||
|
}
|
||||||
|
var maxTime string
|
||||||
|
if pi.Maximums != nil {
|
||||||
|
maxTime = uint32NoValString(pi.Maximums.Time)
|
||||||
|
}
|
||||||
|
return model.PartitionResponse{
|
||||||
|
Name: derefStr(pi.Name),
|
||||||
|
State: state,
|
||||||
|
Nodes: nodes,
|
||||||
|
TotalCPUs: totalCPUs,
|
||||||
|
TotalNodes: totalNodes,
|
||||||
|
MaxTime: maxTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
467
internal/service/cluster_service_test.go
Normal file
467
internal/service/cluster_service_test.go
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockServer(handler http.HandlerFunc) (*slurm.Client, func()) {
|
||||||
|
srv := httptest.NewServer(handler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
return client, srv.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNodes(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/slurm/v0.0.40/nodes" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "node1",
|
||||||
|
"state": []string{"IDLE"},
|
||||||
|
"cpus": 64,
|
||||||
|
"real_memory": 256000,
|
||||||
|
"alloc_memory": 0,
|
||||||
|
"architecture": "x86_64",
|
||||||
|
"operating_system": "Linux 5.15",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
nodes, err := svc.GetNodes(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNodes returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(nodes) != 1 {
|
||||||
|
t.Fatalf("expected 1 node, got %d", len(nodes))
|
||||||
|
}
|
||||||
|
n := nodes[0]
|
||||||
|
if n.Name != "node1" {
|
||||||
|
t.Errorf("expected name node1, got %s", n.Name)
|
||||||
|
}
|
||||||
|
if len(n.State) != 1 || n.State[0] != "IDLE" {
|
||||||
|
t.Errorf("expected state [IDLE], got %v", n.State)
|
||||||
|
}
|
||||||
|
if n.CPUs != 64 {
|
||||||
|
t.Errorf("expected 64 CPUs, got %d", n.CPUs)
|
||||||
|
}
|
||||||
|
if n.RealMemory != 256000 {
|
||||||
|
t.Errorf("expected real_memory 256000, got %d", n.RealMemory)
|
||||||
|
}
|
||||||
|
if n.Arch != "x86_64" {
|
||||||
|
t.Errorf("expected arch x86_64, got %s", n.Arch)
|
||||||
|
}
|
||||||
|
if n.OS != "Linux 5.15" {
|
||||||
|
t.Errorf("expected OS 'Linux 5.15', got %s", n.OS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNodes_Empty(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
nodes, err := svc.GetNodes(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNodes returned error: %v", err)
|
||||||
|
}
|
||||||
|
if nodes != nil {
|
||||||
|
t.Errorf("expected nil for empty response, got %v", nodes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNode(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/slurm/v0.0.40/node/node1" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"nodes": []map[string]interface{}{
|
||||||
|
{"name": "node1", "state": []string{"ALLOCATED"}, "cpus": 32},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
node, err := svc.GetNode(context.Background(), "node1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNode returned error: %v", err)
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
t.Fatal("expected node, got nil")
|
||||||
|
}
|
||||||
|
if node.Name != "node1" {
|
||||||
|
t.Errorf("expected name node1, got %s", node.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNode_NotFound(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
node, err := svc.GetNode(context.Background(), "missing")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNode returned error: %v", err)
|
||||||
|
}
|
||||||
|
if node != nil {
|
||||||
|
t.Errorf("expected nil for missing node, got %+v", node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartitions(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/slurm/v0.0.40/partitions" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "normal",
|
||||||
|
"partition": map[string]interface{}{
|
||||||
|
"state": []string{"UP"},
|
||||||
|
},
|
||||||
|
"nodes": map[string]interface{}{
|
||||||
|
"configured": "node[1-10]",
|
||||||
|
"total": 10,
|
||||||
|
},
|
||||||
|
"cpus": map[string]interface{}{
|
||||||
|
"total": 640,
|
||||||
|
},
|
||||||
|
"maximums": map[string]interface{}{
|
||||||
|
"time": map[string]interface{}{
|
||||||
|
"set": true,
|
||||||
|
"infinite": false,
|
||||||
|
"number": 86400,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
partitions, err := svc.GetPartitions(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPartitions returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(partitions) != 1 {
|
||||||
|
t.Fatalf("expected 1 partition, got %d", len(partitions))
|
||||||
|
}
|
||||||
|
p := partitions[0]
|
||||||
|
if p.Name != "normal" {
|
||||||
|
t.Errorf("expected name normal, got %s", p.Name)
|
||||||
|
}
|
||||||
|
if len(p.State) != 1 || p.State[0] != "UP" {
|
||||||
|
t.Errorf("expected state [UP], got %v", p.State)
|
||||||
|
}
|
||||||
|
if p.Nodes != "node[1-10]" {
|
||||||
|
t.Errorf("expected nodes 'node[1-10]', got %s", p.Nodes)
|
||||||
|
}
|
||||||
|
if p.TotalCPUs != 640 {
|
||||||
|
t.Errorf("expected 640 total CPUs, got %d", p.TotalCPUs)
|
||||||
|
}
|
||||||
|
if p.TotalNodes != 10 {
|
||||||
|
t.Errorf("expected 10 total nodes, got %d", p.TotalNodes)
|
||||||
|
}
|
||||||
|
if p.MaxTime != "86400" {
|
||||||
|
t.Errorf("expected max_time '86400', got %s", p.MaxTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartitions_Empty(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
partitions, err := svc.GetPartitions(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPartitions returned error: %v", err)
|
||||||
|
}
|
||||||
|
if partitions != nil {
|
||||||
|
t.Errorf("expected nil for empty response, got %v", partitions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartition(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/slurm/v0.0.40/partition/gpu" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"partitions": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "gpu",
|
||||||
|
"partition": map[string]interface{}{
|
||||||
|
"state": []string{"UP"},
|
||||||
|
},
|
||||||
|
"nodes": map[string]interface{}{
|
||||||
|
"configured": "gpu[1-4]",
|
||||||
|
"total": 4,
|
||||||
|
},
|
||||||
|
"maximums": map[string]interface{}{
|
||||||
|
"time": map[string]interface{}{
|
||||||
|
"set": true,
|
||||||
|
"infinite": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
part, err := svc.GetPartition(context.Background(), "gpu")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPartition returned error: %v", err)
|
||||||
|
}
|
||||||
|
if part == nil {
|
||||||
|
t.Fatal("expected partition, got nil")
|
||||||
|
}
|
||||||
|
if part.Name != "gpu" {
|
||||||
|
t.Errorf("expected name gpu, got %s", part.Name)
|
||||||
|
}
|
||||||
|
if part.MaxTime != "UNLIMITED" {
|
||||||
|
t.Errorf("expected max_time UNLIMITED, got %s", part.MaxTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPartition_NotFound(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
part, err := svc.GetPartition(context.Background(), "missing")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPartition returned error: %v", err)
|
||||||
|
}
|
||||||
|
if part != nil {
|
||||||
|
t.Errorf("expected nil for missing partition, got %+v", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiag(t *testing.T) {
|
||||||
|
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/slurm/v0.0.40/diag" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"statistics": map[string]interface{}{
|
||||||
|
"server_thread_count": 10,
|
||||||
|
"agent_queue_size": 5,
|
||||||
|
"jobs_submitted": 100,
|
||||||
|
"jobs_running": 20,
|
||||||
|
"schedule_queue_length": 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewClusterService(client, zap.NewNop())
|
||||||
|
diag, err := svc.GetDiag(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetDiag returned error: %v", err)
|
||||||
|
}
|
||||||
|
if diag == nil {
|
||||||
|
t.Fatal("expected diag response, got nil")
|
||||||
|
}
|
||||||
|
if diag.Statistics == nil {
|
||||||
|
t.Fatal("expected statistics, got nil")
|
||||||
|
}
|
||||||
|
if diag.Statistics.ServerThreadCount == nil || *diag.Statistics.ServerThreadCount != 10 {
|
||||||
|
t.Errorf("expected server_thread_count 10, got %v", diag.Statistics.ServerThreadCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSlurmClient(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
keyPath := filepath.Join(dir, "jwt.key")
|
||||||
|
os.WriteFile(keyPath, make([]byte, 32), 0644)
|
||||||
|
|
||||||
|
client, err := NewSlurmClient("http://localhost:6820", "root", keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSlurmClient returned error: %v", err)
|
||||||
|
}
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("expected client, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClusterServiceWithObserver(srv *httptest.Server) (*ClusterService, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
return NewClusterService(client, l), recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorServer() *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"errors": [{"error": "internal server error"}]}`))
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterService_GetNodes_ErrorLogging(t *testing.T) {
|
||||||
|
srv := errorServer()
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, logs := newClusterServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetNodes(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if logs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
||||||
|
}
|
||||||
|
entry := logs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
if len(entry.Context) == 0 {
|
||||||
|
t.Error("expected structured fields in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterService_GetNode_ErrorLogging(t *testing.T) {
|
||||||
|
srv := errorServer()
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, logs := newClusterServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetNode(context.Background(), "test-node")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if logs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
||||||
|
}
|
||||||
|
entry := logs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasName := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "name" && f.String == "test-node" {
|
||||||
|
hasName = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasName {
|
||||||
|
t.Error("expected 'name' field with value 'test-node' in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterService_GetPartitions_ErrorLogging(t *testing.T) {
|
||||||
|
srv := errorServer()
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, logs := newClusterServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetPartitions(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if logs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
||||||
|
}
|
||||||
|
entry := logs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
if len(entry.Context) == 0 {
|
||||||
|
t.Error("expected structured fields in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterService_GetPartition_ErrorLogging(t *testing.T) {
|
||||||
|
srv := errorServer()
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, logs := newClusterServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetPartition(context.Background(), "test-partition")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if logs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
||||||
|
}
|
||||||
|
entry := logs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasName := false
|
||||||
|
for _, f := range entry.Context {
|
||||||
|
if f.Key == "name" && f.String == "test-partition" {
|
||||||
|
hasName = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasName {
|
||||||
|
t.Error("expected 'name' field with value 'test-partition' in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClusterService_GetDiag_ErrorLogging(t *testing.T) {
|
||||||
|
srv := errorServer()
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, logs := newClusterServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetDiag(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if logs.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
||||||
|
}
|
||||||
|
entry := logs.All()[0]
|
||||||
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
|
}
|
||||||
|
if len(entry.Context) == 0 {
|
||||||
|
t.Error("expected structured fields in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
246
internal/service/job_service.go
Normal file
246
internal/service/job_service.go
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JobService wraps Slurm SDK job operations with model mapping and pagination.
|
||||||
|
type JobService struct {
|
||||||
|
client *slurm.Client
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJobService creates a new JobService with the given Slurm SDK client.
|
||||||
|
func NewJobService(client *slurm.Client, logger *zap.Logger) *JobService {
|
||||||
|
return &JobService{client: client, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubmitJob submits a new job to Slurm and returns the job ID.
|
||||||
|
func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest) (*model.JobResponse, error) {
|
||||||
|
script := req.Script
|
||||||
|
jobDesc := &slurm.JobDescMsg{
|
||||||
|
Script: &script,
|
||||||
|
Partition: strToPtrOrNil(req.Partition),
|
||||||
|
Qos: strToPtrOrNil(req.QOS),
|
||||||
|
Name: strToPtrOrNil(req.JobName),
|
||||||
|
}
|
||||||
|
if req.CPUs > 0 {
|
||||||
|
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
|
||||||
|
}
|
||||||
|
if req.TimeLimit != "" {
|
||||||
|
if mins, err := strconv.ParseInt(req.TimeLimit, 10, 64); err == nil {
|
||||||
|
jobDesc.TimeLimit = &slurm.Uint32NoVal{Number: &mins}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
submitReq := &slurm.JobSubmitReq{
|
||||||
|
Script: &script,
|
||||||
|
Job: jobDesc,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
|
||||||
|
return nil, fmt.Errorf("submit job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &model.JobResponse{}
|
||||||
|
if result.Result != nil && result.Result.JobID != nil {
|
||||||
|
resp.JobID = *result.Result.JobID
|
||||||
|
} else if result.JobID != nil {
|
||||||
|
resp.JobID = *result.JobID
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("job submitted", zap.String("job_name", req.JobName), zap.Int32("job_id", resp.JobID))
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJobs lists all current jobs from Slurm.
|
||||||
|
func (s *JobService) GetJobs(ctx context.Context) ([]model.JobResponse, error) {
|
||||||
|
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
|
||||||
|
return nil, fmt.Errorf("get jobs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jobs := make([]model.JobResponse, 0, len(result.Jobs))
|
||||||
|
for i := range result.Jobs {
|
||||||
|
jobs = append(jobs, mapJobInfo(&result.Jobs[i]))
|
||||||
|
}
|
||||||
|
return jobs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJob retrieves a single job by ID.
|
||||||
|
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
|
||||||
|
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
|
||||||
|
return nil, fmt.Errorf("get job %s: %w", jobID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Jobs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := mapJobInfo(&result.Jobs[0])
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelJob cancels a job by ID.
|
||||||
|
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
|
||||||
|
_, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
|
||||||
|
return fmt.Errorf("cancel job %s: %w", jobID, err)
|
||||||
|
}
|
||||||
|
s.logger.Info("job cancelled", zap.String("job_id", jobID))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJobHistory queries SlurmDBD for historical jobs with pagination.
|
||||||
|
func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQuery) (*model.JobListResponse, error) {
|
||||||
|
opts := &slurm.GetSlurmdbJobsOptions{}
|
||||||
|
if query.Users != "" {
|
||||||
|
opts.Users = strToPtr(query.Users)
|
||||||
|
}
|
||||||
|
if query.Account != "" {
|
||||||
|
opts.Account = strToPtr(query.Account)
|
||||||
|
}
|
||||||
|
if query.Partition != "" {
|
||||||
|
opts.Partition = strToPtr(query.Partition)
|
||||||
|
}
|
||||||
|
if query.State != "" {
|
||||||
|
opts.State = strToPtr(query.State)
|
||||||
|
}
|
||||||
|
if query.JobName != "" {
|
||||||
|
opts.JobName = strToPtr(query.JobName)
|
||||||
|
}
|
||||||
|
if query.StartTime != "" {
|
||||||
|
opts.StartTime = strToPtr(query.StartTime)
|
||||||
|
}
|
||||||
|
if query.EndTime != "" {
|
||||||
|
opts.EndTime = strToPtr(query.EndTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
|
||||||
|
return nil, fmt.Errorf("get job history: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
|
||||||
|
for i := range result.Jobs {
|
||||||
|
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
total := len(allJobs)
|
||||||
|
page := query.Page
|
||||||
|
pageSize := query.PageSize
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
start := (page - 1) * pageSize
|
||||||
|
end := start + pageSize
|
||||||
|
if start > total {
|
||||||
|
start = total
|
||||||
|
}
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.JobListResponse{
|
||||||
|
Jobs: allJobs[start:end],
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helper functions
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func strToPtr(s string) *string { return &s }
|
||||||
|
|
||||||
|
// strPtrOrNil returns a pointer to s if non-empty, otherwise nil.
|
||||||
|
func strToPtrOrNil(s string) *string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapJobInfo maps SDK JobInfo to API JobResponse.
|
||||||
|
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
|
||||||
|
resp := model.JobResponse{}
|
||||||
|
if ji.JobID != nil {
|
||||||
|
resp.JobID = *ji.JobID
|
||||||
|
}
|
||||||
|
if ji.Name != nil {
|
||||||
|
resp.Name = *ji.Name
|
||||||
|
}
|
||||||
|
resp.State = ji.JobState
|
||||||
|
if ji.Partition != nil {
|
||||||
|
resp.Partition = *ji.Partition
|
||||||
|
}
|
||||||
|
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
|
||||||
|
resp.SubmitTime = ji.SubmitTime.Number
|
||||||
|
}
|
||||||
|
if ji.StartTime != nil && ji.StartTime.Number != nil {
|
||||||
|
resp.StartTime = ji.StartTime.Number
|
||||||
|
}
|
||||||
|
if ji.EndTime != nil && ji.EndTime.Number != nil {
|
||||||
|
resp.EndTime = ji.EndTime.Number
|
||||||
|
}
|
||||||
|
if ji.ExitCode != nil && ji.ExitCode.ReturnCode != nil && ji.ExitCode.ReturnCode.Number != nil {
|
||||||
|
code := int32(*ji.ExitCode.ReturnCode.Number)
|
||||||
|
resp.ExitCode = &code
|
||||||
|
}
|
||||||
|
if ji.Nodes != nil {
|
||||||
|
resp.Nodes = *ji.Nodes
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapSlurmdbJob maps SDK SlurmDBD Job to API JobResponse.
|
||||||
|
func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
|
||||||
|
resp := model.JobResponse{}
|
||||||
|
if j.JobID != nil {
|
||||||
|
resp.JobID = *j.JobID
|
||||||
|
}
|
||||||
|
if j.Name != nil {
|
||||||
|
resp.Name = *j.Name
|
||||||
|
}
|
||||||
|
if j.State != nil {
|
||||||
|
resp.State = j.State.Current
|
||||||
|
}
|
||||||
|
if j.Partition != nil {
|
||||||
|
resp.Partition = *j.Partition
|
||||||
|
}
|
||||||
|
if j.Time != nil {
|
||||||
|
if j.Time.Submission != nil {
|
||||||
|
resp.SubmitTime = j.Time.Submission
|
||||||
|
}
|
||||||
|
if j.Time.Start != nil {
|
||||||
|
resp.StartTime = j.Time.Start
|
||||||
|
}
|
||||||
|
if j.Time.End != nil {
|
||||||
|
resp.EndTime = j.Time.End
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if j.Nodes != nil {
|
||||||
|
resp.Nodes = *j.Nodes
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
703
internal/service/job_service_test.go
Normal file
703
internal/service/job_service_test.go
Normal file
@@ -0,0 +1,703 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockJobServer(handler http.HandlerFunc) (*slurm.Client, func()) {
|
||||||
|
srv := httptest.NewServer(handler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
return client, srv.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob(t *testing.T) {
|
||||||
|
jobID := int32(123)
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.URL.Path != "/slurm/v0.0.40/job/submit" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body slurm.JobSubmitReq
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if body.Job == nil || body.Job.Script == nil || *body.Job.Script != "#!/bin/bash\necho hello" {
|
||||||
|
t.Errorf("unexpected script in request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{
|
||||||
|
JobID: &jobID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
|
||||||
|
Script: "#!/bin/bash\necho hello",
|
||||||
|
Partition: "normal",
|
||||||
|
QOS: "high",
|
||||||
|
JobName: "test-job",
|
||||||
|
CPUs: 4,
|
||||||
|
TimeLimit: "60",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitJob: %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 123 {
|
||||||
|
t.Errorf("expected JobID 123, got %d", resp.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_WithOptionalFields(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var body slurm.JobSubmitReq
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if body.Job == nil {
|
||||||
|
t.Fatal("job desc is nil")
|
||||||
|
}
|
||||||
|
if body.Job.Partition != nil {
|
||||||
|
t.Error("expected partition nil for empty string")
|
||||||
|
}
|
||||||
|
if body.Job.MinimumCpus != nil {
|
||||||
|
t.Error("expected minimum_cpus nil when CPUs=0")
|
||||||
|
}
|
||||||
|
|
||||||
|
jobID := int32(456)
|
||||||
|
resp := slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
|
||||||
|
Script: "echo hi",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitJob: %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 456 {
|
||||||
|
t.Errorf("expected JobID 456, got %d", resp.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubmitJob_Error(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"internal"}`))
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
|
||||||
|
Script: "echo fail",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobs(t *testing.T) {
|
||||||
|
jobID := int32(100)
|
||||||
|
name := "my-job"
|
||||||
|
partition := "gpu"
|
||||||
|
ts := int64(1700000000)
|
||||||
|
nodes := "node01"
|
||||||
|
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
t.Errorf("expected GET, got %s", r.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: slurm.JobInfoMsg{
|
||||||
|
{
|
||||||
|
JobID: &jobID,
|
||||||
|
Name: &name,
|
||||||
|
JobState: []string{"RUNNING"},
|
||||||
|
Partition: &partition,
|
||||||
|
SubmitTime: &slurm.Uint64NoVal{Number: &ts},
|
||||||
|
StartTime: &slurm.Uint64NoVal{Number: &ts},
|
||||||
|
EndTime: &slurm.Uint64NoVal{Number: &ts},
|
||||||
|
Nodes: &nodes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
jobs, err := svc.GetJobs(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJobs: %v", err)
|
||||||
|
}
|
||||||
|
if len(jobs) != 1 {
|
||||||
|
t.Fatalf("expected 1 job, got %d", len(jobs))
|
||||||
|
}
|
||||||
|
j := jobs[0]
|
||||||
|
if j.JobID != 100 {
|
||||||
|
t.Errorf("expected JobID 100, got %d", j.JobID)
|
||||||
|
}
|
||||||
|
if j.Name != "my-job" {
|
||||||
|
t.Errorf("expected Name my-job, got %s", j.Name)
|
||||||
|
}
|
||||||
|
if len(j.State) != 1 || j.State[0] != "RUNNING" {
|
||||||
|
t.Errorf("expected State [RUNNING], got %v", j.State)
|
||||||
|
}
|
||||||
|
if j.Partition != "gpu" {
|
||||||
|
t.Errorf("expected Partition gpu, got %s", j.Partition)
|
||||||
|
}
|
||||||
|
if j.SubmitTime == nil || *j.SubmitTime != ts {
|
||||||
|
t.Errorf("expected SubmitTime %d, got %v", ts, j.SubmitTime)
|
||||||
|
}
|
||||||
|
if j.Nodes != "node01" {
|
||||||
|
t.Errorf("expected Nodes node01, got %s", j.Nodes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob(t *testing.T) {
|
||||||
|
jobID := int32(200)
|
||||||
|
name := "single-job"
|
||||||
|
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: slurm.JobInfoMsg{
|
||||||
|
{
|
||||||
|
JobID: &jobID,
|
||||||
|
Name: &name,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
job, err := svc.GetJob(context.Background(), "200")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJob: %v", err)
|
||||||
|
}
|
||||||
|
if job == nil {
|
||||||
|
t.Fatal("expected job, got nil")
|
||||||
|
}
|
||||||
|
if job.JobID != 200 {
|
||||||
|
t.Errorf("expected JobID 200, got %d", job.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_NotFound(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiJobInfoResp{Jobs: slurm.JobInfoMsg{}}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
job, err := svc.GetJob(context.Background(), "999")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJob: %v", err)
|
||||||
|
}
|
||||||
|
if job != nil {
|
||||||
|
t.Errorf("expected nil for not found, got %+v", job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelJob(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodDelete {
|
||||||
|
t.Errorf("expected DELETE, got %s", r.Method)
|
||||||
|
}
|
||||||
|
resp := slurm.OpenapiResp{}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
err := svc.CancelJob(context.Background(), "300")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CancelJob: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelJob_Error(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
w.Write([]byte(`not found`))
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
err := svc.CancelJob(context.Background(), "999")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory(t *testing.T) {
|
||||||
|
jobID1 := int32(10)
|
||||||
|
jobID2 := int32(20)
|
||||||
|
jobID3 := int32(30)
|
||||||
|
name1 := "hist-1"
|
||||||
|
name2 := "hist-2"
|
||||||
|
name3 := "hist-3"
|
||||||
|
submission1 := int64(1700000000)
|
||||||
|
submission2 := int64(1700001000)
|
||||||
|
submission3 := int64(1700002000)
|
||||||
|
partition := "normal"
|
||||||
|
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
t.Errorf("expected GET, got %s", r.Method)
|
||||||
|
}
|
||||||
|
users := r.URL.Query().Get("users")
|
||||||
|
if users != "testuser" {
|
||||||
|
t.Errorf("expected users=testuser, got %s", users)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := slurm.OpenapiSlurmdbdJobsResp{
|
||||||
|
Jobs: slurm.JobList{
|
||||||
|
{
|
||||||
|
JobID: &jobID1,
|
||||||
|
Name: &name1,
|
||||||
|
Partition: &partition,
|
||||||
|
State: &slurm.JobState{Current: []string{"COMPLETED"}},
|
||||||
|
Time: &slurm.JobTime{Submission: &submission1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
JobID: &jobID2,
|
||||||
|
Name: &name2,
|
||||||
|
Partition: &partition,
|
||||||
|
State: &slurm.JobState{Current: []string{"FAILED"}},
|
||||||
|
Time: &slurm.JobTime{Submission: &submission2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
JobID: &jobID3,
|
||||||
|
Name: &name3,
|
||||||
|
Partition: &partition,
|
||||||
|
State: &slurm.JobState{Current: []string{"CANCELLED"}},
|
||||||
|
Time: &slurm.JobTime{Submission: &submission3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
|
||||||
|
Users: "testuser",
|
||||||
|
Page: 1,
|
||||||
|
PageSize: 2,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJobHistory: %v", err)
|
||||||
|
}
|
||||||
|
if result.Total != 3 {
|
||||||
|
t.Errorf("expected Total 3, got %d", result.Total)
|
||||||
|
}
|
||||||
|
if result.Page != 1 {
|
||||||
|
t.Errorf("expected Page 1, got %d", result.Page)
|
||||||
|
}
|
||||||
|
if result.PageSize != 2 {
|
||||||
|
t.Errorf("expected PageSize 2, got %d", result.PageSize)
|
||||||
|
}
|
||||||
|
if len(result.Jobs) != 2 {
|
||||||
|
t.Fatalf("expected 2 jobs on page 1, got %d", len(result.Jobs))
|
||||||
|
}
|
||||||
|
if result.Jobs[0].JobID != 10 {
|
||||||
|
t.Errorf("expected first job ID 10, got %d", result.Jobs[0].JobID)
|
||||||
|
}
|
||||||
|
if result.Jobs[1].JobID != 20 {
|
||||||
|
t.Errorf("expected second job ID 20, got %d", result.Jobs[1].JobID)
|
||||||
|
}
|
||||||
|
if len(result.Jobs[0].State) != 1 || result.Jobs[0].State[0] != "COMPLETED" {
|
||||||
|
t.Errorf("expected state [COMPLETED], got %v", result.Jobs[0].State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_Page2(t *testing.T) {
|
||||||
|
jobID1 := int32(10)
|
||||||
|
jobID2 := int32(20)
|
||||||
|
name1 := "a"
|
||||||
|
name2 := "b"
|
||||||
|
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiSlurmdbdJobsResp{
|
||||||
|
Jobs: slurm.JobList{
|
||||||
|
{JobID: &jobID1, Name: &name1},
|
||||||
|
{JobID: &jobID2, Name: &name2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
|
||||||
|
Page: 2,
|
||||||
|
PageSize: 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJobHistory: %v", err)
|
||||||
|
}
|
||||||
|
if result.Total != 2 {
|
||||||
|
t.Errorf("expected Total 2, got %d", result.Total)
|
||||||
|
}
|
||||||
|
if len(result.Jobs) != 1 {
|
||||||
|
t.Fatalf("expected 1 job on page 2, got %d", len(result.Jobs))
|
||||||
|
}
|
||||||
|
if result.Jobs[0].JobID != 20 {
|
||||||
|
t.Errorf("expected job ID 20, got %d", result.Jobs[0].JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_DefaultPagination(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJobHistory: %v", err)
|
||||||
|
}
|
||||||
|
if result.Page != 1 {
|
||||||
|
t.Errorf("expected default page 1, got %d", result.Page)
|
||||||
|
}
|
||||||
|
if result.PageSize != 20 {
|
||||||
|
t.Errorf("expected default pageSize 20, got %d", result.PageSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_QueryMapping(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
q := r.URL.Query()
|
||||||
|
if v := q.Get("account"); v != "proj1" {
|
||||||
|
t.Errorf("expected account=proj1, got %s", v)
|
||||||
|
}
|
||||||
|
if v := q.Get("partition"); v != "gpu" {
|
||||||
|
t.Errorf("expected partition=gpu, got %s", v)
|
||||||
|
}
|
||||||
|
if v := q.Get("state"); v != "COMPLETED" {
|
||||||
|
t.Errorf("expected state=COMPLETED, got %s", v)
|
||||||
|
}
|
||||||
|
if v := q.Get("job_name"); v != "myjob" {
|
||||||
|
t.Errorf("expected job_name=myjob, got %s", v)
|
||||||
|
}
|
||||||
|
if v := q.Get("start_time"); v != "1700000000" {
|
||||||
|
t.Errorf("expected start_time=1700000000, got %s", v)
|
||||||
|
}
|
||||||
|
if v := q.Get("end_time"); v != "1700099999" {
|
||||||
|
t.Errorf("expected end_time=1700099999, got %s", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
|
||||||
|
Users: "testuser",
|
||||||
|
Account: "proj1",
|
||||||
|
Partition: "gpu",
|
||||||
|
State: "COMPLETED",
|
||||||
|
JobName: "myjob",
|
||||||
|
StartTime: "1700000000",
|
||||||
|
EndTime: "1700099999",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJobHistory: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobHistory_Error(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"db down"}`))
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapJobInfo_ExitCode(t *testing.T) {
|
||||||
|
returnCode := int64(2)
|
||||||
|
ji := &slurm.JobInfo{
|
||||||
|
ExitCode: &slurm.ProcessExitCodeVerbose{
|
||||||
|
ReturnCode: &slurm.Uint32NoVal{Number: &returnCode},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := mapJobInfo(ji)
|
||||||
|
if resp.ExitCode == nil || *resp.ExitCode != 2 {
|
||||||
|
t.Errorf("expected exit code 2, got %v", resp.ExitCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapSlurmdbJob_NilFields(t *testing.T) {
|
||||||
|
j := &slurm.Job{}
|
||||||
|
resp := mapSlurmdbJob(j)
|
||||||
|
if resp.JobID != 0 {
|
||||||
|
t.Errorf("expected JobID 0, got %d", resp.JobID)
|
||||||
|
}
|
||||||
|
if resp.State != nil {
|
||||||
|
t.Errorf("expected nil State, got %v", resp.State)
|
||||||
|
}
|
||||||
|
if resp.SubmitTime != nil {
|
||||||
|
t.Errorf("expected nil SubmitTime, got %v", resp.SubmitTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Structured logging tests using zaptest/observer
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func newJobServiceWithObserver(srv *httptest.Server) (*JobService, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
return NewJobService(client, l), recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_SubmitJob_SuccessLog(t *testing.T) {
|
||||||
|
jobID := int32(789)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
|
||||||
|
Script: "echo hi",
|
||||||
|
JobName: "log-test-job",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.InfoLevel {
|
||||||
|
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["job_name"] != "log-test-job" {
|
||||||
|
t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"])
|
||||||
|
}
|
||||||
|
gotJobID, ok := fields["job_id"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected job_id field in log entry")
|
||||||
|
}
|
||||||
|
if gotJobID != int32(789) && gotJobID != int64(789) {
|
||||||
|
t.Errorf("expected job_id=789, got %v (%T)", gotJobID, gotJobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_SubmitJob_ErrorLog(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"internal"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{Script: "echo fail"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["operation"] != "submit" {
|
||||||
|
t.Errorf("expected operation=submit, got %v", fields["operation"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["error"]; !ok {
|
||||||
|
t.Error("expected error field in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_CancelJob_SuccessLog(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiResp{}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
err := svc.CancelJob(context.Background(), "555")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.InfoLevel {
|
||||||
|
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["job_id"] != "555" {
|
||||||
|
t.Errorf("expected job_id=555, got %v", fields["job_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_CancelJob_ErrorLog(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
w.Write([]byte(`not found`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
err := svc.CancelJob(context.Background(), "999")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["operation"] != "cancel" {
|
||||||
|
t.Errorf("expected operation=cancel, got %v", fields["operation"])
|
||||||
|
}
|
||||||
|
if fields["job_id"] != "999" {
|
||||||
|
t.Errorf("expected job_id=999, got %v", fields["job_id"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["error"]; !ok {
|
||||||
|
t.Error("expected error field in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_GetJobs_ErrorLog(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"down"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetJobs(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["operation"] != "get_jobs" {
|
||||||
|
t.Errorf("expected operation=get_jobs, got %v", fields["operation"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["error"]; !ok {
|
||||||
|
t.Error("expected error field in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_GetJob_ErrorLog(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"down"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetJob(context.Background(), "200")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["operation"] != "get_job" {
|
||||||
|
t.Errorf("expected operation=get_job, got %v", fields["operation"])
|
||||||
|
}
|
||||||
|
if fields["job_id"] != "200" {
|
||||||
|
t.Errorf("expected job_id=200, got %v", fields["job_id"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["error"]; !ok {
|
||||||
|
t.Error("expected error field in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"db down"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
|
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := recorded.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Level != zapcore.ErrorLevel {
|
||||||
|
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
||||||
|
}
|
||||||
|
fields := entries[0].ContextMap()
|
||||||
|
if fields["operation"] != "get_job_history" {
|
||||||
|
t.Errorf("expected operation=get_job_history, got %v", fields["operation"])
|
||||||
|
}
|
||||||
|
if _, ok := fields["error"]; !ok {
|
||||||
|
t.Error("expected error field in log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
15
internal/service/slurm_client.go
Normal file
15
internal/service/slurm_client.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSlurmClient creates a Slurm SDK client with JWT authentication.
|
||||||
|
// It reads the JWT key from the given keyPath and signs tokens automatically.
|
||||||
|
func NewSlurmClient(apiURL, userName, jwtKeyPath string) (*slurm.Client, error) {
|
||||||
|
return slurm.NewClientWithOpts(
|
||||||
|
apiURL,
|
||||||
|
slurm.WithUsername(userName),
|
||||||
|
slurm.WithJWTKey(jwtKeyPath),
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE IF EXISTS job_templates;
|
||||||
14
internal/store/migrations/001_create_job_templates.up.sql
Normal file
14
internal/store/migrations/001_create_job_templates.up.sql
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS job_templates (
|
||||||
|
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
script TEXT NOT NULL,
|
||||||
|
partition VARCHAR(255),
|
||||||
|
qos VARCHAR(255),
|
||||||
|
cpus INT UNSIGNED,
|
||||||
|
memory VARCHAR(50),
|
||||||
|
time_limit VARCHAR(50),
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE KEY idx_name (name)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||||
44
internal/store/mysql.go
Normal file
44
internal/store/mysql.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/logger"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewGormDB opens a GORM MySQL connection with sensible defaults.
|
||||||
|
func NewGormDB(dsn string, zapLogger *zap.Logger, gormLevel string) (*gorm.DB, error) {
|
||||||
|
gormCfg := &gorm.Config{
|
||||||
|
Logger: logger.NewGormLogger(zapLogger, gormLevel),
|
||||||
|
}
|
||||||
|
db, err := gorm.Open(mysql.Open(dsn), gormCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open gorm mysql: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB.SetMaxOpenConns(25)
|
||||||
|
sqlDB.SetMaxIdleConns(5)
|
||||||
|
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||||
|
|
||||||
|
if err := sqlDB.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to ping mysql: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoMigrate runs GORM auto-migration for all models.
|
||||||
|
func AutoMigrate(db *gorm.DB) error {
|
||||||
|
return db.AutoMigrate(&model.JobTemplate{})
|
||||||
|
}
|
||||||
14
internal/store/mysql_test.go
Normal file
14
internal/store/mysql_test.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewGormDBInvalidDSN(t *testing.T) {
|
||||||
|
_, err := NewGormDB("invalid:dsn@tcp(localhost:99999)/nonexistent?parseTime=true", zap.NewNop(), "warn")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid DSN, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
113
internal/store/template_store.go
Normal file
113
internal/store/template_store.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TemplateStore provides CRUD operations for job templates via GORM.
|
||||||
|
type TemplateStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTemplateStore creates a new TemplateStore.
|
||||||
|
func NewTemplateStore(db *gorm.DB) *TemplateStore {
|
||||||
|
return &TemplateStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of job templates and the total count.
|
||||||
|
func (s *TemplateStore) List(ctx context.Context, page, pageSize int) ([]model.JobTemplate, int, error) {
|
||||||
|
var templates []model.JobTemplate
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&templates).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return templates, int(total), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID returns a single job template by ID. Returns nil, nil when not found.
|
||||||
|
func (s *TemplateStore) GetByID(ctx context.Context, id int64) (*model.JobTemplate, error) {
|
||||||
|
var t model.JobTemplate
|
||||||
|
err := s.db.WithContext(ctx).First(&t, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create inserts a new job template and returns the generated ID.
|
||||||
|
func (s *TemplateStore) Create(ctx context.Context, req *model.CreateTemplateRequest) (int64, error) {
|
||||||
|
t := &model.JobTemplate{
|
||||||
|
Name: req.Name,
|
||||||
|
Description: req.Description,
|
||||||
|
Script: req.Script,
|
||||||
|
Partition: req.Partition,
|
||||||
|
QOS: req.QOS,
|
||||||
|
CPUs: req.CPUs,
|
||||||
|
Memory: req.Memory,
|
||||||
|
TimeLimit: req.TimeLimit,
|
||||||
|
}
|
||||||
|
if err := s.db.WithContext(ctx).Create(t).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return t.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing job template. Only non-empty/non-zero fields are updated.
|
||||||
|
func (s *TemplateStore) Update(ctx context.Context, id int64, req *model.UpdateTemplateRequest) error {
|
||||||
|
updates := map[string]interface{}{}
|
||||||
|
if req.Name != "" {
|
||||||
|
updates["name"] = req.Name
|
||||||
|
}
|
||||||
|
if req.Description != "" {
|
||||||
|
updates["description"] = req.Description
|
||||||
|
}
|
||||||
|
if req.Script != "" {
|
||||||
|
updates["script"] = req.Script
|
||||||
|
}
|
||||||
|
if req.Partition != "" {
|
||||||
|
updates["partition"] = req.Partition
|
||||||
|
}
|
||||||
|
if req.QOS != "" {
|
||||||
|
updates["qos"] = req.QOS
|
||||||
|
}
|
||||||
|
if req.CPUs > 0 {
|
||||||
|
updates["cpus"] = req.CPUs
|
||||||
|
}
|
||||||
|
if req.Memory != "" {
|
||||||
|
updates["memory"] = req.Memory
|
||||||
|
}
|
||||||
|
if req.TimeLimit != "" {
|
||||||
|
updates["time_limit"] = req.TimeLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil // nothing to update
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Where("id = ?", id).Updates(updates)
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a job template by ID. Idempotent — returns nil even if the row doesn't exist.
|
||||||
|
func (s *TemplateStore) Delete(ctx context.Context, id int64) error {
|
||||||
|
result := s.db.WithContext(ctx).Delete(&model.JobTemplate{}, id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
205
internal/store/template_store_test.go
Normal file
205
internal/store/template_store_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.JobTemplate{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_List(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-1", Script: "echo 1"})
|
||||||
|
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-2", Script: "echo 2"})
|
||||||
|
|
||||||
|
templates, total, err := s.List(context.Background(), 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("total = %d, want 2", total)
|
||||||
|
}
|
||||||
|
if len(templates) != 2 {
|
||||||
|
t.Fatalf("len(templates) = %d, want 2", len(templates))
|
||||||
|
}
|
||||||
|
// DESC order, so job-2 is first
|
||||||
|
if templates[0].Name != "job-2" {
|
||||||
|
t.Errorf("templates[0].Name = %q, want %q", templates[0].Name, "job-2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_List_Page2(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
s.Create(context.Background(), &model.CreateTemplateRequest{
|
||||||
|
Name: "job-" + string(rune('A'+i)), Script: "echo",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
templates, total, err := s.List(context.Background(), 2, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 15 {
|
||||||
|
t.Errorf("total = %d, want 15", total)
|
||||||
|
}
|
||||||
|
if len(templates) != 5 {
|
||||||
|
t.Fatalf("len(templates) = %d, want 5", len(templates))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_GetByID(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
||||||
|
Name: "test-job", Script: "echo hi", Partition: "batch", QOS: "normal", CPUs: 2, Memory: "4G",
|
||||||
|
})
|
||||||
|
|
||||||
|
tpl, err := s.GetByID(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if tpl == nil {
|
||||||
|
t.Fatal("GetByID() returned nil")
|
||||||
|
}
|
||||||
|
if tpl.Name != "test-job" {
|
||||||
|
t.Errorf("Name = %q, want %q", tpl.Name, "test-job")
|
||||||
|
}
|
||||||
|
if tpl.CPUs != 2 {
|
||||||
|
t.Errorf("CPUs = %d, want 2", tpl.CPUs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_GetByID_NotFound(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
tpl, err := s.GetByID(context.Background(), 999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v, want nil", err)
|
||||||
|
}
|
||||||
|
if tpl != nil {
|
||||||
|
t.Fatal("GetByID() should return nil for not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_Create(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
id, err := s.Create(context.Background(), &model.CreateTemplateRequest{
|
||||||
|
Name: "new-job", Script: "echo", Partition: "gpu",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if id == 0 {
|
||||||
|
t.Fatal("Create() returned id=0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_Update(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
||||||
|
Name: "old", Script: "echo",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
|
||||||
|
Name: "updated",
|
||||||
|
Script: "echo new",
|
||||||
|
CPUs: 8,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tpl, _ := s.GetByID(context.Background(), id)
|
||||||
|
if tpl.Name != "updated" {
|
||||||
|
t.Errorf("Name = %q, want %q", tpl.Name, "updated")
|
||||||
|
}
|
||||||
|
if tpl.CPUs != 8 {
|
||||||
|
t.Errorf("CPUs = %d, want 8", tpl.CPUs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_Update_Partial(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
||||||
|
Name: "original", Script: "echo orig", Partition: "batch",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
|
||||||
|
Name: "renamed",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tpl, _ := s.GetByID(context.Background(), id)
|
||||||
|
if tpl.Name != "renamed" {
|
||||||
|
t.Errorf("Name = %q, want %q", tpl.Name, "renamed")
|
||||||
|
}
|
||||||
|
// Script and Partition should be unchanged
|
||||||
|
if tpl.Script != "echo orig" {
|
||||||
|
t.Errorf("Script = %q, want %q", tpl.Script, "echo orig")
|
||||||
|
}
|
||||||
|
if tpl.Partition != "batch" {
|
||||||
|
t.Errorf("Partition = %q, want %q", tpl.Partition, "batch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_Delete(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
||||||
|
Name: "to-delete", Script: "echo",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Delete(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tpl, _ := s.GetByID(context.Background(), id)
|
||||||
|
if tpl != nil {
|
||||||
|
t.Fatal("Delete() did not remove the record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateStore_Delete_NotFound(t *testing.T) {
|
||||||
|
db := newTestDB(t)
|
||||||
|
s := NewTemplateStore(db)
|
||||||
|
|
||||||
|
err := s.Delete(context.Background(), 999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete() should not error for non-existent record, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user