Compare commits

..

7 Commits

Author SHA1 Message Date
dailz
4ff02d4a80 fix: 移除 main() 中多余的 defer application.Close()
Run() 在所有退出路径中已调用 Close(),main 中的 defer 是冗余的。

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:48:42 +08:00
dailz
1784331969 feat: 添加应用骨架,配置化 zap 日志贯穿全链路
- cmd/server/main.go: 使用 logger.NewLogger(cfg.Log) 替代 zap.NewProduction()

- internal/app: 依赖注入组装 DB/Slurm/Service/Handler,传递 logger

- internal/middleware: RequestLogger 请求日志中间件

- internal/server: 统一响应格式和路由注册

- go.mod: module 更名为 gcy_hpc_server,添加 gin/zap/lumberjack/gorm 依赖

- 日志初始化失败时 fail fast (os.Exit(1))

- GormLevel 从配置传递到 NewGormDB,支持 YAML 独立配置

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:40:16 +08:00
dailz
e6162063ca feat: 添加 HTTP 处理层和结构化日志
- JobHandler: 提交/查询/取消/历史,5xx Error + 4xx Warn 日志

- ClusterHandler: 节点/分区/诊断,错误和未找到日志

- TemplateHandler: CRUD 操作,创建/更新/删除 Info + 未找到 Warn

- 不记录成功响应(由 middleware.RequestLogger 处理)

- 不记录请求体和模板内容(安全考虑)

- 完整 TDD 测试,使用 zaptest/observer 验证日志级别和字段

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:40:06 +08:00
dailz
4903f7d07f feat: 添加业务服务层和结构化日志
- JobService: 提交、查询、取消、历史记录,记录关键操作日志

- ClusterService: 节点、分区、诊断查询,记录错误日志

- NewSlurmClient: JWT 认证 HTTP 客户端工厂

- 所有构造函数接受 *zap.Logger 参数实现依赖注入

- 提交/取消成功记录 Info,API 错误记录 Error

- 完整 TDD 测试,使用 zaptest/observer 验证日志输出

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:46 +08:00
dailz
fbfd5c5f42 feat: 添加数据模型和存储层
- model: JobTemplate、SubmitJobRequest、JobHistoryQuery 等模型定义

- store: NewGormDB MySQL 连接池,使用 zap 日志替代 GORM 默认日志

- store: TemplateStore CRUD 操作,支持 GORM AutoMigrate

- NewGormDB 接受 gormLevel 参数,由上层传入配置值

- 完整 TDD 测试覆盖

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:30 +08:00
dailz
f7a21ee455 feat: 添加 zap 日志工厂和 GORM 日志桥接
- NewLogger 工厂函数:支持 JSON/Console 编码、stdout/文件/多输出、lumberjack 轮转

- NewGormLogger 实现 gorm.Interface:Trace 区分错误/慢查询/正常查询

- output_stdout 用 *bool 三态处理(nil=true, true, false)

- 默认值:level=info, encoding=json, max_size=100, max_backups=5, max_age=30

- 慢查询阈值 200ms,ErrRecordNotFound 不视为错误

- 编译时接口检查: var _ gormlogger.Interface = (*GormLogger)(nil)

- 完整 TDD 测试覆盖

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:21 +08:00
dailz
7550e75945 feat: 添加配置加载和日志配置支持
- 新增 LogConfig 结构体,支持 9 个日志配置字段(level, encoding, output_stdout, file_path, max_size, max_backups, max_age, compress, gorm_level)

- Config 结构体新增 Log 字段,支持 YAML 解析

- output_stdout 使用 *bool 指针类型,nil 默认为 true

- 更新 config.example.yaml 添加完整 log 配置段

- 新增 TDD 测试:日志配置解析、向后兼容、字段完整性

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:09 +08:00
38 changed files with 6575 additions and 2 deletions

39
cmd/server/main.go Normal file
View File

@@ -0,0 +1,39 @@
package main
import (
"fmt"
"os"
"gcy_hpc_server/internal/app"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/logger"
"go.uber.org/zap"
)
func main() {
cfgPath := ""
if len(os.Args) > 1 {
cfgPath = os.Args[1]
}
cfg, err := config.Load(cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err)
os.Exit(1)
}
zapLogger, err := logger.NewLogger(cfg.Log)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err)
os.Exit(1)
}
defer zapLogger.Sync()
application, err := app.NewApp(cfg, zapLogger)
if err != nil {
zapLogger.Fatal("failed to initialize application", zap.Error(err))
}
if err := application.Run(); err != nil {
zapLogger.Fatal("application error", zap.Error(err))
}
}

117
cmd/server/main_test.go Normal file
View File

@@ -0,0 +1,117 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTestDB() *gorm.DB {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.JobTemplate{})
return db
}
func TestRouterRegistration(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
}))
defer slurmSrv.Close()
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
templateStore := store.NewTemplateStore(newTestDB())
router := server.NewRouter(
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
handler.NewTemplateHandler(templateStore, zap.NewNop()),
nil,
)
routes := router.Routes()
expected := []struct {
method string
path string
}{
{"POST", "/api/v1/jobs/submit"},
{"GET", "/api/v1/jobs"},
{"GET", "/api/v1/jobs/history"},
{"GET", "/api/v1/jobs/:id"},
{"DELETE", "/api/v1/jobs/:id"},
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/nodes/:name"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/partitions/:name"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/templates"},
{"POST", "/api/v1/templates"},
{"GET", "/api/v1/templates/:id"},
{"PUT", "/api/v1/templates/:id"},
{"DELETE", "/api/v1/templates/:id"},
}
routeMap := map[string]bool{}
for _, r := range routes {
routeMap[r.Method+" "+r.Path] = true
}
for _, exp := range expected {
key := exp.method + " " + exp.path
if !routeMap[key] {
t.Errorf("missing route: %s", key)
}
}
if len(routes) < len(expected) {
t.Errorf("expected at least %d routes, got %d", len(expected), len(routes))
}
}
func TestSmokeGetJobsEndpoint(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
}))
defer slurmSrv.Close()
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
templateStore := store.NewTemplateStore(newTestDB())
router := server.NewRouter(
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
handler.NewTemplateHandler(templateStore, zap.NewNop()),
nil,
)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if success, ok := resp["success"].(bool); !ok || !success {
t.Fatalf("expected success=true, got %v", resp["success"])
}
}

54
go.mod
View File

@@ -1,3 +1,53 @@
module slurm-client
module gcy_hpc_server
go 1.22
go 1.25.0
require (
github.com/gin-gonic/gin v1.12.0
go.uber.org/zap v1.27.1
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
)

123
go.sum Normal file
View File

@@ -0,0 +1,123 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=

159
internal/app/app.go Normal file
View File

@@ -0,0 +1,159 @@
package app
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
)
// App encapsulates the entire application lifecycle.
type App struct {
cfg *config.Config
logger *zap.Logger
db *gorm.DB
server *http.Server
}
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
gormDB, err := initDB(cfg, logger)
if err != nil {
return nil, err
}
slurmClient, err := initSlurmClient(cfg)
if err != nil {
closeDB(gormDB)
return nil, err
}
srv := initHTTPServer(cfg, gormDB, slurmClient, logger)
return &App{
cfg: cfg,
logger: logger,
db: gormDB,
server: srv,
}, nil
}
// Run starts the HTTP server and blocks until a shutdown signal or server error.
func (a *App) Run() error {
errCh := make(chan error, 1)
go func() {
a.logger.Info("starting server", zap.String("addr", a.server.Addr))
if err := a.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errCh <- fmt.Errorf("server listen: %w", err)
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-errCh:
// Server crashed before receiving a signal — clean up resources before
// returning, because the caller may call os.Exit and skip deferred Close().
a.logger.Error("server exited unexpectedly", zap.Error(err))
_ = a.Close()
return err
case sig := <-quit:
a.logger.Info("received shutdown signal", zap.String("signal", sig.String()))
}
a.logger.Info("shutting down server...")
return a.Close()
}
// Close cleans up all resources: HTTP server and database connections.
func (a *App) Close() error {
var errs []error
if a.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := a.server.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
errs = append(errs, fmt.Errorf("shutdown http server: %w", err))
}
}
if a.db != nil {
sqlDB, err := a.db.DB()
if err != nil {
errs = append(errs, fmt.Errorf("get underlying sql.DB: %w", err))
} else if err := sqlDB.Close(); err != nil {
errs = append(errs, fmt.Errorf("close database: %w", err))
}
}
return errors.Join(errs...)
}
// ---------------------------------------------------------------------------
// Initialization helpers
// ---------------------------------------------------------------------------
func initDB(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
gormDB, err := store.NewGormDB(cfg.MySQLDSN, logger, cfg.Log.GormLevel)
if err != nil {
return nil, fmt.Errorf("init database: %w", err)
}
if err := store.AutoMigrate(gormDB); err != nil {
closeDB(gormDB)
return nil, fmt.Errorf("run migrations: %w", err)
}
return gormDB, nil
}
func closeDB(db *gorm.DB) {
if db == nil {
return
}
if sqlDB, err := db.DB(); err == nil {
_ = sqlDB.Close()
}
}
func initSlurmClient(cfg *config.Config) (*slurm.Client, error) {
client, err := service.NewSlurmClient(cfg.SlurmAPIURL, cfg.SlurmUserName, cfg.SlurmJWTKeyPath)
if err != nil {
return nil, fmt.Errorf("init slurm client: %w", err)
}
return client, nil
}
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) *http.Server {
jobSvc := service.NewJobService(slurmClient, logger)
clusterSvc := service.NewClusterService(slurmClient, logger)
templateStore := store.NewTemplateStore(db)
jobH := handler.NewJobHandler(jobSvc, logger)
clusterH := handler.NewClusterHandler(clusterSvc, logger)
templateH := handler.NewTemplateHandler(templateStore, logger)
router := server.NewRouter(jobH, clusterH, templateH, logger)
addr := ":" + cfg.ServerPort
return &http.Server{
Addr: addr,
Handler: router,
}
}

25
internal/app/app_test.go Normal file
View File

@@ -0,0 +1,25 @@
package app
import (
"testing"
"gcy_hpc_server/internal/config"
"go.uber.org/zap/zaptest"
)
func TestNewApp_InvalidDB(t *testing.T) {
cfg := &config.Config{
ServerPort: "8080",
MySQLDSN: "invalid:dsn@tcp(localhost:99999)/nonexistent?parseTime=true",
SlurmAPIURL: "http://localhost:6820",
SlurmUserName: "root",
SlurmJWTKeyPath: "/nonexistent/jwt.key",
}
logger := zaptest.NewLogger(t)
_, err := NewApp(cfg, logger)
if err == nil {
t.Fatal("expected error for invalid DSN, got nil")
}
}

View File

@@ -0,0 +1,16 @@
server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
log:
level: "info" # debug, info, warn, error
encoding: "json" # json, console
output_stdout: true # 是否输出日志到终端
file_path: "" # 日志文件路径,留空则不写文件
max_size: 100 # max MB per log file
max_backups: 5 # number of old log files to retain
max_age: 30 # days to retain old log files
compress: true # gzip rotated log files
gorm_level: "warn" # GORM SQL log level: silent, error, warn, info

51
internal/config/config.go Normal file
View File

@@ -0,0 +1,51 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
// LogConfig holds logging configuration values.
type LogConfig struct {
Level string `yaml:"level"` // debug, info, warn, error (default: info)
Encoding string `yaml:"encoding"` // json, console (default: json)
OutputStdout *bool `yaml:"output_stdout"` // 输出到终端 (default: true)
FilePath string `yaml:"file_path"` // log file path (for rotation)
MaxSize int `yaml:"max_size"` // MB per file (default: 100)
MaxBackups int `yaml:"max_backups"` // retained files (default: 5)
MaxAge int `yaml:"max_age"` // days to retain (default: 30)
Compress bool `yaml:"compress"` // gzip old files (default: true)
GormLevel string `yaml:"gorm_level"` // GORM SQL log level (default: warn)
}
// Config holds all application configuration values.
type Config struct {
ServerPort string `yaml:"server_port"`
SlurmAPIURL string `yaml:"slurm_api_url"`
SlurmUserName string `yaml:"slurm_user_name"`
SlurmJWTKeyPath string `yaml:"slurm_jwt_key_path"`
MySQLDSN string `yaml:"mysql_dsn"`
Log LogConfig `yaml:"log"`
}
// Load reads a YAML configuration file and returns a parsed Config.
// If path is empty, it defaults to "./config.yaml".
func Load(path string) (*Config, error) {
if path == "" {
path = "./config.yaml"
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config file %s: %w", path, err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config file %s: %w", path, err)
}
return &cfg, nil
}

View File

