From 4903f7d07fe8cb75cf048e8935b2864de98ec6ff Mon Sep 17 00:00:00 2001 From: dailz Date: Fri, 10 Apr 2026 08:39:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=B8=9A=E5=8A=A1?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=B1=82=E5=92=8C=E7=BB=93=E6=9E=84=E5=8C=96?= =?UTF-8?q?=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - JobService: 提交、查询、取消、历史记录,记录关键操作日志 - ClusterService: 节点、分区、诊断查询,记录错误日志 - NewSlurmClient: JWT 认证 HTTP 客户端工厂 - 所有构造函数接受 *zap.Logger 参数实现依赖注入 - 提交/取消成功记录 Info,API 错误记录 Error - 完整 TDD 测试,使用 zaptest/observer 验证日志输出 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- internal/service/cluster_service.go | 167 ++++++ internal/service/cluster_service_test.go | 467 +++++++++++++++ internal/service/job_service.go | 246 ++++++++ internal/service/job_service_test.go | 703 +++++++++++++++++++++++ internal/service/slurm_client.go | 15 + 5 files changed, 1598 insertions(+) create mode 100644 internal/service/cluster_service.go create mode 100644 internal/service/cluster_service_test.go create mode 100644 internal/service/job_service.go create mode 100644 internal/service/job_service_test.go create mode 100644 internal/service/slurm_client.go diff --git a/internal/service/cluster_service.go b/internal/service/cluster_service.go new file mode 100644 index 0000000..3d1cdfa --- /dev/null +++ b/internal/service/cluster_service.go @@ -0,0 +1,167 @@ +package service + +import ( + "context" + "fmt" + "strconv" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" + + "go.uber.org/zap" +) + +func derefStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +func derefInt32(i *int32) int32 { + if i == nil { + return 0 + } + return *i +} + +func derefInt64(i *int64) int64 { + if i == nil { + return 0 + } + return *i +} + +func uint32NoValString(v *slurm.Uint32NoVal) string { + if v == nil { + return "" + } + if v.Infinite != nil && *v.Infinite { + return "UNLIMITED" + } + if v.Number != nil { + return strconv.FormatInt(*v.Number, 10) + } + return "" +} + +type ClusterService struct { + client *slurm.Client + logger *zap.Logger +} + +func NewClusterService(client *slurm.Client, logger *zap.Logger) *ClusterService { + return &ClusterService{client: client, logger: logger} +} + +func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) { + resp, _, err := s.client.Nodes.GetNodes(ctx, nil) + if err != nil { + s.logger.Error("failed to get nodes", zap.Error(err)) + return nil, fmt.Errorf("get nodes: %w", err) + } + if resp.Nodes == nil { + return nil, nil + } + result := make([]model.NodeResponse, 0, len(*resp.Nodes)) + for _, n := range *resp.Nodes { + result = append(result, mapNode(n)) + } + return result, nil +} + +func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) { + resp, _, err := s.client.Nodes.GetNode(ctx, name, nil) + if err != nil { + s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err)) + return nil, fmt.Errorf("get node %s: %w", name, err) + } + if resp.Nodes == nil || len(*resp.Nodes) == 0 { + return nil, nil + } + n := (*resp.Nodes)[0] + mapped := mapNode(n) + return &mapped, nil +} + +func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) { + resp, _, err := s.client.Partitions.GetPartitions(ctx, nil) + if err != nil { + s.logger.Error("failed to get partitions", zap.Error(err)) + return nil, fmt.Errorf("get partitions: %w", err) + } + if resp.Partitions == nil { + return nil, nil + } + result := make([]model.PartitionResponse, 0, len(*resp.Partitions)) + for _, pi := range *resp.Partitions { + result = append(result, mapPartition(pi)) + } + return result, nil +} + +func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) { + resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil) + if err != nil { + s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err)) + return nil, fmt.Errorf("get partition %s: %w", name, err) + } + if resp.Partitions == nil || len(*resp.Partitions) == 0 { + return nil, nil + } + p := (*resp.Partitions)[0] + mapped := mapPartition(p) + return &mapped, nil +} + +func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) { + resp, _, err := s.client.Diag.GetDiag(ctx) + if err != nil { + s.logger.Error("failed to get diag", zap.Error(err)) + return nil, fmt.Errorf("get diag: %w", err) + } + return resp, nil +} + +func mapNode(n slurm.Node) model.NodeResponse { + return model.NodeResponse{ + Name: derefStr(n.Name), + State: n.State, + CPUs: derefInt32(n.Cpus), + RealMemory: derefInt64(n.RealMemory), + AllocMem: derefInt64(n.AllocMemory), + Arch: derefStr(n.Architecture), + OS: derefStr(n.OperatingSystem), + } +} + +func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse { + var state []string + if pi.Partition != nil { + state = pi.Partition.State + } + var nodes string + if pi.Nodes != nil { + nodes = derefStr(pi.Nodes.Configured) + } + var totalCPUs int32 + if pi.CPUs != nil { + totalCPUs = derefInt32(pi.CPUs.Total) + } + var totalNodes int32 + if pi.Nodes != nil { + totalNodes = derefInt32(pi.Nodes.Total) + } + var maxTime string + if pi.Maximums != nil { + maxTime = uint32NoValString(pi.Maximums.Time) + } + return model.PartitionResponse{ + Name: derefStr(pi.Name), + State: state, + Nodes: nodes, + TotalCPUs: totalCPUs, + TotalNodes: totalNodes, + MaxTime: maxTime, + } +} diff --git a/internal/service/cluster_service_test.go b/internal/service/cluster_service_test.go new file mode 100644 index 0000000..cff6c16 --- /dev/null +++ b/internal/service/cluster_service_test.go @@ -0,0 +1,467 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "gcy_hpc_server/internal/slurm" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func mockServer(handler http.HandlerFunc) (*slurm.Client, func()) { + srv := httptest.NewServer(handler) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + return client, srv.Close +} + +func TestGetNodes(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/slurm/v0.0.40/nodes" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + resp := map[string]interface{}{ + "nodes": []map[string]interface{}{ + { + "name": "node1", + "state": []string{"IDLE"}, + "cpus": 64, + "real_memory": 256000, + "alloc_memory": 0, + "architecture": "x86_64", + "operating_system": "Linux 5.15", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + nodes, err := svc.GetNodes(context.Background()) + if err != nil { + t.Fatalf("GetNodes returned error: %v", err) + } + if len(nodes) != 1 { + t.Fatalf("expected 1 node, got %d", len(nodes)) + } + n := nodes[0] + if n.Name != "node1" { + t.Errorf("expected name node1, got %s", n.Name) + } + if len(n.State) != 1 || n.State[0] != "IDLE" { + t.Errorf("expected state [IDLE], got %v", n.State) + } + if n.CPUs != 64 { + t.Errorf("expected 64 CPUs, got %d", n.CPUs) + } + if n.RealMemory != 256000 { + t.Errorf("expected real_memory 256000, got %d", n.RealMemory) + } + if n.Arch != "x86_64" { + t.Errorf("expected arch x86_64, got %s", n.Arch) + } + if n.OS != "Linux 5.15" { + t.Errorf("expected OS 'Linux 5.15', got %s", n.OS) + } +} + +func TestGetNodes_Empty(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{}) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + nodes, err := svc.GetNodes(context.Background()) + if err != nil { + t.Fatalf("GetNodes returned error: %v", err) + } + if nodes != nil { + t.Errorf("expected nil for empty response, got %v", nodes) + } +} + +func TestGetNode(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/slurm/v0.0.40/node/node1" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + resp := map[string]interface{}{ + "nodes": []map[string]interface{}{ + {"name": "node1", "state": []string{"ALLOCATED"}, "cpus": 32}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + node, err := svc.GetNode(context.Background(), "node1") + if err != nil { + t.Fatalf("GetNode returned error: %v", err) + } + if node == nil { + t.Fatal("expected node, got nil") + } + if node.Name != "node1" { + t.Errorf("expected name node1, got %s", node.Name) + } +} + +func TestGetNode_NotFound(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{}) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + node, err := svc.GetNode(context.Background(), "missing") + if err != nil { + t.Fatalf("GetNode returned error: %v", err) + } + if node != nil { + t.Errorf("expected nil for missing node, got %+v", node) + } +} + +func TestGetPartitions(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/slurm/v0.0.40/partitions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + resp := map[string]interface{}{ + "partitions": []map[string]interface{}{ + { + "name": "normal", + "partition": map[string]interface{}{ + "state": []string{"UP"}, + }, + "nodes": map[string]interface{}{ + "configured": "node[1-10]", + "total": 10, + }, + "cpus": map[string]interface{}{ + "total": 640, + }, + "maximums": map[string]interface{}{ + "time": map[string]interface{}{ + "set": true, + "infinite": false, + "number": 86400, + }, + }, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + partitions, err := svc.GetPartitions(context.Background()) + if err != nil { + t.Fatalf("GetPartitions returned error: %v", err) + } + if len(partitions) != 1 { + t.Fatalf("expected 1 partition, got %d", len(partitions)) + } + p := partitions[0] + if p.Name != "normal" { + t.Errorf("expected name normal, got %s", p.Name) + } + if len(p.State) != 1 || p.State[0] != "UP" { + t.Errorf("expected state [UP], got %v", p.State) + } + if p.Nodes != "node[1-10]" { + t.Errorf("expected nodes 'node[1-10]', got %s", p.Nodes) + } + if p.TotalCPUs != 640 { + t.Errorf("expected 640 total CPUs, got %d", p.TotalCPUs) + } + if p.TotalNodes != 10 { + t.Errorf("expected 10 total nodes, got %d", p.TotalNodes) + } + if p.MaxTime != "86400" { + t.Errorf("expected max_time '86400', got %s", p.MaxTime) + } +} + +func TestGetPartitions_Empty(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{}) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + partitions, err := svc.GetPartitions(context.Background()) + if err != nil { + t.Fatalf("GetPartitions returned error: %v", err) + } + if partitions != nil { + t.Errorf("expected nil for empty response, got %v", partitions) + } +} + +func TestGetPartition(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/slurm/v0.0.40/partition/gpu" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + resp := map[string]interface{}{ + "partitions": []map[string]interface{}{ + { + "name": "gpu", + "partition": map[string]interface{}{ + "state": []string{"UP"}, + }, + "nodes": map[string]interface{}{ + "configured": "gpu[1-4]", + "total": 4, + }, + "maximums": map[string]interface{}{ + "time": map[string]interface{}{ + "set": true, + "infinite": true, + }, + }, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + part, err := svc.GetPartition(context.Background(), "gpu") + if err != nil { + t.Fatalf("GetPartition returned error: %v", err) + } + if part == nil { + t.Fatal("expected partition, got nil") + } + if part.Name != "gpu" { + t.Errorf("expected name gpu, got %s", part.Name) + } + if part.MaxTime != "UNLIMITED" { + t.Errorf("expected max_time UNLIMITED, got %s", part.MaxTime) + } +} + +func TestGetPartition_NotFound(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{}) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + part, err := svc.GetPartition(context.Background(), "missing") + if err != nil { + t.Fatalf("GetPartition returned error: %v", err) + } + if part != nil { + t.Errorf("expected nil for missing partition, got %+v", part) + } +} + +func TestGetDiag(t *testing.T) { + client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/slurm/v0.0.40/diag" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + resp := map[string]interface{}{ + "statistics": map[string]interface{}{ + "server_thread_count": 10, + "agent_queue_size": 5, + "jobs_submitted": 100, + "jobs_running": 20, + "schedule_queue_length": 3, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + defer cleanup() + + svc := NewClusterService(client, zap.NewNop()) + diag, err := svc.GetDiag(context.Background()) + if err != nil { + t.Fatalf("GetDiag returned error: %v", err) + } + if diag == nil { + t.Fatal("expected diag response, got nil") + } + if diag.Statistics == nil { + t.Fatal("expected statistics, got nil") + } + if diag.Statistics.ServerThreadCount == nil || *diag.Statistics.ServerThreadCount != 10 { + t.Errorf("expected server_thread_count 10, got %v", diag.Statistics.ServerThreadCount) + } +} + +func TestNewSlurmClient(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "jwt.key") + os.WriteFile(keyPath, make([]byte, 32), 0644) + + client, err := NewSlurmClient("http://localhost:6820", "root", keyPath) + if err != nil { + t.Fatalf("NewSlurmClient returned error: %v", err) + } + if client == nil { + t.Fatal("expected client, got nil") + } +} + +func newClusterServiceWithObserver(srv *httptest.Server) (*ClusterService, *observer.ObservedLogs) { + core, recorded := observer.New(zapcore.DebugLevel) + l := zap.New(core) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + return NewClusterService(client, l), recorded +} + +func errorServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"errors": [{"error": "internal server error"}]}`)) + })) +} + +func TestClusterService_GetNodes_ErrorLogging(t *testing.T) { + srv := errorServer() + defer srv.Close() + + svc, logs := newClusterServiceWithObserver(srv) + _, err := svc.GetNodes(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + + if logs.Len() != 1 { + t.Fatalf("expected 1 log entry, got %d", logs.Len()) + } + entry := logs.All()[0] + if entry.Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entry.Level) + } + if len(entry.Context) == 0 { + t.Error("expected structured fields in log entry") + } +} + +func TestClusterService_GetNode_ErrorLogging(t *testing.T) { + srv := errorServer() + defer srv.Close() + + svc, logs := newClusterServiceWithObserver(srv) + _, err := svc.GetNode(context.Background(), "test-node") + if err == nil { + t.Fatal("expected error, got nil") + } + + if logs.Len() != 1 { + t.Fatalf("expected 1 log entry, got %d", logs.Len()) + } + entry := logs.All()[0] + if entry.Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entry.Level) + } + + hasName := false + for _, f := range entry.Context { + if f.Key == "name" && f.String == "test-node" { + hasName = true + } + } + if !hasName { + t.Error("expected 'name' field with value 'test-node' in log entry") + } +} + +func TestClusterService_GetPartitions_ErrorLogging(t *testing.T) { + srv := errorServer() + defer srv.Close() + + svc, logs := newClusterServiceWithObserver(srv) + _, err := svc.GetPartitions(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + + if logs.Len() != 1 { + t.Fatalf("expected 1 log entry, got %d", logs.Len()) + } + entry := logs.All()[0] + if entry.Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entry.Level) + } + if len(entry.Context) == 0 { + t.Error("expected structured fields in log entry") + } +} + +func TestClusterService_GetPartition_ErrorLogging(t *testing.T) { + srv := errorServer() + defer srv.Close() + + svc, logs := newClusterServiceWithObserver(srv) + _, err := svc.GetPartition(context.Background(), "test-partition") + if err == nil { + t.Fatal("expected error, got nil") + } + + if logs.Len() != 1 { + t.Fatalf("expected 1 log entry, got %d", logs.Len()) + } + entry := logs.All()[0] + if entry.Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entry.Level) + } + + hasName := false + for _, f := range entry.Context { + if f.Key == "name" && f.String == "test-partition" { + hasName = true + } + } + if !hasName { + t.Error("expected 'name' field with value 'test-partition' in log entry") + } +} + +func TestClusterService_GetDiag_ErrorLogging(t *testing.T) { + srv := errorServer() + defer srv.Close() + + svc, logs := newClusterServiceWithObserver(srv) + _, err := svc.GetDiag(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + + if logs.Len() != 1 { + t.Fatalf("expected 1 log entry, got %d", logs.Len()) + } + entry := logs.All()[0] + if entry.Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entry.Level) + } + if len(entry.Context) == 0 { + t.Error("expected structured fields in log entry") + } +} diff --git a/internal/service/job_service.go b/internal/service/job_service.go new file mode 100644 index 0000000..636ee99 --- /dev/null +++ b/internal/service/job_service.go @@ -0,0 +1,246 @@ +package service + +import ( + "context" + "fmt" + "strconv" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" + + "go.uber.org/zap" +) + +// JobService wraps Slurm SDK job operations with model mapping and pagination. +type JobService struct { + client *slurm.Client + logger *zap.Logger +} + +// NewJobService creates a new JobService with the given Slurm SDK client. +func NewJobService(client *slurm.Client, logger *zap.Logger) *JobService { + return &JobService{client: client, logger: logger} +} + +// SubmitJob submits a new job to Slurm and returns the job ID. +func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest) (*model.JobResponse, error) { + script := req.Script + jobDesc := &slurm.JobDescMsg{ + Script: &script, + Partition: strToPtrOrNil(req.Partition), + Qos: strToPtrOrNil(req.QOS), + Name: strToPtrOrNil(req.JobName), + } + if req.CPUs > 0 { + jobDesc.MinimumCpus = slurm.Ptr(req.CPUs) + } + if req.TimeLimit != "" { + if mins, err := strconv.ParseInt(req.TimeLimit, 10, 64); err == nil { + jobDesc.TimeLimit = &slurm.Uint32NoVal{Number: &mins} + } + } + + submitReq := &slurm.JobSubmitReq{ + Script: &script, + Job: jobDesc, + } + + result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq) + if err != nil { + s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit")) + return nil, fmt.Errorf("submit job: %w", err) + } + + resp := &model.JobResponse{} + if result.Result != nil && result.Result.JobID != nil { + resp.JobID = *result.Result.JobID + } else if result.JobID != nil { + resp.JobID = *result.JobID + } + + s.logger.Info("job submitted", zap.String("job_name", req.JobName), zap.Int32("job_id", resp.JobID)) + return resp, nil +} + +// GetJobs lists all current jobs from Slurm. +func (s *JobService) GetJobs(ctx context.Context) ([]model.JobResponse, error) { + result, _, err := s.client.Jobs.GetJobs(ctx, nil) + if err != nil { + s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs")) + return nil, fmt.Errorf("get jobs: %w", err) + } + + jobs := make([]model.JobResponse, 0, len(result.Jobs)) + for i := range result.Jobs { + jobs = append(jobs, mapJobInfo(&result.Jobs[i])) + } + return jobs, nil +} + +// GetJob retrieves a single job by ID. +func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) { + result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil) + if err != nil { + s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job")) + return nil, fmt.Errorf("get job %s: %w", jobID, err) + } + + if len(result.Jobs) == 0 { + return nil, nil + } + + resp := mapJobInfo(&result.Jobs[0]) + return &resp, nil +} + +// CancelJob cancels a job by ID. +func (s *JobService) CancelJob(ctx context.Context, jobID string) error { + _, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil) + if err != nil { + s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel")) + return fmt.Errorf("cancel job %s: %w", jobID, err) + } + s.logger.Info("job cancelled", zap.String("job_id", jobID)) + return nil +} + +// GetJobHistory queries SlurmDBD for historical jobs with pagination. +func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQuery) (*model.JobListResponse, error) { + opts := &slurm.GetSlurmdbJobsOptions{} + if query.Users != "" { + opts.Users = strToPtr(query.Users) + } + if query.Account != "" { + opts.Account = strToPtr(query.Account) + } + if query.Partition != "" { + opts.Partition = strToPtr(query.Partition) + } + if query.State != "" { + opts.State = strToPtr(query.State) + } + if query.JobName != "" { + opts.JobName = strToPtr(query.JobName) + } + if query.StartTime != "" { + opts.StartTime = strToPtr(query.StartTime) + } + if query.EndTime != "" { + opts.EndTime = strToPtr(query.EndTime) + } + + result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts) + if err != nil { + s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history")) + return nil, fmt.Errorf("get job history: %w", err) + } + + allJobs := make([]model.JobResponse, 0, len(result.Jobs)) + for i := range result.Jobs { + allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i])) + } + + total := len(allJobs) + page := query.Page + pageSize := query.PageSize + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 20 + } + + start := (page - 1) * pageSize + end := start + pageSize + if start > total { + start = total + } + if end > total { + end = total + } + + return &model.JobListResponse{ + Jobs: allJobs[start:end], + Total: total, + Page: page, + PageSize: pageSize, + }, nil +} + +// --------------------------------------------------------------------------- +// Helper functions +// --------------------------------------------------------------------------- + +func strToPtr(s string) *string { return &s } + +// strPtrOrNil returns a pointer to s if non-empty, otherwise nil. +func strToPtrOrNil(s string) *string { + if s == "" { + return nil + } + return &s +} + +// mapJobInfo maps SDK JobInfo to API JobResponse. +func mapJobInfo(ji *slurm.JobInfo) model.JobResponse { + resp := model.JobResponse{} + if ji.JobID != nil { + resp.JobID = *ji.JobID + } + if ji.Name != nil { + resp.Name = *ji.Name + } + resp.State = ji.JobState + if ji.Partition != nil { + resp.Partition = *ji.Partition + } + if ji.SubmitTime != nil && ji.SubmitTime.Number != nil { + resp.SubmitTime = ji.SubmitTime.Number + } + if ji.StartTime != nil && ji.StartTime.Number != nil { + resp.StartTime = ji.StartTime.Number + } + if ji.EndTime != nil && ji.EndTime.Number != nil { + resp.EndTime = ji.EndTime.Number + } + if ji.ExitCode != nil && ji.ExitCode.ReturnCode != nil && ji.ExitCode.ReturnCode.Number != nil { + code := int32(*ji.ExitCode.ReturnCode.Number) + resp.ExitCode = &code + } + if ji.Nodes != nil { + resp.Nodes = *ji.Nodes + } + return resp +} + +// mapSlurmdbJob maps SDK SlurmDBD Job to API JobResponse. +func mapSlurmdbJob(j *slurm.Job) model.JobResponse { + resp := model.JobResponse{} + if j.JobID != nil { + resp.JobID = *j.JobID + } + if j.Name != nil { + resp.Name = *j.Name + } + if j.State != nil { + resp.State = j.State.Current + } + if j.Partition != nil { + resp.Partition = *j.Partition + } + if j.Time != nil { + if j.Time.Submission != nil { + resp.SubmitTime = j.Time.Submission + } + if j.Time.Start != nil { + resp.StartTime = j.Time.Start + } + if j.Time.End != nil { + resp.EndTime = j.Time.End + } + } + if j.Nodes != nil { + resp.Nodes = *j.Nodes + } + return resp +} diff --git a/internal/service/job_service_test.go b/internal/service/job_service_test.go new file mode 100644 index 0000000..9e2e2aa --- /dev/null +++ b/internal/service/job_service_test.go @@ -0,0 +1,703 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "gcy_hpc_server/internal/model" + "gcy_hpc_server/internal/slurm" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func mockJobServer(handler http.HandlerFunc) (*slurm.Client, func()) { + srv := httptest.NewServer(handler) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + return client, srv.Close +} + +func TestSubmitJob(t *testing.T) { + jobID := int32(123) + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if r.URL.Path != "/slurm/v0.0.40/job/submit" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + + var body slurm.JobSubmitReq + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Job == nil || body.Job.Script == nil || *body.Job.Script != "#!/bin/bash\necho hello" { + t.Errorf("unexpected script in request body") + } + + resp := slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{ + JobID: &jobID, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "#!/bin/bash\necho hello", + Partition: "normal", + QOS: "high", + JobName: "test-job", + CPUs: 4, + TimeLimit: "60", + }) + if err != nil { + t.Fatalf("SubmitJob: %v", err) + } + if resp.JobID != 123 { + t.Errorf("expected JobID 123, got %d", resp.JobID) + } +} + +func TestSubmitJob_WithOptionalFields(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body slurm.JobSubmitReq + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Job == nil { + t.Fatal("job desc is nil") + } + if body.Job.Partition != nil { + t.Error("expected partition nil for empty string") + } + if body.Job.MinimumCpus != nil { + t.Error("expected minimum_cpus nil when CPUs=0") + } + + jobID := int32(456) + resp := slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "echo hi", + }) + if err != nil { + t.Fatalf("SubmitJob: %v", err) + } + if resp.JobID != 456 { + t.Errorf("expected JobID 456, got %d", resp.JobID) + } +} + +func TestSubmitJob_Error(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"internal"}`)) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + _, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "echo fail", + }) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestGetJobs(t *testing.T) { + jobID := int32(100) + name := "my-job" + partition := "gpu" + ts := int64(1700000000) + nodes := "node01" + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + + resp := slurm.OpenapiJobInfoResp{ + Jobs: slurm.JobInfoMsg{ + { + JobID: &jobID, + Name: &name, + JobState: []string{"RUNNING"}, + Partition: &partition, + SubmitTime: &slurm.Uint64NoVal{Number: &ts}, + StartTime: &slurm.Uint64NoVal{Number: &ts}, + EndTime: &slurm.Uint64NoVal{Number: &ts}, + Nodes: &nodes, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + jobs, err := svc.GetJobs(context.Background()) + if err != nil { + t.Fatalf("GetJobs: %v", err) + } + if len(jobs) != 1 { + t.Fatalf("expected 1 job, got %d", len(jobs)) + } + j := jobs[0] + if j.JobID != 100 { + t.Errorf("expected JobID 100, got %d", j.JobID) + } + if j.Name != "my-job" { + t.Errorf("expected Name my-job, got %s", j.Name) + } + if len(j.State) != 1 || j.State[0] != "RUNNING" { + t.Errorf("expected State [RUNNING], got %v", j.State) + } + if j.Partition != "gpu" { + t.Errorf("expected Partition gpu, got %s", j.Partition) + } + if j.SubmitTime == nil || *j.SubmitTime != ts { + t.Errorf("expected SubmitTime %d, got %v", ts, j.SubmitTime) + } + if j.Nodes != "node01" { + t.Errorf("expected Nodes node01, got %s", j.Nodes) + } +} + +func TestGetJob(t *testing.T) { + jobID := int32(200) + name := "single-job" + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiJobInfoResp{ + Jobs: slurm.JobInfoMsg{ + { + JobID: &jobID, + Name: &name, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + job, err := svc.GetJob(context.Background(), "200") + if err != nil { + t.Fatalf("GetJob: %v", err) + } + if job == nil { + t.Fatal("expected job, got nil") + } + if job.JobID != 200 { + t.Errorf("expected JobID 200, got %d", job.JobID) + } +} + +func TestGetJob_NotFound(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiJobInfoResp{Jobs: slurm.JobInfoMsg{}} + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + job, err := svc.GetJob(context.Background(), "999") + if err != nil { + t.Fatalf("GetJob: %v", err) + } + if job != nil { + t.Errorf("expected nil for not found, got %+v", job) + } +} + +func TestCancelJob(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + t.Errorf("expected DELETE, got %s", r.Method) + } + resp := slurm.OpenapiResp{} + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + err := svc.CancelJob(context.Background(), "300") + if err != nil { + t.Fatalf("CancelJob: %v", err) + } +} + +func TestCancelJob_Error(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`not found`)) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + err := svc.CancelJob(context.Background(), "999") + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestGetJobHistory(t *testing.T) { + jobID1 := int32(10) + jobID2 := int32(20) + jobID3 := int32(30) + name1 := "hist-1" + name2 := "hist-2" + name3 := "hist-3" + submission1 := int64(1700000000) + submission2 := int64(1700001000) + submission3 := int64(1700002000) + partition := "normal" + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + users := r.URL.Query().Get("users") + if users != "testuser" { + t.Errorf("expected users=testuser, got %s", users) + } + + resp := slurm.OpenapiSlurmdbdJobsResp{ + Jobs: slurm.JobList{ + { + JobID: &jobID1, + Name: &name1, + Partition: &partition, + State: &slurm.JobState{Current: []string{"COMPLETED"}}, + Time: &slurm.JobTime{Submission: &submission1}, + }, + { + JobID: &jobID2, + Name: &name2, + Partition: &partition, + State: &slurm.JobState{Current: []string{"FAILED"}}, + Time: &slurm.JobTime{Submission: &submission2}, + }, + { + JobID: &jobID3, + Name: &name3, + Partition: &partition, + State: &slurm.JobState{Current: []string{"CANCELLED"}}, + Time: &slurm.JobTime{Submission: &submission3}, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{ + Users: "testuser", + Page: 1, + PageSize: 2, + }) + if err != nil { + t.Fatalf("GetJobHistory: %v", err) + } + if result.Total != 3 { + t.Errorf("expected Total 3, got %d", result.Total) + } + if result.Page != 1 { + t.Errorf("expected Page 1, got %d", result.Page) + } + if result.PageSize != 2 { + t.Errorf("expected PageSize 2, got %d", result.PageSize) + } + if len(result.Jobs) != 2 { + t.Fatalf("expected 2 jobs on page 1, got %d", len(result.Jobs)) + } + if result.Jobs[0].JobID != 10 { + t.Errorf("expected first job ID 10, got %d", result.Jobs[0].JobID) + } + if result.Jobs[1].JobID != 20 { + t.Errorf("expected second job ID 20, got %d", result.Jobs[1].JobID) + } + if len(result.Jobs[0].State) != 1 || result.Jobs[0].State[0] != "COMPLETED" { + t.Errorf("expected state [COMPLETED], got %v", result.Jobs[0].State) + } +} + +func TestGetJobHistory_Page2(t *testing.T) { + jobID1 := int32(10) + jobID2 := int32(20) + name1 := "a" + name2 := "b" + + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiSlurmdbdJobsResp{ + Jobs: slurm.JobList{ + {JobID: &jobID1, Name: &name1}, + {JobID: &jobID2, Name: &name2}, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{ + Page: 2, + PageSize: 1, + }) + if err != nil { + t.Fatalf("GetJobHistory: %v", err) + } + if result.Total != 2 { + t.Errorf("expected Total 2, got %d", result.Total) + } + if len(result.Jobs) != 1 { + t.Fatalf("expected 1 job on page 2, got %d", len(result.Jobs)) + } + if result.Jobs[0].JobID != 20 { + t.Errorf("expected job ID 20, got %d", result.Jobs[0].JobID) + } +} + +func TestGetJobHistory_DefaultPagination(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}} + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{}) + if err != nil { + t.Fatalf("GetJobHistory: %v", err) + } + if result.Page != 1 { + t.Errorf("expected default page 1, got %d", result.Page) + } + if result.PageSize != 20 { + t.Errorf("expected default pageSize 20, got %d", result.PageSize) + } +} + +func TestGetJobHistory_QueryMapping(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + if v := q.Get("account"); v != "proj1" { + t.Errorf("expected account=proj1, got %s", v) + } + if v := q.Get("partition"); v != "gpu" { + t.Errorf("expected partition=gpu, got %s", v) + } + if v := q.Get("state"); v != "COMPLETED" { + t.Errorf("expected state=COMPLETED, got %s", v) + } + if v := q.Get("job_name"); v != "myjob" { + t.Errorf("expected job_name=myjob, got %s", v) + } + if v := q.Get("start_time"); v != "1700000000" { + t.Errorf("expected start_time=1700000000, got %s", v) + } + if v := q.Get("end_time"); v != "1700099999" { + t.Errorf("expected end_time=1700099999, got %s", v) + } + + resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}} + json.NewEncoder(w).Encode(resp) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + _, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{ + Users: "testuser", + Account: "proj1", + Partition: "gpu", + State: "COMPLETED", + JobName: "myjob", + StartTime: "1700000000", + EndTime: "1700099999", + }) + if err != nil { + t.Fatalf("GetJobHistory: %v", err) + } +} + +func TestGetJobHistory_Error(t *testing.T) { + client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"db down"}`)) + })) + defer cleanup() + + svc := NewJobService(client, zap.NewNop()) + _, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{}) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestMapJobInfo_ExitCode(t *testing.T) { + returnCode := int64(2) + ji := &slurm.JobInfo{ + ExitCode: &slurm.ProcessExitCodeVerbose{ + ReturnCode: &slurm.Uint32NoVal{Number: &returnCode}, + }, + } + resp := mapJobInfo(ji) + if resp.ExitCode == nil || *resp.ExitCode != 2 { + t.Errorf("expected exit code 2, got %v", resp.ExitCode) + } +} + +func TestMapSlurmdbJob_NilFields(t *testing.T) { + j := &slurm.Job{} + resp := mapSlurmdbJob(j) + if resp.JobID != 0 { + t.Errorf("expected JobID 0, got %d", resp.JobID) + } + if resp.State != nil { + t.Errorf("expected nil State, got %v", resp.State) + } + if resp.SubmitTime != nil { + t.Errorf("expected nil SubmitTime, got %v", resp.SubmitTime) + } +} + +// --------------------------------------------------------------------------- +// Structured logging tests using zaptest/observer +// --------------------------------------------------------------------------- + +func newJobServiceWithObserver(srv *httptest.Server) (*JobService, *observer.ObservedLogs) { + core, recorded := observer.New(zapcore.DebugLevel) + l := zap.New(core) + client, _ := slurm.NewClient(srv.URL, srv.Client()) + return NewJobService(client, l), recorded +} + +func TestJobService_SubmitJob_SuccessLog(t *testing.T) { + jobID := int32(789) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiJobSubmitResponse{ + Result: &slurm.JobSubmitResponseMsg{JobID: &jobID}, + } + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + _, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{ + Script: "echo hi", + JobName: "log-test-job", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.InfoLevel { + t.Errorf("expected InfoLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["job_name"] != "log-test-job" { + t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"]) + } + gotJobID, ok := fields["job_id"] + if !ok { + t.Fatal("expected job_id field in log entry") + } + if gotJobID != int32(789) && gotJobID != int64(789) { + t.Errorf("expected job_id=789, got %v (%T)", gotJobID, gotJobID) + } +} + +func TestJobService_SubmitJob_ErrorLog(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"internal"}`)) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + _, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{Script: "echo fail"}) + if err == nil { + t.Fatal("expected error, got nil") + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["operation"] != "submit" { + t.Errorf("expected operation=submit, got %v", fields["operation"]) + } + if _, ok := fields["error"]; !ok { + t.Error("expected error field in log entry") + } +} + +func TestJobService_CancelJob_SuccessLog(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := slurm.OpenapiResp{} + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + err := svc.CancelJob(context.Background(), "555") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.InfoLevel { + t.Errorf("expected InfoLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["job_id"] != "555" { + t.Errorf("expected job_id=555, got %v", fields["job_id"]) + } +} + +func TestJobService_CancelJob_ErrorLog(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`not found`)) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + err := svc.CancelJob(context.Background(), "999") + if err == nil { + t.Fatal("expected error, got nil") + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["operation"] != "cancel" { + t.Errorf("expected operation=cancel, got %v", fields["operation"]) + } + if fields["job_id"] != "999" { + t.Errorf("expected job_id=999, got %v", fields["job_id"]) + } + if _, ok := fields["error"]; !ok { + t.Error("expected error field in log entry") + } +} + +func TestJobService_GetJobs_ErrorLog(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"down"}`)) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + _, err := svc.GetJobs(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["operation"] != "get_jobs" { + t.Errorf("expected operation=get_jobs, got %v", fields["operation"]) + } + if _, ok := fields["error"]; !ok { + t.Error("expected error field in log entry") + } +} + +func TestJobService_GetJob_ErrorLog(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"down"}`)) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + _, err := svc.GetJob(context.Background(), "200") + if err == nil { + t.Fatal("expected error, got nil") + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["operation"] != "get_job" { + t.Errorf("expected operation=get_job, got %v", fields["operation"]) + } + if fields["job_id"] != "200" { + t.Errorf("expected job_id=200, got %v", fields["job_id"]) + } + if _, ok := fields["error"]; !ok { + t.Error("expected error field in log entry") + } +} + +func TestJobService_GetJobHistory_ErrorLog(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"db down"}`)) + })) + defer srv.Close() + + svc, recorded := newJobServiceWithObserver(srv) + _, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{}) + if err == nil { + t.Fatal("expected error, got nil") + } + + entries := recorded.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Level != zapcore.ErrorLevel { + t.Errorf("expected ErrorLevel, got %v", entries[0].Level) + } + fields := entries[0].ContextMap() + if fields["operation"] != "get_job_history" { + t.Errorf("expected operation=get_job_history, got %v", fields["operation"]) + } + if _, ok := fields["error"]; !ok { + t.Error("expected error field in log entry") + } +} diff --git a/internal/service/slurm_client.go b/internal/service/slurm_client.go new file mode 100644 index 0000000..121311e --- /dev/null +++ b/internal/service/slurm_client.go @@ -0,0 +1,15 @@ +package service + +import ( + "gcy_hpc_server/internal/slurm" +) + +// NewSlurmClient creates a Slurm SDK client with JWT authentication. +// It reads the JWT key from the given keyPath and signs tokens automatically. +func NewSlurmClient(apiURL, userName, jwtKeyPath string) (*slurm.Client, error) { + return slurm.NewClientWithOpts( + apiURL, + slurm.WithUsername(userName), + slurm.WithJWTKey(jwtKeyPath), + ) +}