@@ -0,0 +1,256 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.ServerPort != "9090" {
t.Errorf("ServerPort = %q, want %q", cfg.ServerPort, "9090")
}
if cfg.SlurmAPIURL != "http://slurm.example.com:6820" {
t.Errorf("SlurmAPIURL = %q, want %q", cfg.SlurmAPIURL, "http://slurm.example.com:6820")
}
if cfg.SlurmUserName != "admin" {
t.Errorf("SlurmUserName = %q, want %q", cfg.SlurmUserName, "admin")
}
if cfg.SlurmJWTKeyPath != "/etc/slurm/jwt.key" {
t.Errorf("SlurmJWTKeyPath = %q, want %q", cfg.SlurmJWTKeyPath, "/etc/slurm/jwt.key")
}
if cfg.MySQLDSN != "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true" {
t.Errorf("MySQLDSN = %q, want %q", cfg.MySQLDSN, "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true")
}
}
func TestLoadDefaultPath(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg == nil {
t.Fatal("Load() returned nil config")
}
}
func TestLoadWithLogConfig(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
log:
level: "debug"
encoding: "console"
output_stdout: true
file_path: "/var/log/app.log"
max_size: 200
max_backups: 10
max_age: 60
compress: false
gorm_level: "info"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Log.Level != "debug" {
t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug")
}
if cfg.Log.Encoding != "console" {
t.Errorf("Log.Encoding = %q, want %q", cfg.Log.Encoding, "console")
}
if cfg.Log.OutputStdout == nil || *cfg.Log.OutputStdout != true {
t.Errorf("Log.OutputStdout = %v, want true", cfg.Log.OutputStdout)
}
if cfg.Log.FilePath != "/var/log/app.log" {
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
}
if cfg.Log.MaxSize != 200 {
t.Errorf("Log.MaxSize = %d, want %d", cfg.Log.MaxSize, 200)
}
if cfg.Log.MaxBackups != 10 {
t.Errorf("Log.MaxBackups = %d, want %d", cfg.Log.MaxBackups, 10)
}
if cfg.Log.MaxAge != 60 {
t.Errorf("Log.MaxAge = %d, want %d", cfg.Log.MaxAge, 60)
}
if cfg.Log.Compress != false {
t.Errorf("Log.Compress = %v, want %v", cfg.Log.Compress, false)
}
if cfg.Log.GormLevel != "info" {
t.Errorf("Log.GormLevel = %q, want %q", cfg.Log.GormLevel, "info")
}
}
func TestLoadWithoutLogConfig(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Log.Level != "" {
t.Errorf("Log.Level = %q, want empty string", cfg.Log.Level)
}
if cfg.Log.Encoding != "" {
t.Errorf("Log.Encoding = %q, want empty string", cfg.Log.Encoding)
}
if cfg.Log.OutputStdout != nil {
t.Errorf("Log.OutputStdout = %v, want nil", cfg.Log.OutputStdout)
}
if cfg.Log.FilePath != "" {
t.Errorf("Log.FilePath = %q, want empty string", cfg.Log.FilePath)
}
if cfg.Log.MaxSize != 0 {
t.Errorf("Log.MaxSize = %d, want 0", cfg.Log.MaxSize)
}
if cfg.Log.MaxBackups != 0 {
t.Errorf("Log.MaxBackups = %d, want 0", cfg.Log.MaxBackups)
}
if cfg.Log.MaxAge != 0 {
t.Errorf("Log.MaxAge = %d, want 0", cfg.Log.MaxAge)
}
if cfg.Log.Compress != false {
t.Errorf("Log.Compress = %v, want false", cfg.Log.Compress)
}
if cfg.Log.GormLevel != "" {
t.Errorf("Log.GormLevel = %q, want empty string", cfg.Log.GormLevel)
}
}
func TestLoadExistingFieldsWithLogConfig(t *testing.T) {
content := []byte(`server_port: "7070"
slurm_api_url: "http://slurm2.example.com:6820"
slurm_user_name: "testuser"
slurm_jwt_key_path: "/keys/jwt.key"
mysql_dsn: "root:secret@tcp(db:3306)/mydb?parseTime=true"
log:
level: "warn"
encoding: "json"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.ServerPort != "7070" {
t.Errorf("ServerPort = %q, want %q", cfg.ServerPort, "7070")
}
if cfg.SlurmAPIURL != "http://slurm2.example.com:6820" {
t.Errorf("SlurmAPIURL = %q, want %q", cfg.SlurmAPIURL, "http://slurm2.example.com:6820")
}
if cfg.SlurmUserName != "testuser" {
t.Errorf("SlurmUserName = %q, want %q", cfg.SlurmUserName, "testuser")
}
if cfg.SlurmJWTKeyPath != "/keys/jwt.key" {
t.Errorf("SlurmJWTKeyPath = %q, want %q", cfg.SlurmJWTKeyPath, "/keys/jwt.key")
}
if cfg.MySQLDSN != "root:secret@tcp(db:3306)/mydb?parseTime=true" {
t.Errorf("MySQLDSN = %q, want %q", cfg.MySQLDSN, "root:secret@tcp(db:3306)/mydb?parseTime=true")
}
if cfg.Log.Level != "warn" {
t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "warn")
}
if cfg.Log.Encoding != "json" {
t.Errorf("Log.Encoding = %q, want %q", cfg.Log.Encoding, "json")
}
}
func TestLoadNonExistentFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.yaml")
if err == nil {
t.Fatal("Load() expected error for non-existent file, got nil")
}
}
func TestLoadWithOutputStdoutFalse(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
log:
level: "info"
encoding: "json"
output_stdout: false
file_path: "/var/log/app.log"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Log.OutputStdout == nil || *cfg.Log.OutputStdout != false {
t.Errorf("Log.OutputStdout = %v, want false", cfg.Log.OutputStdout)
}
if cfg.Log.FilePath != "/var/log/app.log" {
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
}
}

View File

@@ -0,0 +1,94 @@
package handler
import (
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ClusterHandler handles HTTP requests for cluster operations (nodes, partitions, diag).
type ClusterHandler struct {
clusterSvc *service.ClusterService
logger *zap.Logger
}
// NewClusterHandler creates a new ClusterHandler with the given ClusterService.
func NewClusterHandler(clusterSvc *service.ClusterService, logger *zap.Logger) *ClusterHandler {
return &ClusterHandler{clusterSvc: clusterSvc, logger: logger}
}
// GetNodes handles GET /api/v1/nodes.
func (h *ClusterHandler) GetNodes(c *gin.Context) {
nodes, err := h.clusterSvc.GetNodes(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetNodes"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, nodes)
}
// GetNode handles GET /api/v1/nodes/:name.
func (h *ClusterHandler) GetNode(c *gin.Context) {
name := c.Param("name")
resp, err := h.clusterSvc.GetNode(c.Request.Context(), name)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetNode"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if resp == nil {
h.logger.Warn("not found", zap.String("method", "GetNode"), zap.String("name", name))
server.NotFound(c, "node not found")
return
}
server.OK(c, resp)
}
// GetPartitions handles GET /api/v1/partitions.
func (h *ClusterHandler) GetPartitions(c *gin.Context) {
partitions, err := h.clusterSvc.GetPartitions(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetPartitions"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, partitions)
}
// GetPartition handles GET /api/v1/partitions/:name.
func (h *ClusterHandler) GetPartition(c *gin.Context) {
name := c.Param("name")
resp, err := h.clusterSvc.GetPartition(c.Request.Context(), name)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetPartition"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if resp == nil {
h.logger.Warn("not found", zap.String("method", "GetPartition"), zap.String("name", name))
server.NotFound(c, "partition not found")
return
}
server.OK(c, resp)
}
// GetDiag handles GET /api/v1/diag.
func (h *ClusterHandler) GetDiag(c *gin.Context) {
resp, err := h.clusterSvc.GetDiag(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetDiag"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}

View File

@@ -0,0 +1,634 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func setupClusterHandler(slurmHandler http.HandlerFunc) (*httptest.Server, *ClusterHandler) {
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
clusterSvc := service.NewClusterService(client, zap.NewNop())
return srv, NewClusterHandler(clusterSvc, zap.NewNop())
}
func setupClusterHandlerWithObserver(slurmHandler http.HandlerFunc) (*httptest.Server, *ClusterHandler, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
clusterSvc := service.NewClusterService(client, l)
return srv, NewClusterHandler(clusterSvc, l), recorded
}
func setupClusterRouter(h *ClusterHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
v1.GET("/nodes", h.GetNodes)
v1.GET("/nodes/:name", h.GetNode)
v1.GET("/partitions", h.GetPartitions)
v1.GET("/partitions/:name", h.GetPartition)
v1.GET("/diag", h.GetDiag)
return r
}
func TestGetNodes_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetNode_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetNode_NotFound(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestGetPartitions_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": int32(10),
},
"cpus": map[string]interface{}{
"total": int32(640),
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{"number": int64(60)},
},
},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetPartition_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": int32(10),
},
"cpus": map[string]interface{}{
"total": int32(640),
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{"number": int64(60)},
},
},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetPartition_NotFound(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestGetDiag_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"statistics": map[string]interface{}{
"server_thread_count": 3,
"agent_queue_size": 0,
"jobs_submitted": 100,
"jobs_started": 90,
"jobs_completed": 85,
"schedule_cycle_last": 10,
"schedule_cycle_total": 500,
},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
// --- Logging tests ---
func TestClusterHandler_GetNodes_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetNodes")
}
func TestClusterHandler_GetNodes_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1"},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetNode_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetNode")
}
func TestClusterHandler_GetNode_NotFound_LogsWarn(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
if recorded.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
}
entry := recorded.All()[0]
if entry.Level != zapcore.WarnLevel {
t.Fatalf("expected Warn level, got %v", entry.Level)
}
if entry.Message != "not found" {
t.Fatalf("expected message 'not found', got %q", entry.Message)
}
assertField(t, entry.Context, "method", "GetNode")
assertField(t, entry.Context, "name", "nonexistent")
}
func TestClusterHandler_GetNode_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetPartitions_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetPartitions")
}
func TestClusterHandler_GetPartitions_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{"name": "normal"},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetPartition_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetPartition")
}
func TestClusterHandler_GetPartition_NotFound_LogsWarn(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
if recorded.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
}
entry := recorded.All()[0]
if entry.Level != zapcore.WarnLevel {
t.Fatalf("expected Warn level, got %v", entry.Level)
}
if entry.Message != "not found" {
t.Fatalf("expected message 'not found', got %q", entry.Message)
}
assertField(t, entry.Context, "method", "GetPartition")
assertField(t, entry.Context, "name", "nonexistent")
}
func TestClusterHandler_GetPartition_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": int32(10),
},
"cpus": map[string]interface{}{
"total": int32(640),
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{"number": int64(60)},
},
},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetDiag_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetDiag")
}
func TestClusterHandler_GetDiag_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"statistics": map[string]interface{}{
"server_thread_count": 3,
"agent_queue_size": 0,
"jobs_submitted": 100,
"jobs_started": 90,
"jobs_completed": 85,
"schedule_cycle_last": 10,
"schedule_cycle_total": 500,
},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
}
}
// assertField checks that a zap Field slice contains a string field with the given key and value.
func assertField(t *testing.T, fields []zapcore.Field, key, value string) {
t.Helper()
for _, f := range fields {
if f.Key == key && f.String == value {
return
}
}
t.Fatalf("expected field %q=%q in context, got %v", key, value, fields)
}

118
internal/handler/job.go Normal file
View File

@@ -0,0 +1,118 @@
package handler
import (
"net/http"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// JobHandler handles HTTP requests for job operations.
type JobHandler struct {
jobSvc *service.JobService
logger *zap.Logger
}
// NewJobHandler creates a new JobHandler with the given JobService.
func NewJobHandler(jobSvc *service.JobService, logger *zap.Logger) *JobHandler {
return &JobHandler{jobSvc: jobSvc, logger: logger}
}
// SubmitJob handles POST /api/v1/jobs/submit.
func (h *JobHandler) SubmitJob(c *gin.Context) {
var req model.SubmitJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "invalid request body"))
server.BadRequest(c, "invalid request body")
return
}
if req.Script == "" {
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "script is required"))
server.BadRequest(c, "script is required")
return
}
resp, err := h.jobSvc.SubmitJob(c.Request.Context(), &req)
if err != nil {
h.logger.Error("handler error", zap.String("method", "SubmitJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
server.ErrorWithStatus(c, http.StatusBadGateway, "slurm error: "+err.Error())
return
}
server.Created(c, resp)
}
// GetJobs handles GET /api/v1/jobs.
func (h *JobHandler) GetJobs(c *gin.Context) {
jobs, err := h.jobSvc.GetJobs(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJobs"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, jobs)
}
// GetJob handles GET /api/v1/jobs/:id.
func (h *JobHandler) GetJob(c *gin.Context) {
jobID := c.Param("id")
resp, err := h.jobSvc.GetJob(c.Request.Context(), jobID)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJob"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if resp == nil {
h.logger.Warn("bad request", zap.String("method", "GetJob"), zap.String("error", "job not found"))
server.NotFound(c, "job not found")
return
}
server.OK(c, resp)
}
// CancelJob handles DELETE /api/v1/jobs/:id.
func (h *JobHandler) CancelJob(c *gin.Context) {
jobID := c.Param("id")
err := h.jobSvc.CancelJob(c.Request.Context(), jobID)
if err != nil {
h.logger.Error("handler error", zap.String("method", "CancelJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
server.ErrorWithStatus(c, http.StatusBadGateway, err.Error())
return
}
server.OK(c, gin.H{"message": "job cancelled"})
}
// GetJobHistory handles GET /api/v1/jobs/history.
func (h *JobHandler) GetJobHistory(c *gin.Context) {
var query model.JobHistoryQuery
if err := c.ShouldBindQuery(&query); err != nil {
h.logger.Warn("bad request", zap.String("method", "GetJobHistory"), zap.String("error", "invalid query params"))
server.BadRequest(c, "invalid query params")
return
}
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 {
query.PageSize = 20
}
resp, err := h.jobSvc.GetJobHistory(c.Request.Context(), &query)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJobHistory"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}

View File

@@ -0,0 +1,821 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func setupJobRouter(h *JobHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
jobs := v1.Group("/jobs")
{
jobs.POST("/submit", h.SubmitJob)
jobs.GET("", h.GetJobs)
jobs.GET("/history", h.GetJobHistory)
jobs.GET("/:id", h.GetJob)
jobs.DELETE("/:id", h.CancelJob)
}
return r
}
func setupJobHandler(mux *http.ServeMux) (*httptest.Server, *JobHandler) {
srv := httptest.NewServer(mux)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := service.NewJobService(client, zap.NewNop())
return srv, NewJobHandler(jobSvc, zap.NewNop())
}
func setupJobHandlerWithObserver(mux *http.ServeMux) (*httptest.Server, *JobHandler, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
srv := httptest.NewServer(mux)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := service.NewJobService(client, l)
return srv, NewJobHandler(jobSvc, l), recorded
}
func handlerLogs(logs *observer.ObservedLogs) []observer.LoggedEntry {
var handler []observer.LoggedEntry
for _, e := range logs.All() {
for _, f := range e.Context {
if f.Key == "method" {
handler = append(handler, e)
break
}
}
}
return handler
}
func TestSubmitJob_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: slurm.Ptr(int32(123))},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"script":"#!/bin/bash\necho hello"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if int(data["job_id"].(float64)) != 123 {
t.Errorf("expected job_id=123, got %v", data["job_id"])
}
}
func TestSubmitJob_MissingScript(t *testing.T) {
mux := http.NewServeMux()
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"partition":"normal"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["success"].(bool) {
t.Fatal("expected success=false")
}
if resp["error"] != "invalid request body" && resp["error"] != "script is required" {
t.Errorf("expected validation error, got %v", resp["error"])
}
}
func TestSubmitJob_EmptyScript(t *testing.T) {
mux := http.NewServeMux()
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"script":""}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["success"].(bool) {
t.Fatal("expected success=false")
}
}
func TestSubmitJob_SlurmError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"script":"#!/bin/bash\necho hello"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetJobs_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("job1")},
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("job2")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
}
func TestGetJob_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{
{JobID: slurm.Ptr(int32(42)), Name: slurm.Ptr("test-job")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if int(data["job_id"].(float64)) != 42 {
t.Errorf("expected job_id=42, got %v", data["job_id"])
}
}
func TestGetJob_NotFound(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/999", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["error"].(string) != "job not found" {
t.Errorf("expected 'job not found' error, got %v", resp["error"])
}
}
func TestCancelJob_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiResp{})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["message"].(string) != "job cancelled" {
t.Errorf("expected 'job cancelled', got %v", data["message"])
}
}
func TestGetJobHistory_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
Jobs: []slurm.Job{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("hist1")},
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("hist2")},
{JobID: slurm.Ptr(int32(3)), Name: slurm.Ptr("hist3")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=2", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
jobs := data["jobs"].([]interface{})
if len(jobs) != 2 {
t.Fatalf("expected 2 jobs on page 1, got %d", len(jobs))
}
if int(data["total"].(float64)) != 3 {
t.Errorf("expected total=3, got %v", data["total"])
}
if int(data["page"].(float64)) != 1 {
t.Errorf("expected page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 2 {
t.Errorf("expected page_size=2, got %v", data["page_size"])
}
}
func TestGetJobHistory_DefaultPagination(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
Jobs: []slurm.Job{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("h1")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if int(data["page"].(float64)) != 1 {
t.Errorf("expected default page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 20 {
t.Errorf("expected default page_size=20, got %v", data["page_size"])
}
}
func TestSubmitJob_InvalidBody(t *testing.T) {
mux := http.NewServeMux()
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`not json`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
// --- Logging verification tests ---
func TestSubmitJob_InvalidBody_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`not json`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
if entry.Context[0].Key != "method" || entry.Context[0].String != "SubmitJob" {
t.Errorf("expected method=SubmitJob, got %v", entry.Context[0])
}
}
func TestSubmitJob_EmptyScript_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":""}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
}
func TestSubmitJob_SlurmError_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":"#!/bin/bash\necho hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", entry.Level)
}
foundMethod := false
foundStatus := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "SubmitJob" {
foundMethod = true
}
if f.Key == "status" && f.Integer == http.StatusBadGateway {
foundStatus = true
}
}
if !foundMethod {
t.Error("expected method=SubmitJob in log fields")
}
if !foundStatus {
t.Error("expected status=502 in log fields")
}
}
func TestSubmitJob_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: slurm.Ptr(int32(123))},
})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":"#!/bin/bash\necho hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}
func TestGetJobs_Error_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "GetJobs" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=GetJobs in log fields")
}
}
func TestGetJob_NotFound_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/999", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "GetJob" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=GetJob in log fields")
}
}
func TestGetJob_Error_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
if hLogs[0].Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", hLogs[0].Level)
}
}
func TestCancelJob_SlurmError_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "CancelJob" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=CancelJob in log fields")
}
}
func TestGetJobHistory_InvalidQuery_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=abc", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "GetJobHistory" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=GetJobHistory in log fields")
}
}
func TestGetJobHistory_Error_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=10", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
if hLogs[0].Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", hLogs[0].Level)
}
}
func TestGetJobHistory_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
Jobs: []slurm.Job{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("h1")},
},
})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=10", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}
func TestGetJobs_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}
func TestCancelJob_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiResp{})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}

View File

@@ -0,0 +1,139 @@
package handler
import (
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type TemplateHandler struct {
store *store.TemplateStore
logger *zap.Logger
}
func NewTemplateHandler(s *store.TemplateStore, logger *zap.Logger) *TemplateHandler {
return &TemplateHandler{store: s, logger: logger}
}
func (h *TemplateHandler) ListTemplates(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
templates, total, err := h.store.List(c.Request.Context(), page, pageSize)
if err != nil {
h.logger.Error("failed to list templates", zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, gin.H{
"templates": templates,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *TemplateHandler) CreateTemplate(c *gin.Context) {
var req model.CreateTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create template", zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if req.Name == "" || req.Script == "" {
h.logger.Warn("missing required fields for create template")
server.BadRequest(c, "name and script are required")
return
}
id, err := h.store.Create(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to create template", zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("template created", zap.Int64("id", id))
server.Created(c, gin.H{"id": id})
}
func (h *TemplateHandler) GetTemplate(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid template id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
tmpl, err := h.store.GetByID(c.Request.Context(), id)
if err != nil {
h.logger.Error("failed to get template", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if tmpl == nil {
h.logger.Warn("template not found", zap.Int64("id", id))
server.NotFound(c, "template not found")
return
}
server.OK(c, tmpl)
}
func (h *TemplateHandler) UpdateTemplate(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid template id for update", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
var req model.UpdateTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for update template", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if err := h.store.Update(c.Request.Context(), id, &req); err != nil {
h.logger.Error("failed to update template", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("template updated", zap.Int64("id", id))
server.OK(c, gin.H{"message": "template updated"})
}
func (h *TemplateHandler) DeleteTemplate(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid template id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.store.Delete(c.Request.Context(), id); err != nil {
h.logger.Error("failed to delete template", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("template deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "template deleted"})
}

View File

@@ -0,0 +1,387 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupTemplateHandler() (*TemplateHandler, *gorm.DB) {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.JobTemplate{})
s := store.NewTemplateStore(db)
h := NewTemplateHandler(s, zap.NewNop())
return h, db
}
func setupTemplateRouter(h *TemplateHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
templates := v1.Group("/templates")
templates.GET("", h.ListTemplates)
templates.POST("", h.CreateTemplate)
templates.GET("/:id", h.GetTemplate)
templates.PUT("/:id", h.UpdateTemplate)
templates.DELETE("/:id", h.DeleteTemplate)
return r
}
func TestListTemplates_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal", QOS: "high", CPUs: 4, Memory: "4GB"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestCreateTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
body := `{"name":"my-tpl","description":"desc","script":"echo hello","partition":"gpu"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestCreateTemplate_MissingFields(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
body := `{"name":"","script":""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestGetTemplate_NotFound(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetTemplate_InvalidID(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
body := `{"name":"updated"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestDeleteTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
// itoa converts int64 to string for URL path construction.
func itoa(id int64) string {
return fmt.Sprintf("%d", id)
}
func setupTemplateHandlerWithObserver() (*TemplateHandler, *gorm.DB, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.JobTemplate{})
s := store.NewTemplateStore(db)
return NewTemplateHandler(s, l), db, recorded
}
func TestTemplateLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
body := `{"name":"log-tpl","script":"echo hi"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template created").FilterLevelExact(zapcore.InfoLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 info log for 'template created', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
body := `{"name":"updated"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template updated").FilterLevelExact(zapcore.InfoLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 info log for 'template updated', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template deleted").FilterLevelExact(zapcore.InfoLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 info log for 'template deleted', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template not found").FilterLevelExact(zapcore.WarnLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 warn log for 'template not found', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_CreateBadRequest_LogsWarn(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
body := `{"name":"","script":""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
if len(warnEntries) == 0 {
t.Fatal("expected at least 1 warn log for bad request")
}
}
func TestTemplateLogging_InvalidID_LogsWarn(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
if len(warnEntries) == 0 {
t.Fatal("expected at least 1 warn log for invalid id")
}
}
func TestTemplateLogging_ListSuccess_NoInfoLog(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
// Seed data
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi"})
// Reset recorded logs so the create log doesn't interfere
recorded.TakeAll()
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
infoEntries := recorded.FilterLevelExact(zapcore.InfoLevel).All()
if len(infoEntries) != 0 {
t.Fatalf("expected 0 info logs for list success, got %d: %+v", len(infoEntries), infoEntries)
}
}
func TestTemplateLogging_LogsDoNotContainTemplateContent(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
body := `{"name":"secret-name","script":"secret-script","partition":"secret-partition"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.All()
for _, e := range entries {
logStr := e.Message + " " + fmt.Sprintf("%v", e.ContextMap())
if strings.Contains(logStr, "secret-name") || strings.Contains(logStr, "secret-script") || strings.Contains(logStr, "secret-partition") {
t.Fatalf("log entry contains sensitive template content: %s", logStr)
}
}
}

148
internal/logger/gorm.go Normal file
View File

@@ -0,0 +1,148 @@
package logger
import (
"context"
"errors"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
const slowQueryThreshold = 200 * time.Millisecond
// GormLogger implements gorm's logger.Interface backed by zap.
type GormLogger struct {
logger *zap.Logger
level zapcore.Level
silent bool
}
// Compile-time interface check.
var _ gormlogger.Interface = (*GormLogger)(nil)
// NewGormLogger creates a new GormLogger wrapping the given zap logger.
// The level string maps to zap levels; empty defaults to "warn".
// The special value "silent" suppresses all output.
func NewGormLogger(zapLogger *zap.Logger, level string) gormlogger.Interface {
lvl := parseGormLevel(level)
silent := level == "silent"
return &GormLogger{
logger: zapLogger,
level: lvl,
silent: silent,
}
}
// LogMode returns a new GormLogger with the given gorm log level.
// It does NOT mutate the receiver.
func (l *GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
newLogger := &GormLogger{
logger: l.logger,
level: l.level,
silent: l.silent,
}
switch level {
case gormlogger.Silent:
newLogger.silent = true
case gormlogger.Error:
newLogger.level = zapcore.ErrorLevel
newLogger.silent = false
case gormlogger.Warn:
newLogger.level = zapcore.WarnLevel
newLogger.silent = false
case gormlogger.Info:
newLogger.level = zapcore.InfoLevel
newLogger.silent = false
}
return newLogger
}
// Info logs at zap.InfoLevel with structured fields from key-value pairs.
func (l *GormLogger) Info(ctx context.Context, msg string, args ...any) {
if l.silent || l.level > zapcore.InfoLevel {
return
}
l.logger.Info(msg, argsToFields(args)...)
}
// Warn logs at zap.WarnLevel with structured fields from key-value pairs.
func (l *GormLogger) Warn(ctx context.Context, msg string, args ...any) {
if l.silent || l.level > zapcore.WarnLevel {
return
}
l.logger.Warn(msg, argsToFields(args)...)
}
// Error logs at zap.ErrorLevel with structured fields from key-value pairs.
func (l *GormLogger) Error(ctx context.Context, msg string, args ...any) {
if l.silent || l.level > zapcore.ErrorLevel {
return
}
l.logger.Error(msg, argsToFields(args)...)
}
// Trace logs SQL query information based on execution results.
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
if l.silent {
return
}
elapsed := time.Since(begin)
sql, rows := fc()
switch {
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
l.logger.Error("gorm query error",
zap.Error(err),
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed),
)
case elapsed > slowQueryThreshold:
l.logger.Warn("gorm slow query",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Float64("elapsed_ms", float64(elapsed.Nanoseconds())/1e6),
)
default:
if l.level > zapcore.InfoLevel {
return
}
l.logger.Info("gorm query",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed),
)
}
}
// parseGormLevel parses a level string into a zapcore.Level.
// Defaults to zapcore.WarnLevel.
func parseGormLevel(level string) zapcore.Level {
if level == "" || level == "silent" {
return zapcore.WarnLevel
}
var lvl zapcore.Level
if err := lvl.UnmarshalText([]byte(level)); err != nil {
return zapcore.WarnLevel
}
return lvl
}
// argsToFields converts alternating key-value pairs into zap fields.
func argsToFields(args []any) []zap.Field {
fields := make([]zap.Field, 0, len(args)/2)
for i := 0; i+1 < len(args); i += 2 {
key, ok := args[i].(string)
if !ok {
continue
}
fields = append(fields, zap.Any(key, args[i+1]))
}
return fields
}

View File

@@ -0,0 +1,509 @@
package logger
import (
"context"
"errors"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
// newObservedLogger creates a zap logger backed by an observer for test assertions.
func newObservedLogger() (*zap.Logger, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
return zap.New(core), recorded
}
func TestNewGormLogger_DefaultLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
if g.level != zapcore.WarnLevel {
t.Fatalf("expected warn level, got %v", g.level)
}
if g.silent {
t.Fatal("expected silent=false")
}
}
func TestNewGormLogger_Silent(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "silent")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
if !g.silent {
t.Fatal("expected silent=true")
}
}
func TestNewGormLogger_ExplicitLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "error")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
if g.level != zapcore.ErrorLevel {
t.Fatalf("expected error level, got %v", g.level)
}
}
func TestNewGormLogger_InvalidLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "bogus")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
// Should default to warn
if g.level != zapcore.WarnLevel {
t.Fatalf("expected warn level fallback, got %v", g.level)
}
}
func TestGormLogger_Info(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
gl.Info(context.Background(), "test info", "key", "value")
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Message != "test info" {
t.Fatalf("expected message 'test info', got %q", entries[0].Message)
}
if entries[0].Level != zapcore.InfoLevel {
t.Fatalf("expected InfoLevel, got %v", entries[0].Level)
}
}
func TestGormLogger_Warn(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
gl.Warn(context.Background(), "test warn", "code", 42)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Message != "test warn" {
t.Fatalf("expected message 'test warn', got %q", entries[0].Message)
}
if entries[0].Level != zapcore.WarnLevel {
t.Fatalf("expected WarnLevel, got %v", entries[0].Level)
}
}
func TestGormLogger_Error(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "error")
gl.Error(context.Background(), "test error", "module", "gorm")
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Message != "test error" {
t.Fatalf("expected message 'test error', got %q", entries[0].Message)
}
if entries[0].Level != zapcore.ErrorLevel {
t.Fatalf("expected ErrorLevel, got %v", entries[0].Level)
}
}
func TestGormLogger_LevelFiltering(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
// Info should be suppressed at warn level
gl.Info(context.Background(), "should be suppressed")
if len(recorded.All()) != 0 {
t.Fatal("info should be suppressed at warn level")
}
// Warn should pass
gl.Warn(context.Background(), "should pass")
if len(recorded.All()) != 1 {
t.Fatal("warn should pass at warn level")
}
}
func TestGormLogger_SilentSuppressesAll(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "silent")
gl.Info(context.Background(), "info msg")
gl.Warn(context.Background(), "warn msg")
gl.Error(context.Background(), "error msg")
if len(recorded.All()) != 0 {
t.Fatalf("silent mode should suppress all logs, got %d", len(recorded.All()))
}
}
func TestGormLogger_LogMode_ReturnsNewInstance(t *testing.T) {
log, _ := newObservedLogger()
original := NewGormLogger(log, "warn").(*GormLogger)
modified := original.LogMode(gormlogger.Info)
// Must be a different instance
if modified == original {
t.Fatal("LogMode should return a new instance")
}
// Original should be unchanged
if original.level != zapcore.WarnLevel {
t.Fatalf("original level should remain warn, got %v", original.level)
}
// New instance should have InfoLevel
mod := modified.(*GormLogger)
if mod.level != zapcore.InfoLevel {
t.Fatalf("modified level should be info, got %v", mod.level)
}
}
func TestGormLogger_LogMode_Silent(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "info").(*GormLogger)
silent := gl.LogMode(gormlogger.Silent).(*GormLogger)
if !silent.silent {
t.Fatal("LogMode(Silent) should set silent=true")
}
// Original should not be silent
if gl.silent {
t.Fatal("original should not be affected")
}
}
func TestGormLogger_LogMode_ErrorLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "info").(*GormLogger)
modified := gl.LogMode(gormlogger.Error).(*GormLogger)
if modified.level != zapcore.ErrorLevel {
t.Fatalf("expected ErrorLevel, got %v", modified.level)
}
}
func TestGormLogger_LogMode_WarnLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "info").(*GormLogger)
modified := gl.LogMode(gormlogger.Warn).(*GormLogger)
if modified.level != zapcore.WarnLevel {
t.Fatalf("expected WarnLevel, got %v", modified.level)
}
}
func TestGormLogger_Trace_WithError(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users", 0 }
err := errors.New("connection refused")
gl.Trace(context.Background(), begin, fc, err)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Fatalf("expected ErrorLevel for real errors, got %v", entries[0].Level)
}
if entries[0].Message != "gorm query error" {
t.Fatalf("unexpected message: %q", entries[0].Message)
}
}
func TestGormLogger_Trace_ErrRecordNotFound(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 }
gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound)
// ErrRecordNotFound should NOT be logged as error
// At default level with no slow query, it should log as info
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level == zapcore.ErrorLevel {
t.Fatal("ErrRecordNotFound should NOT be logged at ErrorLevel")
}
if entries[0].Message != "gorm query" {
t.Fatalf("expected 'gorm query', got %q", entries[0].Message)
}
}
func TestGormLogger_Trace_ErrRecordNotFound_SuppressedAtWarnLevel(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 }
gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound)
// ErrRecordNotFound is not a real error, so it falls through to the default path.
// At warn level, the default path (info-level) is suppressed.
entries := recorded.All()
if len(entries) != 0 {
t.Fatalf("expected 0 entries at warn level, got %d", len(entries))
}
}
func TestGormLogger_Trace_SlowQuery(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
// Simulate a begin time far enough in the past to exceed the threshold
begin := time.Now().Add(-500 * time.Millisecond)
fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level != zapcore.WarnLevel {
t.Fatalf("expected WarnLevel for slow query, got %v", entries[0].Level)
}
if entries[0].Message != "gorm slow query" {
t.Fatalf("expected 'gorm slow query', got %q", entries[0].Message)
}
}
func TestGormLogger_Trace_NormalQuery(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
begin := time.Now()
fc := func() (string, int64) { return "SELECT 1", 1 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level != zapcore.InfoLevel {
t.Fatalf("expected InfoLevel for normal query, got %v", entries[0].Level)
}
if entries[0].Message != "gorm query" {
t.Fatalf("expected 'gorm query', got %q", entries[0].Message)
}
}
func TestGormLogger_Trace_NormalQuerySuppressedAtWarnLevel(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "SELECT 1", 1 }
gl.Trace(context.Background(), begin, fc, nil)
// Normal queries are info-level, suppressed at warn
if len(recorded.All()) != 0 {
t.Fatal("normal query should be suppressed at warn level")
}
}
func TestGormLogger_Trace_SilentSuppressesAll(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "silent")
begin := time.Now().Add(-500 * time.Millisecond)
fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 }
err := errors.New("some error")
gl.Trace(context.Background(), begin, fc, err)
if len(recorded.All()) != 0 {
t.Fatal("silent mode should suppress even Trace errors")
}
}
func TestGormLogger_Trace_StructuredFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users", 42 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
// Verify structured fields
fields := entries[0].ContextMap()
if fields["sql"] != "SELECT * FROM users" {
t.Fatalf("expected sql field, got %v", fields["sql"])
}
if fields["rows"] != int64(42) {
t.Fatalf("expected rows=42, got %v", fields["rows"])
}
if _, ok := fields["elapsed"]; !ok {
t.Fatal("expected elapsed field")
}
}
func TestGormLogger_Trace_ErrorFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "INSERT INTO users", 0 }
err := errors.New("duplicate key")
gl.Trace(context.Background(), begin, fc, err)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["sql"] != "INSERT INTO users" {
t.Fatalf("expected sql field, got %v", fields["sql"])
}
if _, ok := fields["error"]; !ok {
t.Fatal("expected error field")
}
}
func TestGormLogger_Trace_SlowQueryFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now().Add(-500 * time.Millisecond)
fc := func() (string, int64) { return "SELECT * FROM large_table", 1000 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["sql"] != "SELECT * FROM large_table" {
t.Fatalf("expected sql field, got %v", fields["sql"])
}
if fields["rows"] != int64(1000) {
t.Fatalf("expected rows=1000, got %v", fields["rows"])
}
if _, ok := fields["elapsed_ms"]; !ok {
t.Fatal("expected elapsed_ms field for slow query")
}
}
func TestArgsToFields(t *testing.T) {
tests := []struct {
name string
args []interface{}
expected int
}{
{"empty", nil, 0},
{"single_pair", []interface{}{"key", "value"}, 1},
{"two_pairs", []interface{}{"a", 1, "b", 2}, 2},
{"odd_args_ignores_last", []interface{}{"key", "value", "orphan"}, 1},
{"non_string_key_ignored", []interface{}{123, "value"}, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fields := argsToFields(tt.args)
if len(fields) != tt.expected {
t.Fatalf("expected %d fields, got %d", tt.expected, len(fields))
}
})
}
}
func TestArgsToFields_FieldValues(t *testing.T) {
fields := argsToFields([]interface{}{"name", "test", "count", 42})
if len(fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(fields))
}
if fields[0].Key != "name" {
t.Fatalf("expected key 'name', got %q", fields[0].Key)
}
if fields[1].Key != "count" {
t.Fatalf("expected key 'count', got %q", fields[1].Key)
}
}
func TestParseGormLevel(t *testing.T) {
tests := []struct {
input string
expected zapcore.Level
}{
{"debug", zapcore.DebugLevel},
{"info", zapcore.InfoLevel},
{"warn", zapcore.WarnLevel},
{"error", zapcore.ErrorLevel},
{"", zapcore.WarnLevel},
{"silent", zapcore.WarnLevel},
{"invalid", zapcore.WarnLevel},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := parseGormLevel(tt.input)
if got != tt.expected {
t.Fatalf("expected %v, got %v", tt.expected, got)
}
})
}
}
func TestGormLogger_InfoWithMultipleFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
gl.Info(context.Background(), "multi fields", "key1", "val1", "key2", 123, "key3", true)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["key1"] != "val1" {
t.Fatalf("expected key1=val1, got %v", fields["key1"])
}
if fields["key2"] != int64(123) {
t.Fatalf("expected key2=123, got %v", fields["key2"])
}
if fields["key3"] != true {
t.Fatalf("expected key3=true, got %v", fields["key3"])
}
}

93
internal/logger/logger.go Normal file
View File

@@ -0,0 +1,93 @@
package logger
import (
"fmt"
"os"
"gcy_hpc_server/internal/config"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
func NewLogger(cfg config.LogConfig) (*zap.Logger, error) {
level := applyDefault(cfg.Level, "info")
var zapLevel zapcore.Level
if err := zapLevel.UnmarshalText([]byte(level)); err != nil {
return nil, fmt.Errorf("invalid log level %q: %w", level, err)
}
encoding := applyDefault(cfg.Encoding, "json")
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "ts"
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
var encoder zapcore.Encoder
switch encoding {
case "console":
encoder = zapcore.NewConsoleEncoder(encoderConfig)
default:
encoder = zapcore.NewJSONEncoder(encoderConfig)
}
var syncers []zapcore.WriteSyncer
stdout := true
if cfg.OutputStdout != nil {
stdout = *cfg.OutputStdout
}
if stdout {
syncers = append(syncers, zapcore.AddSync(os.Stdout))
}
if cfg.FilePath != "" {
maxSize := applyDefaultInt(cfg.MaxSize, 100)
maxBackups := applyDefaultInt(cfg.MaxBackups, 5)
maxAge := applyDefaultInt(cfg.MaxAge, 30)
compress := cfg.Compress || cfg.MaxSize == 0 && cfg.MaxBackups == 0 && cfg.MaxAge == 0
lj := &lumberjack.Logger{
Filename: cfg.FilePath,
MaxSize: maxSize,
MaxBackups: maxBackups,
MaxAge: maxAge,
Compress: compress,
}
syncers = append(syncers, zapcore.AddSync(lj))
}
if len(syncers) == 0 {
syncers = append(syncers, zapcore.AddSync(os.Stdout))
}
writeSyncer := syncers[0]
if len(syncers) > 1 {
writeSyncer = zapcore.NewMultiWriteSyncer(syncers...)
}
core := zapcore.NewCore(encoder, writeSyncer, zapLevel)
opts := []zap.Option{
zap.AddCaller(),
zap.AddStacktrace(zapcore.ErrorLevel),
}
return zap.New(core, opts...), nil
}
func applyDefault(val, def string) string {
if val == "" {
return def
}
return val
}
func applyDefaultInt(val, def int) int {
if val == 0 {
return def
}
return val
}

View File

@@ -0,0 +1,286 @@
package logger
import (
"os"
"path/filepath"
"strings"
"testing"
"gcy_hpc_server/internal/config"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest"
)
func ptrBool(v bool) *bool { return &v }
// TestNewLogger_JSONConfig creates a logger with JSON encoding and verifies
// that log entries are emitted successfully.
func TestNewLogger_JSONConfig(t *testing.T) {
cfg := config.LogConfig{
Level: "debug",
Encoding: "json",
OutputStdout: ptrBool(true),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
// Should not panic when logging
log.Info("json logger test", zap.String("key", "value"))
}
// TestNewLogger_ConsoleConfig creates a logger with console encoding.
func TestNewLogger_ConsoleConfig(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "console",
OutputStdout: ptrBool(true),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("console logger test", zap.Int("num", 42))
}
// TestNewLogger_InvalidLevel verifies that an invalid log level returns an error.
func TestNewLogger_InvalidLevel(t *testing.T) {
cfg := config.LogConfig{
Level: "bogus",
Encoding: "json",
}
_, err := NewLogger(cfg)
if err == nil {
t.Fatal("expected error for invalid log level, got nil")
}
}
// TestNewLogger_EmptyConfig verifies defaults are applied when config is zero-value.
func TestNewLogger_EmptyConfig(t *testing.T) {
cfg := config.LogConfig{} // all zero values
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("default config test")
}
// TestNewLogger_FileOutput verifies that file output with rotation config works.
func TestNewLogger_FileOutput(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "test.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
FilePath: logFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("file output test", zap.String("msg", "hello"))
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if len(data) == 0 {
t.Fatal("log file is empty, expected output")
}
if !strings.Contains(string(data), "file output test") {
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
}
}
// TestNewLogger_MultiWriter verifies that both stdout and file output work together.
func TestNewLogger_MultiWriter(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "multi.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
OutputStdout: ptrBool(true),
FilePath: logFile,
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("multi writer test", zap.String("writer", "both"))
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if !strings.Contains(string(data), "multi writer test") {
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
}
}
// TestNewLogger_Observer verifies actual log output content using zaptest.
func TestNewLogger_Observer(t *testing.T) {
// Use zaptest.NewLogger to capture logs in test output
log := zaptest.NewLogger(t,
zaptest.WrapOptions(zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)),
zaptest.Level(zapcore.DebugLevel),
)
// These should all succeed without panicking
log.Debug("debug msg", zap.String("k", "v"))
log.Info("info msg", zap.Int("n", 1))
log.Warn("warn msg")
log.Error("error msg")
}
// TestNewLogger_AllLevels verifies all valid log levels parse correctly.
func TestNewLogger_AllLevels(t *testing.T) {
levels := []string{"debug", "info", "warn", "error", "dpanic", "panic", "fatal"}
for _, level := range levels {
t.Run(level, func(t *testing.T) {
cfg := config.LogConfig{
Level: level,
Encoding: "json",
OutputStdout: ptrBool(true),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("level %q: NewLogger returned error: %v", level, err)
}
log.Sync()
})
}
}
// TestNewLogger_InvalidEncoding falls back gracefully — the factory should
// treat an unrecognized encoding as an error or default to JSON.
func TestNewLogger_InvalidEncoding(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "xml",
OutputStdout: ptrBool(true),
}
// The implementation should default to JSON for unknown encoding.
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("unexpected error for invalid encoding: %v", err)
}
defer log.Sync()
log.Info("invalid encoding test")
}
// TestNewLogger_DefaultRotation verifies rotation defaults are applied.
func TestNewLogger_DefaultRotation(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "rotation.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
FilePath: logFile,
// MaxSize, MaxBackups, MaxAge, Compress all zero → defaults apply
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("rotation defaults test")
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if len(data) == 0 {
t.Fatal("log file is empty")
}
}
func TestNewLogger_OutputStdoutNil(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("default stdout test")
}
func TestNewLogger_OutputStdoutFalseWithFile(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "nostdout.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
OutputStdout: ptrBool(false),
FilePath: logFile,
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("file only test")
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if !strings.Contains(string(data), "file only test") {
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
}
}
func TestNewLogger_OutputStdoutFalseFallback(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
OutputStdout: ptrBool(false),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("fallback stdout test")
}

View File

@@ -0,0 +1,25 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RequestLogger returns a Gin middleware that logs each request using zap.
func RequestLogger(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request",
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.Int("status", c.Writer.Status()),
zap.Duration("latency", time.Since(start)),
zap.String("client_ip", c.ClientIP()),
)
}
}

23
internal/model/cluster.go Normal file
View File

@@ -0,0 +1,23 @@
package model
// NodeResponse is the simplified API response for a node.
type NodeResponse struct {
Name string `json:"name"`
State []string `json:"state"`
CPUs int32 `json:"cpus"`
RealMemory int64 `json:"real_memory"`
AllocMem int64 `json:"alloc_memory,omitempty"`
Arch string `json:"architecture,omitempty"`
OS string `json:"operating_system,omitempty"`
}
// PartitionResponse is the simplified API response for a partition.
type PartitionResponse struct {
Name string `json:"name"`
State []string `json:"state"`
Nodes string `json:"nodes,omitempty"`
TotalCPUs int32 `json:"total_cpus,omitempty"`
TotalNodes int32 `json:"total_nodes,omitempty"`
MaxTime string `json:"max_time,omitempty"`
Default bool `json:"default,omitempty"`
}

47
internal/model/job.go Normal file
View File

@@ -0,0 +1,47 @@
package model
// SubmitJobRequest is the API request for submitting a job.
type SubmitJobRequest struct {
Script string `json:"script" binding:"required"`
Partition string `json:"partition,omitempty"`
QOS string `json:"qos,omitempty"`
CPUs int32 `json:"cpus,omitempty"`
Memory string `json:"memory,omitempty"`
TimeLimit string `json:"time_limit,omitempty"`
JobName string `json:"job_name,omitempty"`
Environment map[string]string `json:"environment,omitempty"`
}
// JobResponse is the simplified API response for a job.
type JobResponse struct {
JobID int32 `json:"job_id"`
Name string `json:"name"`
State []string `json:"job_state"`
Partition string `json:"partition"`
SubmitTime *int64 `json:"submit_time,omitempty"`
StartTime *int64 `json:"start_time,omitempty"`
EndTime *int64 `json:"end_time,omitempty"`
ExitCode *int32 `json:"exit_code,omitempty"`
Nodes string `json:"nodes,omitempty"`
}
// JobListResponse is the paginated response for job listings.
type JobListResponse struct {
Jobs []JobResponse `json:"jobs"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// JobHistoryQuery contains query parameters for job history.
type JobHistoryQuery struct {
Users string `form:"users" json:"users,omitempty"`
StartTime string `form:"start_time" json:"start_time,omitempty"`
EndTime string `form:"end_time" json:"end_time,omitempty"`
Account string `form:"account" json:"account,omitempty"`
Partition string `form:"partition" json:"partition,omitempty"`
State string `form:"state" json:"state,omitempty"`
JobName string `form:"job_name" json:"job_name,omitempty"`
Page int `form:"page,default=1" json:"page,omitempty"`
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"`
}

View File

@@ -0,0 +1,45 @@
package model
import "time"
// JobTemplate represents a saved job template.
type JobTemplate struct {
ID int64 `json:"id" gorm:"primaryKey;autoIncrement"`
Name string `json:"name" gorm:"uniqueIndex;size:255;not null"`
Description string `json:"description,omitempty" gorm:"type:text"`
Script string `json:"script" gorm:"type:text;not null"`
Partition string `json:"partition,omitempty" gorm:"size:255"`
QOS string `json:"qos,omitempty" gorm:"column:qos;size:255"`
CPUs int `json:"cpus,omitempty" gorm:"column:cpus"`
Memory string `json:"memory,omitempty" gorm:"size:50"`
TimeLimit string `json:"time_limit,omitempty" gorm:"column:time_limit;size:50"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName specifies the database table name for GORM.
func (JobTemplate) TableName() string { return "job_templates" }
// CreateTemplateRequest is the API request for creating a template.
type CreateTemplateRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description,omitempty"`
Script string `json:"script" binding:"required"`
Partition string `json:"partition,omitempty"`
QOS string `json:"qos,omitempty"`
CPUs int `json:"cpus,omitempty"`
Memory string `json:"memory,omitempty"`
TimeLimit string `json:"time_limit,omitempty"`
}
// UpdateTemplateRequest is the API request for updating a template.
type UpdateTemplateRequest struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Script string `json:"script,omitempty"`
Partition string `json:"partition,omitempty"`
QOS string `json:"qos,omitempty"`
CPUs int `json:"cpus,omitempty"`
Memory string `json:"memory,omitempty"`
TimeLimit string `json:"time_limit,omitempty"`
}

View File

@@ -0,0 +1,44 @@
package server
import (
"net/http"
"github.com/gin-gonic/gin"
)
// APIResponse represents the unified JSON structure for all API responses.
type APIResponse struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// OK responds with 200 and success data.
func OK(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, APIResponse{Success: true, Data: data})
}
// Created responds with 201 and success data.
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, APIResponse{Success: true, Data: data})
}
// BadRequest responds with 400 and an error message.
func BadRequest(c *gin.Context, msg string) {
c.JSON(http.StatusBadRequest, APIResponse{Success: false, Error: msg})
}
// NotFound responds with 404 and an error message.
func NotFound(c *gin.Context, msg string) {
c.JSON(http.StatusNotFound, APIResponse{Success: false, Error: msg})
}
// InternalError responds with 500 and an error message.
func InternalError(c *gin.Context, msg string) {
c.JSON(http.StatusInternalServerError, APIResponse{Success: false, Error: msg})
}
// ErrorWithStatus responds with a custom status code and an error message.
func ErrorWithStatus(c *gin.Context, code int, msg string) {
c.JSON(code, APIResponse{Success: false, Error: msg})
}

View File

@@ -0,0 +1,116 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func setupTestContext() (*httptest.ResponseRecorder, *gin.Context) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
return w, c
}
func parseResponse(t *testing.T, w *httptest.ResponseRecorder) APIResponse {
t.Helper()
var resp APIResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response body: %v", err)
}
return resp
}
func TestOK(t *testing.T) {
w, c := setupTestContext()
OK(c, map[string]string{"msg": "hello"})
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
resp := parseResponse(t, w)
if !resp.Success {
t.Fatal("expected success=true")
}
}
func TestCreated(t *testing.T) {
w, c := setupTestContext()
Created(c, map[string]int{"id": 1})
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", w.Code)
}
resp := parseResponse(t, w)
if !resp.Success {
t.Fatal("expected success=true")
}
}
func TestBadRequest(t *testing.T) {
w, c := setupTestContext()
BadRequest(c, "invalid input")
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "invalid input" {
t.Fatalf("expected error 'invalid input', got '%s'", resp.Error)
}
}
func TestNotFound(t *testing.T) {
w, c := setupTestContext()
NotFound(c, "resource missing")
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "resource missing" {
t.Fatalf("expected error 'resource missing', got '%s'", resp.Error)
}
}
func TestInternalError(t *testing.T) {
w, c := setupTestContext()
InternalError(c, "something broke")
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "something broke" {
t.Fatalf("expected error 'something broke', got '%s'", resp.Error)
}
}
func TestErrorWithStatus(t *testing.T) {
w, c := setupTestContext()
ErrorWithStatus(c, http.StatusConflict, "already exists")
if w.Code != http.StatusConflict {
t.Fatalf("expected 409, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "already exists" {
t.Fatalf("expected error 'already exists', got '%s'", resp.Error)
}
}

111
internal/server/server.go Normal file
View File

@@ -0,0 +1,111 @@
package server
import (
"net/http"
"gcy_hpc_server/internal/middleware"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type JobHandler interface {
SubmitJob(c *gin.Context)
GetJobs(c *gin.Context)
GetJobHistory(c *gin.Context)
GetJob(c *gin.Context)
CancelJob(c *gin.Context)
}
type ClusterHandler interface {
GetNodes(c *gin.Context)
GetNode(c *gin.Context)
GetPartitions(c *gin.Context)
GetPartition(c *gin.Context)
GetDiag(c *gin.Context)
}
type TemplateHandler interface {
ListTemplates(c *gin.Context)
CreateTemplate(c *gin.Context)
GetTemplate(c *gin.Context)
UpdateTemplate(c *gin.Context)
DeleteTemplate(c *gin.Context)
}
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
func NewRouter(jobH JobHandler, clusterH ClusterHandler, templateH TemplateHandler, logger *zap.Logger) *gin.Engine {
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(gin.Recovery())
if logger != nil {
r.Use(middleware.RequestLogger(logger))
}
v1 := r.Group("/api/v1")
jobs := v1.Group("/jobs")
jobs.POST("/submit", jobH.SubmitJob)
jobs.GET("", jobH.GetJobs)
jobs.GET("/history", jobH.GetJobHistory)
jobs.GET("/:id", jobH.GetJob)
jobs.DELETE("/:id", jobH.CancelJob)
v1.GET("/nodes", clusterH.GetNodes)
v1.GET("/nodes/:name", clusterH.GetNode)
v1.GET("/partitions", clusterH.GetPartitions)
v1.GET("/partitions/:name", clusterH.GetPartition)
v1.GET("/diag", clusterH.GetDiag)
templates := v1.Group("/templates")
templates.GET("", templateH.ListTemplates)
templates.POST("", templateH.CreateTemplate)
templates.GET("/:id", templateH.GetTemplate)
templates.PUT("/:id", templateH.UpdateTemplate)
templates.DELETE("/:id", templateH.DeleteTemplate)
return r
}
// NewTestRouter creates a router for testing without real handlers.
func NewTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(gin.Recovery())
v1 := r.Group("/api/v1")
registerPlaceholderRoutes(v1)
return r
}
func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
jobs := v1.Group("/jobs")
jobs.POST("/submit", notImplemented)
jobs.GET("", notImplemented)
jobs.GET("/history", notImplemented)
jobs.GET("/:id", notImplemented)
jobs.DELETE("/:id", notImplemented)
v1.GET("/nodes", notImplemented)
v1.GET("/nodes/:name", notImplemented)
v1.GET("/partitions", notImplemented)
v1.GET("/partitions/:name", notImplemented)
v1.GET("/diag", notImplemented)
templates := v1.Group("/templates")
templates.GET("", notImplemented)
templates.POST("", notImplemented)
templates.GET("/:id", notImplemented)
templates.PUT("/:id", notImplemented)
templates.DELETE("/:id", notImplemented)
}
func notImplemented(c *gin.Context) {
c.JSON(http.StatusNotImplemented, APIResponse{
Success: false,
Error: "not implemented",
})
}

View File

@@ -0,0 +1,108 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestAllRoutesRegistered(t *testing.T) {
r := NewTestRouter()
routes := r.Routes()
expected := []struct {
method string
path string
}{
{"POST", "/api/v1/jobs/submit"},
{"GET", "/api/v1/jobs"},
{"GET", "/api/v1/jobs/history"},
{"GET", "/api/v1/jobs/:id"},
{"DELETE", "/api/v1/jobs/:id"},
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/nodes/:name"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/partitions/:name"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/templates"},
{"POST", "/api/v1/templates"},
{"GET", "/api/v1/templates/:id"},
{"PUT", "/api/v1/templates/:id"},
{"DELETE", "/api/v1/templates/:id"},
}
routeMap := map[string]bool{}
for _, route := range routes {
key := route.Method + " " + route.Path
routeMap[key] = true
}
for _, exp := range expected {
key := exp.method + " " + exp.path
if !routeMap[key] {
t.Errorf("missing route: %s", key)
}
}
if len(routes) < len(expected) {
t.Errorf("expected at least %d routes, got %d", len(expected), len(routes))
}
}
func TestUnregisteredPathReturns404(t *testing.T) {
r := NewTestRouter()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/nonexistent", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for unregistered path, got %d", w.Code)
}
}
func TestRegisteredPathReturns501(t *testing.T) {
r := NewTestRouter()
endpoints := []struct {
method string
path string
}{
{"GET", "/api/v1/jobs"},
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/templates"},
}
for _, ep := range endpoints {
w := httptest.NewRecorder()
req, _ := http.NewRequest(ep.method, ep.path, nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotImplemented {
t.Fatalf("%s %s: expected 501, got %d", ep.method, ep.path, w.Code)
}
var resp APIResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "not implemented" {
t.Fatalf("expected error 'not implemented', got '%s'", resp.Error)
}
}
}
func TestRouterUsesGinMode(t *testing.T) {
gin.SetMode(gin.TestMode)
r := NewTestRouter()
if r == nil {
t.Fatal("NewRouter returned nil")
}
}

View File

@@ -0,0 +1,167 @@
package service
import (
"context"
"fmt"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
)
func derefStr(s *string) string {
if s == nil {
return ""
}
return *s
}
func derefInt32(i *int32) int32 {
if i == nil {
return 0
}
return *i
}
func derefInt64(i *int64) int64 {
if i == nil {
return 0
}
return *i
}
func uint32NoValString(v *slurm.Uint32NoVal) string {
if v == nil {
return ""
}
if v.Infinite != nil && *v.Infinite {
return "UNLIMITED"
}
if v.Number != nil {
return strconv.FormatInt(*v.Number, 10)
}
return ""
}
type ClusterService struct {
client *slurm.Client
logger *zap.Logger
}
func NewClusterService(client *slurm.Client, logger *zap.Logger) *ClusterService {
return &ClusterService{client: client, logger: logger}
}
func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) {
resp, _, err := s.client.Nodes.GetNodes(ctx, nil)
if err != nil {
s.logger.Error("failed to get nodes", zap.Error(err))
return nil, fmt.Errorf("get nodes: %w", err)
}
if resp.Nodes == nil {
return nil, nil
}
result := make([]model.NodeResponse, 0, len(*resp.Nodes))
for _, n := range *resp.Nodes {
result = append(result, mapNode(n))
}
return result, nil
}
func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) {
resp, _, err := s.client.Nodes.GetNode(ctx, name, nil)
if err != nil {
s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err))
return nil, fmt.Errorf("get node %s: %w", name, err)
}
if resp.Nodes == nil || len(*resp.Nodes) == 0 {
return nil, nil
}
n := (*resp.Nodes)[0]
mapped := mapNode(n)
return &mapped, nil
}
func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) {
resp, _, err := s.client.Partitions.GetPartitions(ctx, nil)
if err != nil {
s.logger.Error("failed to get partitions", zap.Error(err))
return nil, fmt.Errorf("get partitions: %w", err)
}
if resp.Partitions == nil {
return nil, nil
}
result := make([]model.PartitionResponse, 0, len(*resp.Partitions))
for _, pi := range *resp.Partitions {
result = append(result, mapPartition(pi))
}
return result, nil
}
func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) {
resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil)
if err != nil {
s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err))
return nil, fmt.Errorf("get partition %s: %w", name, err)
}
if resp.Partitions == nil || len(*resp.Partitions) == 0 {
return nil, nil
}
p := (*resp.Partitions)[0]
mapped := mapPartition(p)
return &mapped, nil
}
func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) {
resp, _, err := s.client.Diag.GetDiag(ctx)
if err != nil {
s.logger.Error("failed to get diag", zap.Error(err))
return nil, fmt.Errorf("get diag: %w", err)
}
return resp, nil
}
func mapNode(n slurm.Node) model.NodeResponse {
return model.NodeResponse{
Name: derefStr(n.Name),
State: n.State,
CPUs: derefInt32(n.Cpus),
RealMemory: derefInt64(n.RealMemory),
AllocMem: derefInt64(n.AllocMemory),
Arch: derefStr(n.Architecture),
OS: derefStr(n.OperatingSystem),
}
}
func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
var state []string
if pi.Partition != nil {
state = pi.Partition.State
}
var nodes string
if pi.Nodes != nil {
nodes = derefStr(pi.Nodes.Configured)
}
var totalCPUs int32
if pi.CPUs != nil {
totalCPUs = derefInt32(pi.CPUs.Total)
}
var totalNodes int32
if pi.Nodes != nil {
totalNodes = derefInt32(pi.Nodes.Total)
}
var maxTime string
if pi.Maximums != nil {
maxTime = uint32NoValString(pi.Maximums.Time)
}
return model.PartitionResponse{
Name: derefStr(pi.Name),
State: state,
Nodes: nodes,
TotalCPUs: totalCPUs,
TotalNodes: totalNodes,
MaxTime: maxTime,
}
}

View File

@@ -0,0 +1,467 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func mockServer(handler http.HandlerFunc) (*slurm.Client, func()) {
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return client, srv.Close
}
func TestGetNodes(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/nodes" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"nodes": []map[string]interface{}{
{
"name": "node1",
"state": []string{"IDLE"},
"cpus": 64,
"real_memory": 256000,
"alloc_memory": 0,
"architecture": "x86_64",
"operating_system": "Linux 5.15",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
nodes, err := svc.GetNodes(context.Background())
if err != nil {
t.Fatalf("GetNodes returned error: %v", err)
}
if len(nodes) != 1 {
t.Fatalf("expected 1 node, got %d", len(nodes))
}
n := nodes[0]
if n.Name != "node1" {
t.Errorf("expected name node1, got %s", n.Name)
}
if len(n.State) != 1 || n.State[0] != "IDLE" {
t.Errorf("expected state [IDLE], got %v", n.State)
}
if n.CPUs != 64 {
t.Errorf("expected 64 CPUs, got %d", n.CPUs)
}
if n.RealMemory != 256000 {
t.Errorf("expected real_memory 256000, got %d", n.RealMemory)
}
if n.Arch != "x86_64" {
t.Errorf("expected arch x86_64, got %s", n.Arch)
}
if n.OS != "Linux 5.15" {
t.Errorf("expected OS 'Linux 5.15', got %s", n.OS)
}
}
func TestGetNodes_Empty(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
nodes, err := svc.GetNodes(context.Background())
if err != nil {
t.Fatalf("GetNodes returned error: %v", err)
}
if nodes != nil {
t.Errorf("expected nil for empty response, got %v", nodes)
}
}
func TestGetNode(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/node/node1" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"ALLOCATED"}, "cpus": 32},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
node, err := svc.GetNode(context.Background(), "node1")
if err != nil {
t.Fatalf("GetNode returned error: %v", err)
}
if node == nil {
t.Fatal("expected node, got nil")
}
if node.Name != "node1" {
t.Errorf("expected name node1, got %s", node.Name)
}
}
func TestGetNode_NotFound(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
node, err := svc.GetNode(context.Background(), "missing")
if err != nil {
t.Fatalf("GetNode returned error: %v", err)
}
if node != nil {
t.Errorf("expected nil for missing node, got %+v", node)
}
}
func TestGetPartitions(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/partitions" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": 10,
},
"cpus": map[string]interface{}{
"total": 640,
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{
"set": true,
"infinite": false,
"number": 86400,
},
},
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
partitions, err := svc.GetPartitions(context.Background())
if err != nil {
t.Fatalf("GetPartitions returned error: %v", err)
}
if len(partitions) != 1 {
t.Fatalf("expected 1 partition, got %d", len(partitions))
}
p := partitions[0]
if p.Name != "normal" {
t.Errorf("expected name normal, got %s", p.Name)
}
if len(p.State) != 1 || p.State[0] != "UP" {
t.Errorf("expected state [UP], got %v", p.State)
}
if p.Nodes != "node[1-10]" {
t.Errorf("expected nodes 'node[1-10]', got %s", p.Nodes)
}
if p.TotalCPUs != 640 {
t.Errorf("expected 640 total CPUs, got %d", p.TotalCPUs)
}
if p.TotalNodes != 10 {
t.Errorf("expected 10 total nodes, got %d", p.TotalNodes)
}
if p.MaxTime != "86400" {
t.Errorf("expected max_time '86400', got %s", p.MaxTime)
}
}
func TestGetPartitions_Empty(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
partitions, err := svc.GetPartitions(context.Background())
if err != nil {
t.Fatalf("GetPartitions returned error: %v", err)
}
if partitions != nil {
t.Errorf("expected nil for empty response, got %v", partitions)
}
}
func TestGetPartition(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/partition/gpu" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "gpu",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "gpu[1-4]",
"total": 4,
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{
"set": true,
"infinite": true,
},
},
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
part, err := svc.GetPartition(context.Background(), "gpu")
if err != nil {
t.Fatalf("GetPartition returned error: %v", err)
}
if part == nil {
t.Fatal("expected partition, got nil")
}
if part.Name != "gpu" {
t.Errorf("expected name gpu, got %s", part.Name)
}
if part.MaxTime != "UNLIMITED" {
t.Errorf("expected max_time UNLIMITED, got %s", part.MaxTime)
}
}
func TestGetPartition_NotFound(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
part, err := svc.GetPartition(context.Background(), "missing")
if err != nil {
t.Fatalf("GetPartition returned error: %v", err)
}
if part != nil {
t.Errorf("expected nil for missing partition, got %+v", part)
}
}
func TestGetDiag(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/diag" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"statistics": map[string]interface{}{
"server_thread_count": 10,
"agent_queue_size": 5,
"jobs_submitted": 100,
"jobs_running": 20,
"schedule_queue_length": 3,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
diag, err := svc.GetDiag(context.Background())
if err != nil {
t.Fatalf("GetDiag returned error: %v", err)
}
if diag == nil {
t.Fatal("expected diag response, got nil")
}
if diag.Statistics == nil {
t.Fatal("expected statistics, got nil")
}
if diag.Statistics.ServerThreadCount == nil || *diag.Statistics.ServerThreadCount != 10 {
t.Errorf("expected server_thread_count 10, got %v", diag.Statistics.ServerThreadCount)
}
}
func TestNewSlurmClient(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "jwt.key")
os.WriteFile(keyPath, make([]byte, 32), 0644)
client, err := NewSlurmClient("http://localhost:6820", "root", keyPath)
if err != nil {
t.Fatalf("NewSlurmClient returned error: %v", err)
}
if client == nil {
t.Fatal("expected client, got nil")
}
}
func newClusterServiceWithObserver(srv *httptest.Server) (*ClusterService, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return NewClusterService(client, l), recorded
}
func errorServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"errors": [{"error": "internal server error"}]}`))
}))
}
func TestClusterService_GetNodes_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetNodes(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
}
entry := logs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
if len(entry.Context) == 0 {
t.Error("expected structured fields in log entry")
}
}
func TestClusterService_GetNode_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetNode(context.Background(), "test-node")
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
}
entry := logs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
hasName := false
for _, f := range entry.Context {
if f.Key == "name" && f.String == "test-node" {
hasName = true
}
}
if !hasName {
t.Error("expected 'name' field with value 'test-node' in log entry")
}
}
func TestClusterService_GetPartitions_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetPartitions(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
}
entry := logs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
if len(entry.Context) == 0 {
t.Error("expected structured fields in log entry")
}
}
func TestClusterService_GetPartition_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetPartition(context.Background(), "test-partition")
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
}
entry := logs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
hasName := false
for _, f := range entry.Context {
if f.Key == "name" && f.String == "test-partition" {
hasName = true
}
}
if !hasName {
t.Error("expected 'name' field with value 'test-partition' in log entry")
}
}
func TestClusterService_GetDiag_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetDiag(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
}
entry := logs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
if len(entry.Context) == 0 {
t.Error("expected structured fields in log entry")
}
}

View File

@@ -0,0 +1,246 @@
package service
import (
"context"
"fmt"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
)
// JobService wraps Slurm SDK job operations with model mapping and pagination.
type JobService struct {
client *slurm.Client
logger *zap.Logger
}
// NewJobService creates a new JobService with the given Slurm SDK client.
func NewJobService(client *slurm.Client, logger *zap.Logger) *JobService {
return &JobService{client: client, logger: logger}
}
// SubmitJob submits a new job to Slurm and returns the job ID.
func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest) (*model.JobResponse, error) {
script := req.Script
jobDesc := &slurm.JobDescMsg{
Script: &script,
Partition: strToPtrOrNil(req.Partition),
Qos: strToPtrOrNil(req.QOS),
Name: strToPtrOrNil(req.JobName),
}
if req.CPUs > 0 {
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
}
if req.TimeLimit != "" {
if mins, err := strconv.ParseInt(req.TimeLimit, 10, 64); err == nil {
jobDesc.TimeLimit = &slurm.Uint32NoVal{Number: &mins}
}
}
submitReq := &slurm.JobSubmitReq{
Script: &script,
Job: jobDesc,
}
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
if err != nil {
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
return nil, fmt.Errorf("submit job: %w", err)
}
resp := &model.JobResponse{}
if result.Result != nil && result.Result.JobID != nil {
resp.JobID = *result.Result.JobID
} else if result.JobID != nil {
resp.JobID = *result.JobID
}
s.logger.Info("job submitted", zap.String("job_name", req.JobName), zap.Int32("job_id", resp.JobID))
return resp, nil
}
// GetJobs lists all current jobs from Slurm.
func (s *JobService) GetJobs(ctx context.Context) ([]model.JobResponse, error) {
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
if err != nil {
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
return nil, fmt.Errorf("get jobs: %w", err)
}
jobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
jobs = append(jobs, mapJobInfo(&result.Jobs[i]))
}
return jobs, nil
}
// GetJob retrieves a single job by ID.
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
if err != nil {
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
return nil, fmt.Errorf("get job %s: %w", jobID, err)
}
if len(result.Jobs) == 0 {
return nil, nil
}
resp := mapJobInfo(&result.Jobs[0])
return &resp, nil
}
// CancelJob cancels a job by ID.
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
_, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
if err != nil {
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
return fmt.Errorf("cancel job %s: %w", jobID, err)
}
s.logger.Info("job cancelled", zap.String("job_id", jobID))
return nil
}
// GetJobHistory queries SlurmDBD for historical jobs with pagination.
func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQuery) (*model.JobListResponse, error) {
opts := &slurm.GetSlurmdbJobsOptions{}
if query.Users != "" {
opts.Users = strToPtr(query.Users)
}
if query.Account != "" {
opts.Account = strToPtr(query.Account)
}
if query.Partition != "" {
opts.Partition = strToPtr(query.Partition)
}
if query.State != "" {
opts.State = strToPtr(query.State)
}
if query.JobName != "" {
opts.JobName = strToPtr(query.JobName)
}
if query.StartTime != "" {
opts.StartTime = strToPtr(query.StartTime)
}
if query.EndTime != "" {
opts.EndTime = strToPtr(query.EndTime)
}
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
if err != nil {
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
return nil, fmt.Errorf("get job history: %w", err)
}
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
}
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
start := (page - 1) * pageSize
end := start + pageSize
if start > total {
start = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[start:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// ---------------------------------------------------------------------------
// Helper functions
// ---------------------------------------------------------------------------
func strToPtr(s string) *string { return &s }
// strPtrOrNil returns a pointer to s if non-empty, otherwise nil.
func strToPtrOrNil(s string) *string {
if s == "" {
return nil
}
return &s
}
// mapJobInfo maps SDK JobInfo to API JobResponse.
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
resp := model.JobResponse{}
if ji.JobID != nil {
resp.JobID = *ji.JobID
}
if ji.Name != nil {
resp.Name = *ji.Name
}
resp.State = ji.JobState
if ji.Partition != nil {
resp.Partition = *ji.Partition
}
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
resp.SubmitTime = ji.SubmitTime.Number
}
if ji.StartTime != nil && ji.StartTime.Number != nil {
resp.StartTime = ji.StartTime.Number
}
if ji.EndTime != nil && ji.EndTime.Number != nil {
resp.EndTime = ji.EndTime.Number
}
if ji.ExitCode != nil && ji.ExitCode.ReturnCode != nil && ji.ExitCode.ReturnCode.Number != nil {
code := int32(*ji.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if ji.Nodes != nil {
resp.Nodes = *ji.Nodes
}
return resp
}
// mapSlurmdbJob maps SDK SlurmDBD Job to API JobResponse.
func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
resp := model.JobResponse{}
if j.JobID != nil {
resp.JobID = *j.JobID
}
if j.Name != nil {
resp.Name = *j.Name
}
if j.State != nil {
resp.State = j.State.Current
}
if j.Partition != nil {
resp.Partition = *j.Partition
}
if j.Time != nil {
if j.Time.Submission != nil {
resp.SubmitTime = j.Time.Submission
}
if j.Time.Start != nil {
resp.StartTime = j.Time.Start
}
if j.Time.End != nil {
resp.EndTime = j.Time.End
}
}
if j.Nodes != nil {
resp.Nodes = *j.Nodes
}
return resp
}

View File

@@ -0,0 +1,703 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func mockJobServer(handler http.HandlerFunc) (*slurm.Client, func()) {
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return client, srv.Close
}
func TestSubmitJob(t *testing.T) {
jobID := int32(123)
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/slurm/v0.0.40/job/submit" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
var body slurm.JobSubmitReq
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode body: %v", err)
}
if body.Job == nil || body.Job.Script == nil || *body.Job.Script != "#!/bin/bash\necho hello" {
t.Errorf("unexpected script in request body")
}
resp := slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{
JobID: &jobID,
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "#!/bin/bash\necho hello",
Partition: "normal",
QOS: "high",
JobName: "test-job",
CPUs: 4,
TimeLimit: "60",
})
if err != nil {
t.Fatalf("SubmitJob: %v", err)
}
if resp.JobID != 123 {
t.Errorf("expected JobID 123, got %d", resp.JobID)
}
}
func TestSubmitJob_WithOptionalFields(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body slurm.JobSubmitReq
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode body: %v", err)
}
if body.Job == nil {
t.Fatal("job desc is nil")
}
if body.Job.Partition != nil {
t.Error("expected partition nil for empty string")
}
if body.Job.MinimumCpus != nil {
t.Error("expected minimum_cpus nil when CPUs=0")
}
jobID := int32(456)
resp := slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "echo hi",
})
if err != nil {
t.Fatalf("SubmitJob: %v", err)
}
if resp.JobID != 456 {
t.Errorf("expected JobID 456, got %d", resp.JobID)
}
}
func TestSubmitJob_Error(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"internal"}`))
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "echo fail",
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestGetJobs(t *testing.T) {
jobID := int32(100)
name := "my-job"
partition := "gpu"
ts := int64(1700000000)
nodes := "node01"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
Name: &name,
JobState: []string{"RUNNING"},
Partition: &partition,
SubmitTime: &slurm.Uint64NoVal{Number: &ts},
StartTime: &slurm.Uint64NoVal{Number: &ts},
EndTime: &slurm.Uint64NoVal{Number: &ts},
Nodes: &nodes,
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
jobs, err := svc.GetJobs(context.Background())
if err != nil {
t.Fatalf("GetJobs: %v", err)
}
if len(jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(jobs))
}
j := jobs[0]
if j.JobID != 100 {
t.Errorf("expected JobID 100, got %d", j.JobID)
}
if j.Name != "my-job" {
t.Errorf("expected Name my-job, got %s", j.Name)
}
if len(j.State) != 1 || j.State[0] != "RUNNING" {
t.Errorf("expected State [RUNNING], got %v", j.State)
}
if j.Partition != "gpu" {
t.Errorf("expected Partition gpu, got %s", j.Partition)
}
if j.SubmitTime == nil || *j.SubmitTime != ts {
t.Errorf("expected SubmitTime %d, got %v", ts, j.SubmitTime)
}
if j.Nodes != "node01" {
t.Errorf("expected Nodes node01, got %s", j.Nodes)
}
}
func TestGetJob(t *testing.T) {
jobID := int32(200)
name := "single-job"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
Name: &name,
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "200")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job == nil {
t.Fatal("expected job, got nil")
}
if job.JobID != 200 {
t.Errorf("expected JobID 200, got %d", job.JobID)
}
}
func TestGetJob_NotFound(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{Jobs: slurm.JobInfoMsg{}}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "999")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job != nil {
t.Errorf("expected nil for not found, got %+v", job)
}
}
func TestCancelJob(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
t.Errorf("expected DELETE, got %s", r.Method)
}
resp := slurm.OpenapiResp{}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
err := svc.CancelJob(context.Background(), "300")
if err != nil {
t.Fatalf("CancelJob: %v", err)
}
}
func TestCancelJob_Error(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`not found`))
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
err := svc.CancelJob(context.Background(), "999")
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestGetJobHistory(t *testing.T) {
jobID1 := int32(10)
jobID2 := int32(20)
jobID3 := int32(30)
name1 := "hist-1"
name2 := "hist-2"
name3 := "hist-3"
submission1 := int64(1700000000)
submission2 := int64(1700001000)
submission3 := int64(1700002000)
partition := "normal"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
users := r.URL.Query().Get("users")
if users != "testuser" {
t.Errorf("expected users=testuser, got %s", users)
}
resp := slurm.OpenapiSlurmdbdJobsResp{
Jobs: slurm.JobList{
{
JobID: &jobID1,
Name: &name1,
Partition: &partition,
State: &slurm.JobState{Current: []string{"COMPLETED"}},
Time: &slurm.JobTime{Submission: &submission1},
},
{
JobID: &jobID2,
Name: &name2,
Partition: &partition,
State: &slurm.JobState{Current: []string{"FAILED"}},
Time: &slurm.JobTime{Submission: &submission2},
},
{
JobID: &jobID3,
Name: &name3,
Partition: &partition,
State: &slurm.JobState{Current: []string{"CANCELLED"}},
Time: &slurm.JobTime{Submission: &submission3},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
Users: "testuser",
Page: 1,
PageSize: 2,
})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
if result.Total != 3 {
t.Errorf("expected Total 3, got %d", result.Total)
}
if result.Page != 1 {
t.Errorf("expected Page 1, got %d", result.Page)
}
if result.PageSize != 2 {
t.Errorf("expected PageSize 2, got %d", result.PageSize)
}
if len(result.Jobs) != 2 {
t.Fatalf("expected 2 jobs on page 1, got %d", len(result.Jobs))
}
if result.Jobs[0].JobID != 10 {
t.Errorf("expected first job ID 10, got %d", result.Jobs[0].JobID)
}
if result.Jobs[1].JobID != 20 {
t.Errorf("expected second job ID 20, got %d", result.Jobs[1].JobID)
}
if len(result.Jobs[0].State) != 1 || result.Jobs[0].State[0] != "COMPLETED" {
t.Errorf("expected state [COMPLETED], got %v", result.Jobs[0].State)
}
}
func TestGetJobHistory_Page2(t *testing.T) {
jobID1 := int32(10)
jobID2 := int32(20)
name1 := "a"
name2 := "b"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiSlurmdbdJobsResp{
Jobs: slurm.JobList{
{JobID: &jobID1, Name: &name1},
{JobID: &jobID2, Name: &name2},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
Page: 2,
PageSize: 1,
})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
if result.Total != 2 {
t.Errorf("expected Total 2, got %d", result.Total)
}
if len(result.Jobs) != 1 {
t.Fatalf("expected 1 job on page 2, got %d", len(result.Jobs))
}
if result.Jobs[0].JobID != 20 {
t.Errorf("expected job ID 20, got %d", result.Jobs[0].JobID)
}
}
func TestGetJobHistory_DefaultPagination(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
if result.Page != 1 {
t.Errorf("expected default page 1, got %d", result.Page)
}
if result.PageSize != 20 {
t.Errorf("expected default pageSize 20, got %d", result.PageSize)
}
}
func TestGetJobHistory_QueryMapping(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
if v := q.Get("account"); v != "proj1" {
t.Errorf("expected account=proj1, got %s", v)
}
if v := q.Get("partition"); v != "gpu" {
t.Errorf("expected partition=gpu, got %s", v)
}
if v := q.Get("state"); v != "COMPLETED" {
t.Errorf("expected state=COMPLETED, got %s", v)
}
if v := q.Get("job_name"); v != "myjob" {
t.Errorf("expected job_name=myjob, got %s", v)
}
if v := q.Get("start_time"); v != "1700000000" {
t.Errorf("expected start_time=1700000000, got %s", v)
}
if v := q.Get("end_time"); v != "1700099999" {
t.Errorf("expected end_time=1700099999, got %s", v)
}
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
Users: "testuser",
Account: "proj1",
Partition: "gpu",
State: "COMPLETED",
JobName: "myjob",
StartTime: "1700000000",
EndTime: "1700099999",
})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
}
func TestGetJobHistory_Error(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"db down"}`))
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestMapJobInfo_ExitCode(t *testing.T) {
returnCode := int64(2)
ji := &slurm.JobInfo{
ExitCode: &slurm.ProcessExitCodeVerbose{
ReturnCode: &slurm.Uint32NoVal{Number: &returnCode},
},
}
resp := mapJobInfo(ji)
if resp.ExitCode == nil || *resp.ExitCode != 2 {
t.Errorf("expected exit code 2, got %v", resp.ExitCode)
}
}
func TestMapSlurmdbJob_NilFields(t *testing.T) {
j := &slurm.Job{}
resp := mapSlurmdbJob(j)
if resp.JobID != 0 {
t.Errorf("expected JobID 0, got %d", resp.JobID)
}
if resp.State != nil {
t.Errorf("expected nil State, got %v", resp.State)
}
if resp.SubmitTime != nil {
t.Errorf("expected nil SubmitTime, got %v", resp.SubmitTime)
}
}
// ---------------------------------------------------------------------------
// Structured logging tests using zaptest/observer
// ---------------------------------------------------------------------------
func newJobServiceWithObserver(srv *httptest.Server) (*JobService, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return NewJobService(client, l), recorded
}
func TestJobService_SubmitJob_SuccessLog(t *testing.T) {
jobID := int32(789)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "echo hi",
JobName: "log-test-job",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["job_name"] != "log-test-job" {
t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"])
}
gotJobID, ok := fields["job_id"]
if !ok {
t.Fatal("expected job_id field in log entry")
}
if gotJobID != int32(789) && gotJobID != int64(789) {
t.Errorf("expected job_id=789, got %v (%T)", gotJobID, gotJobID)
}
}
func TestJobService_SubmitJob_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"internal"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{Script: "echo fail"})
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["operation"] != "submit" {
t.Errorf("expected operation=submit, got %v", fields["operation"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_CancelJob_SuccessLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiResp{}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
err := svc.CancelJob(context.Background(), "555")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["job_id"] != "555" {
t.Errorf("expected job_id=555, got %v", fields["job_id"])
}
}
func TestJobService_CancelJob_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`not found`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
err := svc.CancelJob(context.Background(), "999")
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["operation"] != "cancel" {
t.Errorf("expected operation=cancel, got %v", fields["operation"])
}
if fields["job_id"] != "999" {
t.Errorf("expected job_id=999, got %v", fields["job_id"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_GetJobs_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJobs(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["operation"] != "get_jobs" {
t.Errorf("expected operation=get_jobs, got %v", fields["operation"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_GetJob_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJob(context.Background(), "200")
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["operation"] != "get_job" {
t.Errorf("expected operation=get_job, got %v", fields["operation"])
}
if fields["job_id"] != "200" {
t.Errorf("expected job_id=200, got %v", fields["job_id"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"db down"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
}
fields := entries[0].ContextMap()
if fields["operation"] != "get_job_history" {
t.Errorf("expected operation=get_job_history, got %v", fields["operation"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}

View File

@@ -0,0 +1,15 @@
package service
import (
"gcy_hpc_server/internal/slurm"
)
// NewSlurmClient creates a Slurm SDK client with JWT authentication.
// It reads the JWT key from the given keyPath and signs tokens automatically.
func NewSlurmClient(apiURL, userName, jwtKeyPath string) (*slurm.Client, error) {
return slurm.NewClientWithOpts(
apiURL,
slurm.WithUsername(userName),
slurm.WithJWTKey(jwtKeyPath),
)
}

View File

@@ -0,0 +1 @@
DROP TABLE IF EXISTS job_templates;

View File

@@ -0,0 +1,14 @@
CREATE TABLE IF NOT EXISTS job_templates (
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT,
script TEXT NOT NULL,
partition VARCHAR(255),
qos VARCHAR(255),
cpus INT UNSIGNED,
memory VARCHAR(50),
time_limit VARCHAR(50),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY idx_name (name)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

44
internal/store/mysql.go Normal file
View File

@@ -0,0 +1,44 @@
package store
import (
"fmt"
"time"
"gcy_hpc_server/internal/logger"
"gcy_hpc_server/internal/model"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// NewGormDB opens a GORM MySQL connection with sensible defaults.
func NewGormDB(dsn string, zapLogger *zap.Logger, gormLevel string) (*gorm.DB, error) {
gormCfg := &gorm.Config{
Logger: logger.NewGormLogger(zapLogger, gormLevel),
}
db, err := gorm.Open(mysql.Open(dsn), gormCfg)
if err != nil {
return nil, fmt.Errorf("failed to open gorm mysql: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(5)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping mysql: %w", err)
}
return db, nil
}
// AutoMigrate runs GORM auto-migration for all models.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(&model.JobTemplate{})
}

View File

@@ -0,0 +1,14 @@
package store
import (
"testing"
"go.uber.org/zap"
)
func TestNewGormDBInvalidDSN(t *testing.T) {
_, err := NewGormDB("invalid:dsn@tcp(localhost:99999)/nonexistent?parseTime=true", zap.NewNop(), "warn")
if err == nil {
t.Fatal("expected error for invalid DSN, got nil")
}
}

View File

@@ -0,0 +1,113 @@
package store
import (
"context"
"errors"
"gorm.io/gorm"
"gcy_hpc_server/internal/model"
)
// TemplateStore provides CRUD operations for job templates via GORM.
type TemplateStore struct {
db *gorm.DB
}
// NewTemplateStore creates a new TemplateStore.
func NewTemplateStore(db *gorm.DB) *TemplateStore {
return &TemplateStore{db: db}
}
// List returns a paginated list of job templates and the total count.
func (s *TemplateStore) List(ctx context.Context, page, pageSize int) ([]model.JobTemplate, int, error) {
var templates []model.JobTemplate
var total int64
if err := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Count(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&templates).Error; err != nil {
return nil, 0, err
}
return templates, int(total), nil
}
// GetByID returns a single job template by ID. Returns nil, nil when not found.
func (s *TemplateStore) GetByID(ctx context.Context, id int64) (*model.JobTemplate, error) {
var t model.JobTemplate
err := s.db.WithContext(ctx).First(&t, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &t, nil
}
// Create inserts a new job template and returns the generated ID.
func (s *TemplateStore) Create(ctx context.Context, req *model.CreateTemplateRequest) (int64, error) {
t := &model.JobTemplate{
Name: req.Name,
Description: req.Description,
Script: req.Script,
Partition: req.Partition,
QOS: req.QOS,
CPUs: req.CPUs,
Memory: req.Memory,
TimeLimit: req.TimeLimit,
}
if err := s.db.WithContext(ctx).Create(t).Error; err != nil {
return 0, err
}
return t.ID, nil
}
// Update modifies an existing job template. Only non-empty/non-zero fields are updated.
func (s *TemplateStore) Update(ctx context.Context, id int64, req *model.UpdateTemplateRequest) error {
updates := map[string]interface{}{}
if req.Name != "" {
updates["name"] = req.Name
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.Script != "" {
updates["script"] = req.Script
}
if req.Partition != "" {
updates["partition"] = req.Partition
}
if req.QOS != "" {
updates["qos"] = req.QOS
}
if req.CPUs > 0 {
updates["cpus"] = req.CPUs
}
if req.Memory != "" {
updates["memory"] = req.Memory
}
if req.TimeLimit != "" {
updates["time_limit"] = req.TimeLimit
}
if len(updates) == 0 {
return nil // nothing to update
}
result := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Where("id = ?", id).Updates(updates)
return result.Error
}
// Delete removes a job template by ID. Idempotent — returns nil even if the row doesn't exist.
func (s *TemplateStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.JobTemplate{}, id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,205 @@
package store
import (
"context"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gcy_hpc_server/internal/model"
)
func newTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.JobTemplate{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func TestTemplateStore_List(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-1", Script: "echo 1"})
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-2", Script: "echo 2"})
templates, total, err := s.List(context.Background(), 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
if len(templates) != 2 {
t.Fatalf("len(templates) = %d, want 2", len(templates))
}
// DESC order, so job-2 is first
if templates[0].Name != "job-2" {
t.Errorf("templates[0].Name = %q, want %q", templates[0].Name, "job-2")
}
}
func TestTemplateStore_List_Page2(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
for i := 0; i < 15; i++ {
s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "job-" + string(rune('A'+i)), Script: "echo",
})
}
templates, total, err := s.List(context.Background(), 2, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 15 {
t.Errorf("total = %d, want 15", total)
}
if len(templates) != 5 {
t.Fatalf("len(templates) = %d, want 5", len(templates))
}
}
func TestTemplateStore_GetByID(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "test-job", Script: "echo hi", Partition: "batch", QOS: "normal", CPUs: 2, Memory: "4G",
})
tpl, err := s.GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if tpl == nil {
t.Fatal("GetByID() returned nil")
}
if tpl.Name != "test-job" {
t.Errorf("Name = %q, want %q", tpl.Name, "test-job")
}
if tpl.CPUs != 2 {
t.Errorf("CPUs = %d, want 2", tpl.CPUs)
}
}
func TestTemplateStore_GetByID_NotFound(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
tpl, err := s.GetByID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByID() error = %v, want nil", err)
}
if tpl != nil {
t.Fatal("GetByID() should return nil for not found")
}
}
func TestTemplateStore_Create(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, err := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "new-job", Script: "echo", Partition: "gpu",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id == 0 {
t.Fatal("Create() returned id=0")
}
}
func TestTemplateStore_Update(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "old", Script: "echo",
})
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
Name: "updated",
Script: "echo new",
CPUs: 8,
})
if err != nil {
t.Fatalf("Update() error = %v", err)
}
tpl, _ := s.GetByID(context.Background(), id)
if tpl.Name != "updated" {
t.Errorf("Name = %q, want %q", tpl.Name, "updated")
}
if tpl.CPUs != 8 {
t.Errorf("CPUs = %d, want 8", tpl.CPUs)
}
}
func TestTemplateStore_Update_Partial(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "original", Script: "echo orig", Partition: "batch",
})
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
Name: "renamed",
})
if err != nil {
t.Fatalf("Update() error = %v", err)
}
tpl, _ := s.GetByID(context.Background(), id)
if tpl.Name != "renamed" {
t.Errorf("Name = %q, want %q", tpl.Name, "renamed")
}
// Script and Partition should be unchanged
if tpl.Script != "echo orig" {
t.Errorf("Script = %q, want %q", tpl.Script, "echo orig")
}
if tpl.Partition != "batch" {
t.Errorf("Partition = %q, want %q", tpl.Partition, "batch")
}
}
func TestTemplateStore_Delete(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "to-delete", Script: "echo",
})
err := s.Delete(context.Background(), id)
if err != nil {
t.Fatalf("Delete() error = %v", err)
}
tpl, _ := s.GetByID(context.Background(), id)
if tpl != nil {
t.Fatal("Delete() did not remove the record")
}
}
func TestTemplateStore_Delete_NotFound(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
err := s.Delete(context.Background(), 999)
if err != nil {
t.Fatalf("Delete() should not error for non-existent record, got: %v", err)
}
}