Compare commits

...

50 Commits

Author SHA1 Message Date
dailz
9092278d26 refactor(test): update test expectations for removed submit route
- Comment out submit route assertions in main_test.go and server_test.go

- Comment out TestTask_OldAPICompatibility in task_test.go

- Update expected route count 31→30 in testenv env_test.go

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:16:07 +08:00
dailz
7c374f4fd5 refactor(handler,server): disable SubmitApplication endpoint, replaced by POST /tasks
- Comment out SubmitApplication handler method

- Comment out route registration in server.go (interface + router + placeholder)

- Comment out related handler tests

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:15:55 +08:00
dailz
36d842350c refactor(service): disable SubmitFromApplication fallback, fully replaced by POST /tasks
- Comment out SubmitFromApplication method and its fallback path

- Comment out 5 tests that tested the old direct-submission code

- Remove unused imports after commenting out the method

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:15:42 +08:00
dailz
80f2bd32d9 docs(openapi): update spec to match code — add Tasks, fix schemas, remove submit endpoint
- Add Tasks tag, /tasks paths, and Task schemas (CreateTaskRequest, TaskResponse, TaskListResponse)

- Fix SubmitJobRequest.work_dir, InitUploadRequest mime_type/chunk_size, UploadSessionResponse.created_at

- Fix FolderResponse: add file_count/subfolder_count, remove updated_at

- Fix response wrapping for File/Upload/Folder endpoints to use ApiResponseSuccess

- Remove /applications/{id}/submit path and ApplicationSubmitRequest schema

- Update Applications tag description

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:15:31 +08:00
dailz
52a34e2cb0 feat(client): add CLI client entry point 2026-04-16 13:24:21 +08:00
dailz
b9b2f0d9b4 feat(testutil): add MockSlurm, MockMinIO, TestEnv and 37 integration tests
- mockminio: in-memory ObjectStorage with all 11 methods, thread-safe, SHA256 ETag, Range support
- mockslurm: httptest server with 11 Slurm REST API endpoints, job eviction from active to history queue
- testenv: one-line test environment factory (SQLite + MockSlurm + MockMinIO + all stores/services/handlers + httptest server)
- integration tests: 37 tests covering Jobs(5), Cluster(5), App(6), Upload(5), File(4), Folder(4), Task(4), E2E(1)
- no external dependencies, no existing files modified
2026-04-16 13:23:27 +08:00
dailz
73504f9fdb feat(app): add TaskPoller, wire DI, and add task integration tests 2026-04-15 21:31:17 +08:00
dailz
3f8a680c99 feat(handler): add TaskHandler endpoints and register task routes 2026-04-15 21:31:11 +08:00
dailz
ec64300ff2 feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission 2026-04-15 21:31:02 +08:00
dailz
acf8c1d62b feat(store): add TaskStore CRUD and batch query methods for files and blobs 2026-04-15 21:30:51 +08:00
dailz
d46a784efb feat(model): add Task model, DTOs, and status constants for task submission system 2026-04-15 21:30:44 +08:00
dailz
79870333cb fix(service): tolerate concurrent pending-to-uploading status race in UploadChunk
When multiple chunk uploads race on the pending→uploading transition, ignore ErrRecordNotFound from UpdateSessionStatus since another request already completed the update.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 10:27:12 +08:00
dailz
d9a60c3511 fix(model): rename Application table to hpc_applications
Avoid table name collision with other systems.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:32:11 +08:00
dailz
20576bc325 docs(openapi): add file storage API specifications
Add 13 endpoints for chunked upload, file management, and folder CRUD with 6 new schemas.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:30:18 +08:00
dailz
c0176d7764 feat(app): wire file storage DI, cleanup worker, and integration tests
Add DI wiring with graceful MinIO fallback, background cleanup worker for expired sessions and leaked multipart uploads, and end-to-end integration tests.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:23:25 +08:00
dailz
2298e92516 feat(handler): add upload, file, and folder handlers with routes
Add UploadHandler (5 endpoints), FileHandler (4 endpoints), FolderHandler (4 endpoints) with Gin route registration in server.go.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:23:17 +08:00
dailz
f0847d3978 feat(service): add upload, download, file, and folder services
Add UploadService (dedup, chunk lifecycle, ComposeObject), DownloadService (Range support), FileService (ref counting), FolderService (path validation).

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:23:09 +08:00
dailz
a114821615 feat(server): add streaming response helpers for file download
Add ParseRange, StreamFile, StreamRange for full and partial content delivery.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:58 +08:00
dailz
bf89de12f0 feat(store): add blob, file, folder, and upload stores
Add BlobStore (ref counting), FileStore (soft delete + pagination), FolderStore (materialized path), UploadStore (idempotent upsert), and update AutoMigrate.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:44 +08:00
dailz
c861ff3adf feat(storage): add ObjectStorage interface and MinIO client
Add ObjectStorage interface (11 methods) with MinioClient implementation using minio-go Core.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:33 +08:00
dailz
0e4f523746 feat(model): add file storage GORM models and DTOs
Add FileBlob, File, Folder, UploadSession, UploadChunk models with validators.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:25 +08:00
dailz
44895214d4 feat(config): add MinIO object storage configuration
Add MinioConfig struct with connection, bucket, chunk size, and session TTL settings.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:18 +08:00
dailz
a65c8762af fix(service): add environment variables and fix work directory permissions for Slurm job submission
Slurm requires environment variables in job submission; without them it returns 'batch job cannot run without an environment'. Also chmod the entire directory path to 0777 to bypass umask, ensuring Slurm and compute node users can write.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 13:06:51 +08:00
dailz
04f99cc1c4 docs(openapi): update spec for Application Definition
Add 6 application endpoints and schemas to OpenAPI spec. Update .gitignore.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:13:02 +08:00
dailz
32f5792b68 feat(service): pass work directory to Slurm job submission
Add WorkDir to SubmitJobRequest and pass it as CurrentWorkingDirectory to Slurm REST API. Fixes Slurm 500 error when working directory is not specified.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:12:28 +08:00
dailz
328691adff feat(config): add WorkDirBase for application job working directory
Add WorkDirBase config field for auto-generated job working directories. Pattern: {base}/{app_name}/{timestamp}_{random}/

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:11:48 +08:00
dailz
10bb15e5b2 feat(handler): add Application handler, routes, and wiring
Add ApplicationHandler with CRUD + Submit endpoints. Register 6 routes, wire in app.go, update main_test.go references. 22 handler tests.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:10:54 +08:00
dailz
d3eb728c2f feat(service): add Application service with parameter validation and script rendering
Add ApplicationService with ValidateParams, RenderScript, SubmitFromApplication. Includes shell escaping, longest-first parameter replacement, and work directory generation. 15 tests.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:10:09 +08:00
dailz
4a8153aa6c feat(model): add Application model and store
Add Application and ParameterSchema models with CRUD store. Includes 10 store tests and ParamType constants.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:08:24 +08:00
dailz
dd8d226e78 refactor: remove JobTemplate production code
Remove all JobTemplate model, store, handler, migrations, and wiring. Replaced by Application Definition system.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:07:46 +08:00
dailz
62e458cb7a docs(openapi): update GET /jobs with pagination and JobListResponse
Add page/page_size query parameters, change response from JobResponse[] to JobListResponse, add 400 error code.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 15:15:24 +08:00
dailz
2cb6fbecdd feat(service): add pagination to GetJobs endpoint
GetJobs now accepts page/page_size query parameters and returns JobListResponse instead of raw array. Uses in-memory pagination matching GetJobHistory pattern.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 15:14:56 +08:00
dailz
35a4017b8e docs(model): add Chinese field comments to all model structs
Add inline comments to SubmitJobRequest, JobListResponse, JobHistoryQuery, JobTemplate, CreateTemplateRequest, and UpdateTemplateRequest fields, consistent with existing cluster.go and JobResponse style.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 13:53:54 +08:00
dailz
f4177dd287 feat(service): add GetJob fallback to SlurmDBD history and expand query params
GetJob now falls back to SlurmDBD history when active queue returns 404 or empty jobs. Expand JobHistoryQuery from 7 to 16 filter params (add SubmitTime, Cluster, Qos, Constraints, ExitCode, Node, Reservation, Groups, Wckey).

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 13:43:31 +08:00
dailz
b3d787c97b fix(slurm): parse structured errors from non-2xx Slurm API responses
Replace ErrorResponse with SlurmAPIError that extracts structured errors/warnings from JSON body when Slurm returns non-2xx (e.g. 404 with valid JSON). Add IsNotFound helper for fallback logic.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 13:43:17 +08:00
dailz
30f0fbc34b fix(slurm): correct PartitionInfoMaximums CpusPerNode/CpusPerSocket types to Uint32NoVal
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:39:29 +08:00
dailz
34ba617cbf fix(test): update log assertions for debug logging and field expansion
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:13:13 +08:00
dailz
824d9e816f feat(service): map additional Slurm SDK fields and fix ExitCode/Default bugs
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:12:51 +08:00
dailz
85901fe18a feat(model): expand API response fields to expose full Slurm data
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:12:33 +08:00
dailz
270552ba9a feat(service): add debug logging for Slurm API calls with request/response body and latency
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 10:28:58 +08:00
dailz
347b0e1229 fix: remove redundant binding tags and clarify logger compress logic
- Remove binding:"required" from model fields that are manually validated in handlers. - Add parentheses to logger compress default to clarify operator precedence.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 09:25:46 +08:00
dailz
c070dd8abc fix(slurm): add default 30s timeout to HTTP client
Replaces http.DefaultClient with a client that has a 30s timeout to prevent indefinite hangs when the Slurm REST API is unresponsive.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 09:25:35 +08:00
dailz
1359730300 fix(store): return ErrRecordNotFound when updating non-existent template
RowsAffected == 0 now returns gorm.ErrRecordNotFound so the handler can respond with 404 instead of silently returning 200.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 09:21:03 +08:00
dailz
4ff02d4a80 fix: 移除 main() 中多余的 defer application.Close()
Run() 在所有退出路径中已调用 Close(),main 中的 defer 是冗余的。

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:48:42 +08:00
dailz
1784331969 feat: 添加应用骨架,配置化 zap 日志贯穿全链路
- cmd/server/main.go: 使用 logger.NewLogger(cfg.Log) 替代 zap.NewProduction()

- internal/app: 依赖注入组装 DB/Slurm/Service/Handler,传递 logger

- internal/middleware: RequestLogger 请求日志中间件

- internal/server: 统一响应格式和路由注册

- go.mod: module 更名为 gcy_hpc_server,添加 gin/zap/lumberjack/gorm 依赖

- 日志初始化失败时 fail fast (os.Exit(1))

- GormLevel 从配置传递到 NewGormDB,支持 YAML 独立配置

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:40:16 +08:00
dailz
e6162063ca feat: 添加 HTTP 处理层和结构化日志
- JobHandler: 提交/查询/取消/历史,5xx Error + 4xx Warn 日志

- ClusterHandler: 节点/分区/诊断,错误和未找到日志

- TemplateHandler: CRUD 操作,创建/更新/删除 Info + 未找到 Warn

- 不记录成功响应(由 middleware.RequestLogger 处理)

- 不记录请求体和模板内容(安全考虑)

- 完整 TDD 测试,使用 zaptest/observer 验证日志级别和字段

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:40:06 +08:00
dailz
4903f7d07f feat: 添加业务服务层和结构化日志
- JobService: 提交、查询、取消、历史记录,记录关键操作日志

- ClusterService: 节点、分区、诊断查询,记录错误日志

- NewSlurmClient: JWT 认证 HTTP 客户端工厂

- 所有构造函数接受 *zap.Logger 参数实现依赖注入

- 提交/取消成功记录 Info,API 错误记录 Error

- 完整 TDD 测试,使用 zaptest/observer 验证日志输出

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:46 +08:00
dailz
fbfd5c5f42 feat: 添加数据模型和存储层
- model: JobTemplate、SubmitJobRequest、JobHistoryQuery 等模型定义

- store: NewGormDB MySQL 连接池,使用 zap 日志替代 GORM 默认日志

- store: TemplateStore CRUD 操作,支持 GORM AutoMigrate

- NewGormDB 接受 gormLevel 参数,由上层传入配置值

- 完整 TDD 测试覆盖

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:30 +08:00
dailz
f7a21ee455 feat: 添加 zap 日志工厂和 GORM 日志桥接
- NewLogger 工厂函数:支持 JSON/Console 编码、stdout/文件/多输出、lumberjack 轮转

- NewGormLogger 实现 gorm.Interface:Trace 区分错误/慢查询/正常查询

- output_stdout 用 *bool 三态处理(nil=true, true, false)

- 默认值:level=info, encoding=json, max_size=100, max_backups=5, max_age=30

- 慢查询阈值 200ms,ErrRecordNotFound 不视为错误

- 编译时接口检查: var _ gormlogger.Interface = (*GormLogger)(nil)

- 完整 TDD 测试覆盖

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:21 +08:00
dailz
7550e75945 feat: 添加配置加载和日志配置支持
- 新增 LogConfig 结构体,支持 9 个日志配置字段(level, encoding, output_stdout, file_path, max_size, max_backups, max_age, compress, gorm_level)

- Config 结构体新增 Log 字段,支持 YAML 解析

- output_stdout 使用 *bool 指针类型,nil 默认为 true

- 更新 config.example.yaml 添加完整 log 配置段

- 新增 TDD 测试:日志配置解析、向后兼容、字段完整性

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 08:39:09 +08:00
109 changed files with 28149 additions and 32 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
bin/
*.exe
.sisyphus/

662
cmd/client/main.go Normal file
View File

@@ -0,0 +1,662 @@
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// ── API response types ───────────────────────────────────────────────
type apiResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data"`
Error string `json:"error,omitempty"`
}
type sessionResponse struct {
ID int64 `json:"id"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
ChunkSize int64 `json:"chunk_size"`
TotalChunks int `json:"total_chunks"`
SHA256 string `json:"sha256"`
Status string `json:"status"`
UploadedChunks []int `json:"uploaded_chunks"`
}
type fileResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Size int64 `json:"size"`
MimeType string `json:"mime_type"`
SHA256 string `json:"sha256"`
CreatedAt string `json:"created_at"`
}
type folderResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parent_id,omitempty"`
Path string `json:"path"`
CreatedAt string `json:"created_at"`
}
type listFilesResponse struct {
Files []fileResponse `json:"files"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// ── Helpers ──────────────────────────────────────────────────────────
func formatSize(b int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case b >= GB:
return fmt.Sprintf("%.2f GB", float64(b)/float64(GB))
case b >= MB:
return fmt.Sprintf("%.2f MB", float64(b)/float64(MB))
case b >= KB:
return fmt.Sprintf("%.2f KB", float64(b)/float64(KB))
default:
return fmt.Sprintf("%d B", b)
}
}
func doRequest(server, method, path string, body io.Reader, contentType string) (*apiResponse, error) {
url := server + path
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(raw))
}
var apiResp apiResponse
if err := json.Unmarshal(raw, &apiResp); err != nil {
return nil, fmt.Errorf("parsing response: %w\nbody: %s", err, string(raw))
}
if !apiResp.Success {
return &apiResp, fmt.Errorf("API error: %s", apiResp.Error)
}
return &apiResp, nil
}
func computeSHA256(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
// ── Commands ─────────────────────────────────────────────────────────
func cmdMkdir(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> mkdir <name> [-parent <id>]")
os.Exit(1)
}
name := args[0]
var parentID *int64
for i := 1; i+1 < len(args); i += 2 {
if args[i] == "-parent" {
v, err := strconv.ParseInt(args[i+1], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid parent id: %s\n", args[i+1])
os.Exit(1)
}
parentID = &v
}
}
payload := map[string]interface{}{"name": name}
if parentID != nil {
payload["parent_id"] = *parentID
}
body, _ := json.Marshal(payload)
resp, err := doRequest(server, http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body), "application/json")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var folder folderResponse
if err := json.Unmarshal(resp.Data, &folder); err != nil {
fmt.Fprintf(os.Stderr, "parsing folder response: %v\n", err)
os.Exit(1)
}
fmt.Printf("Folder created: id=%d name=%q path=%q\n", folder.ID, folder.Name, folder.Path)
}
func cmdListFolders(server string, args []string) {
path := "/api/v1/files/folders?"
for i := 0; i+1 < len(args); i += 2 {
if args[i] == "-parent" {
path += "parent_id=" + args[i+1] + "&"
}
}
resp, err := doRequest(server, http.MethodGet, path, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var folders []folderResponse
if err := json.Unmarshal(resp.Data, &folders); err != nil {
fmt.Fprintf(os.Stderr, "parsing folders response: %v\n", err)
os.Exit(1)
}
if len(folders) == 0 {
fmt.Println("No folders found.")
return
}
fmt.Printf("%-8s %-30s %-10s %s\n", "ID", "Name", "ParentID", "Path")
fmt.Println(strings.Repeat("-", 80))
for _, f := range folders {
pid := "<root>"
if f.ParentID != nil {
pid = strconv.FormatInt(*f.ParentID, 10)
}
fmt.Printf("%-8d %-30s %-10s %s\n", f.ID, f.Name, pid, f.Path)
}
}
func cmdDeleteFolder(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> delete-folder <id>")
os.Exit(1)
}
id := args[0]
_, err := doRequest(server, http.MethodDelete, "/api/v1/files/folders/"+id, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
fmt.Printf("Folder %s deleted.\n", id)
}
func cmdUpload(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> upload <file> [-folder <id>] [-chunk-size <bytes>]")
os.Exit(1)
}
filePath := args[0]
var folderID *int64
chunkSize := int64(8 * 1024 * 1024) // 8 MB default
for i := 1; i+1 < len(args); i += 2 {
switch args[i] {
case "-folder":
v, err := strconv.ParseInt(args[i+1], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid folder id: %s\n", args[i+1])
os.Exit(1)
}
folderID = &v
case "-chunk-size":
v, err := strconv.ParseInt(args[i+1], 10, 64)
if err != nil || v <= 0 {
fmt.Fprintf(os.Stderr, "invalid chunk size: %s\n", args[i+1])
os.Exit(1)
}
chunkSize = v
}
}
start := time.Now()
// Open file and get size
f, err := os.Open(filePath)
if err != nil {
fmt.Fprintf(os.Stderr, "cannot open file: %v\n", err)
os.Exit(1)
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
fmt.Fprintf(os.Stderr, "cannot stat file: %v\n", err)
os.Exit(1)
}
fileSize := fi.Size()
fileName := filepath.Base(filePath)
// Compute SHA256
fmt.Printf("Computing SHA256 for %s (%s)...\n", fileName, formatSize(fileSize))
sha, err := computeSHA256(filePath)
if err != nil {
fmt.Fprintf(os.Stderr, "sha256 error: %v\n", err)
os.Exit(1)
}
fmt.Printf("SHA256: %s\n", sha)
// Init upload
payload := map[string]interface{}{
"file_name": fileName,
"file_size": fileSize,
"sha256": sha,
"chunk_size": chunkSize,
}
if folderID != nil {
payload["folder_id"] = *folderID
}
body, _ := json.Marshal(payload)
resp, err := doRequest(server, http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body), "application/json")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
// Check for dedup (秒传) — response is a file, not a session
var sess sessionResponse
if err := json.Unmarshal(resp.Data, &sess); err == nil && sess.Status != "" && sess.TotalChunks > 0 {
// It's a session, proceed with chunk uploads
} else {
// Try to parse as file (dedup hit)
var fr fileResponse
if err := json.Unmarshal(resp.Data, &fr); err == nil && fr.Name != "" {
elapsed := time.Since(start)
fmt.Printf("⚡ 秒传 (instant upload)! File already exists.\n")
fmt.Printf(" ID: %d\n", fr.ID)
fmt.Printf(" Name: %s\n", fr.Name)
fmt.Printf(" Size: %s\n", formatSize(fr.Size))
fmt.Printf(" SHA256: %s\n", fr.SHA256)
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
return
}
// If neither, just try the session path anyway
if err := json.Unmarshal(resp.Data, &sess); err != nil {
fmt.Fprintf(os.Stderr, "unexpected response: %s\n", string(resp.Data))
os.Exit(1)
}
}
sessionID := sess.ID
totalChunks := sess.TotalChunks
fmt.Printf("Upload session created: id=%d total_chunks=%d chunk_size=%s\n",
sessionID, totalChunks, formatSize(chunkSize))
// Upload chunks (4 concurrent workers)
const uploadWorkers = 4
type chunkResult struct {
index int
err error
}
work := make(chan int, totalChunks)
results := make(chan chunkResult, totalChunks)
for i := 0; i < totalChunks; i++ {
work <- i
}
close(work)
var uploaded int64 // atomic counter for progress
var wg sync.WaitGroup
for w := 0; w < uploadWorkers; w++ {
wg.Add(1)
go func() {
defer wg.Done()
for idx := range work {
offset := int64(idx) * chunkSize
remaining := fileSize - offset
thisChunkSize := chunkSize
if remaining < chunkSize {
thisChunkSize = remaining
}
sectionReader := io.NewSectionReader(f, offset, thisChunkSize)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("chunk", fileName)
if err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
if _, err := io.Copy(part, sectionReader); err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
writer.Close()
chunkURL := fmt.Sprintf("%s/api/v1/files/uploads/%d/chunks/%d", server, sessionID, idx)
chunkReq, err := http.NewRequest(http.MethodPut, chunkURL, &buf)
if err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
chunkReq.Header.Set("Content-Type", writer.FormDataContentType())
chunkResp, err := http.DefaultClient.Do(chunkReq)
if err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
raw, _ := io.ReadAll(chunkResp.Body)
chunkResp.Body.Close()
if chunkResp.StatusCode >= 400 {
results <- chunkResult{index: idx, err: fmt.Errorf("HTTP %d: %s", chunkResp.StatusCode, string(raw))}
continue
}
var chunkAPIResp apiResponse
if err := json.Unmarshal(raw, &chunkAPIResp); err == nil && !chunkAPIResp.Success {
results <- chunkResult{index: idx, err: fmt.Errorf("%s", chunkAPIResp.Error)}
continue
}
results <- chunkResult{index: idx, err: nil}
}
}()
}
go func() {
wg.Wait()
close(results)
}()
var firstErr error
for res := range results {
if res.err != nil {
if firstErr == nil {
firstErr = res.err
}
fmt.Fprintf(os.Stderr, "chunk %d failed: %v\n", res.index, res.err)
}
done := atomic.AddInt64(&uploaded, 1)
pct := float64(done) / float64(totalChunks) * 100
doneBytes := done * chunkSize
if doneBytes > fileSize {
doneBytes = fileSize
}
fmt.Printf("\rUploading: %.1f%% (%s / %s)", pct, formatSize(doneBytes), formatSize(fileSize))
}
fmt.Println()
if firstErr != nil {
fmt.Fprintf(os.Stderr, "upload failed: %v\n", firstErr)
os.Exit(1)
}
// Complete upload
completeURL := fmt.Sprintf("/api/v1/files/uploads/%d/complete", sessionID)
resp, err = doRequest(server, http.MethodPost, completeURL, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var result fileResponse
if err := json.Unmarshal(resp.Data, &result); err != nil {
fmt.Fprintf(os.Stderr, "parsing complete response: %v\n", err)
os.Exit(1)
}
elapsed := time.Since(start)
speed := float64(fileSize) / elapsed.Seconds()
fmt.Println("✓ Upload complete!")
fmt.Printf(" ID: %d\n", result.ID)
fmt.Printf(" Name: %s\n", result.Name)
fmt.Printf(" Size: %s\n", formatSize(result.Size))
fmt.Printf(" SHA256: %s\n", result.SHA256)
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
fmt.Printf(" Speed: %s/s\n", formatSize(int64(speed)))
}
func cmdDownload(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> download <file_id> [-o <output>]")
os.Exit(1)
}
fileID := args[0]
var output string
for i := 1; i+1 < len(args); i += 2 {
if args[i] == "-o" {
output = args[i+1]
}
}
start := time.Now()
url := server + "/api/v1/files/" + fileID + "/download"
resp, err := http.Get(url)
if err != nil {
fmt.Fprintf(os.Stderr, "download error: %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
fmt.Fprintf(os.Stderr, "download failed (HTTP %d): %s\n", resp.StatusCode, string(body))
os.Exit(1)
}
// Determine output filename
if output == "" {
// Try Content-Disposition header
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
if idx := strings.Index(cd, "filename="); idx != -1 {
output = strings.Trim(cd[idx+9:], `"`)
}
}
if output == "" {
output = "download_" + fileID
}
}
outFile, err := os.Create(output)
if err != nil {
fmt.Fprintf(os.Stderr, "cannot create output file: %v\n", err)
os.Exit(1)
}
defer outFile.Close()
written, err := io.Copy(outFile, resp.Body)
if err != nil {
fmt.Fprintf(os.Stderr, "download write error: %v\n", err)
os.Exit(1)
}
elapsed := time.Since(start)
speed := float64(written) / elapsed.Seconds()
fmt.Println("✓ Download complete!")
fmt.Printf(" File: %s\n", output)
fmt.Printf(" Size: %s\n", formatSize(written))
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
fmt.Printf(" Speed: %s/s\n", formatSize(int64(speed)))
}
func cmdListFiles(server string, args []string) {
var folderID, page, pageSize string
for i := 0; i+1 < len(args); i += 2 {
switch args[i] {
case "-folder":
folderID = args[i+1]
case "-page":
page = args[i+1]
case "-page-size":
pageSize = args[i+1]
}
}
path := "/api/v1/files?"
if folderID != "" {
path += "folder_id=" + folderID + "&"
}
if page != "" {
path += "page=" + page + "&"
}
if pageSize != "" {
path += "page_size=" + pageSize + "&"
}
resp, err := doRequest(server, http.MethodGet, path, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var list listFilesResponse
if err := json.Unmarshal(resp.Data, &list); err != nil {
fmt.Fprintf(os.Stderr, "parsing files response: %v\n", err)
os.Exit(1)
}
if len(list.Files) == 0 {
fmt.Println("No files found.")
return
}
fmt.Printf("%-8s %-30s %-12s %-10s %s\n", "ID", "Name", "Size", "MIME", "Created")
fmt.Println(strings.Repeat("-", 90))
for _, f := range list.Files {
name := f.Name
if len(name) > 28 {
name = name[:25] + "..."
}
fmt.Printf("%-8d %-30s %-12s %-10s %s\n", f.ID, name, formatSize(f.Size), f.MimeType, f.CreatedAt)
}
fmt.Printf("\nTotal: %d Page: %d PageSize: %d\n", list.Total, list.Page, list.PageSize)
}
func cmdDeleteFile(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> delete-file <file_id>")
os.Exit(1)
}
fileID := args[0]
_, err := doRequest(server, http.MethodDelete, "/api/v1/files/"+fileID, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
fmt.Printf("File %s deleted.\n", fileID)
}
// ── Main ─────────────────────────────────────────────────────────────
func usage() {
fmt.Fprintf(os.Stderr, `HPC File Storage Client
Usage: client -server <addr> <command> [args]
Commands:
mkdir <name> [-parent <id>] 创建文件夹
ls-folders [-parent <id>] 列出文件夹
delete-folder <id> 删除文件夹
upload <file> [-folder <id>] [-chunk-size <bytes>] 上传文件(分片)
download <file_id> [-o <output>] 下载文件
ls-files [-folder <id>] [-page <n>] [-page-size <n>] 列出文件
delete-file <file_id> 删除文件
Global flags:
-server <addr> 服务器地址 (默认 http://localhost:8080)
`)
}
func main() {
server := "http://localhost:8080"
var command string
var cmdArgs []string
args := os.Args[1:]
i := 0
for i < len(args) {
if args[i] == "-server" && i+1 < len(args) {
server = args[i+1]
i += 2
continue
}
if command == "" {
command = args[i]
} else {
cmdArgs = append(cmdArgs, args[i])
}
i++
}
if command == "" {
usage()
os.Exit(1)
}
switch command {
case "mkdir":
cmdMkdir(server, cmdArgs)
case "ls-folders":
cmdListFolders(server, cmdArgs)
case "delete-folder":
cmdDeleteFolder(server, cmdArgs)
case "upload":
cmdUpload(server, cmdArgs)
case "download":
cmdDownload(server, cmdArgs)
case "ls-files":
cmdListFiles(server, cmdArgs)
case "delete-file":
cmdDeleteFile(server, cmdArgs)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
usage()
os.Exit(1)
}
}

773
cmd/server/file_test.go Normal file
View File

@@ -0,0 +1,773 @@
package main
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"sort"
"strings"
"sync"
"testing"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// ---------------------------------------------------------------------------
// In-memory ObjectStorage mock
// ---------------------------------------------------------------------------
type inMemoryStorage struct {
mu sync.RWMutex
objects map[string][]byte
bucket string
}
var _ storage.ObjectStorage = (*inMemoryStorage)(nil)
func (s *inMemoryStorage) PutObject(_ context.Context, _, key string, reader io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
data, err := io.ReadAll(reader)
if err != nil {
return storage.UploadInfo{}, fmt.Errorf("read all: %w", err)
}
s.mu.Lock()
s.objects[key] = data
s.mu.Unlock()
h := sha256.Sum256(data)
return storage.UploadInfo{ETag: hex.EncodeToString(h[:]), Size: int64(len(data))}, nil
}
func (s *inMemoryStorage) GetObject(_ context.Context, _, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
s.mu.RLock()
data, ok := s.objects[key]
s.mu.RUnlock()
if !ok {
return nil, storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
}
size := int64(len(data))
start := int64(0)
end := size - 1
if opts.Start != nil {
start = *opts.Start
}
if opts.End != nil {
end = *opts.End
}
if end >= size {
end = size - 1
}
section := io.NewSectionReader(bytes.NewReader(data), start, end-start+1)
info := storage.ObjectInfo{Key: key, Size: size}
return io.NopCloser(section), info, nil
}
func (s *inMemoryStorage) ComposeObject(_ context.Context, _, dst string, sources []string) (storage.UploadInfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
var buf bytes.Buffer
for _, src := range sources {
data, ok := s.objects[src]
if !ok {
return storage.UploadInfo{}, fmt.Errorf("source object %s not found", src)
}
buf.Write(data)
}
combined := buf.Bytes()
s.objects[dst] = combined
h := sha256.Sum256(combined)
return storage.UploadInfo{ETag: hex.EncodeToString(h[:]), Size: int64(len(combined))}, nil
}
func (s *inMemoryStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
return nil
}
func (s *inMemoryStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
return nil
}
func (s *inMemoryStorage) RemoveObject(_ context.Context, _, key string, _ storage.RemoveObjectOptions) error {
s.mu.Lock()
delete(s.objects, key)
s.mu.Unlock()
return nil
}
func (s *inMemoryStorage) ListObjects(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []storage.ObjectInfo
for k, v := range s.objects {
if strings.HasPrefix(k, prefix) {
result = append(result, storage.ObjectInfo{Key: k, Size: int64(len(v))})
}
}
sort.Slice(result, func(i, j int) bool { return result[i].Key < result[j].Key })
return result, nil
}
func (s *inMemoryStorage) RemoveObjects(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
s.mu.Lock()
for _, k := range keys {
delete(s.objects, k)
}
s.mu.Unlock()
return nil
}
func (s *inMemoryStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (s *inMemoryStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (s *inMemoryStorage) StatObject(_ context.Context, _, key string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
s.mu.RLock()
data, ok := s.objects[key]
s.mu.RUnlock()
if !ok {
return storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
}
return storage.ObjectInfo{Key: key, Size: int64(len(data))}, nil
}
// ---------------------------------------------------------------------------
// Test helpers
// ---------------------------------------------------------------------------
func setupFileTestRouter(t *testing.T) (*gin.Engine, *gorm.DB, *inMemoryStorage) {
t.Helper()
gin.SetMode(gin.TestMode)
dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
t.Fatal(err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(1)
db.AutoMigrate(&model.FileBlob{}, &model.File{}, &model.Folder{}, &model.UploadSession{}, &model.UploadChunk{})
memStore := &inMemoryStorage{objects: make(map[string][]byte)}
blobStore := store.NewBlobStore(db)
fileStore := store.NewFileStore(db)
folderStore := store.NewFolderStore(db)
uploadStore := store.NewUploadStore(db)
cfg := config.MinioConfig{
ChunkSize: 16 << 20,
MaxFileSize: 50 << 30,
MinChunkSize: 5 << 20,
SessionTTL: 48,
Bucket: "files",
}
uploadSvc := service.NewUploadService(memStore, blobStore, fileStore, uploadStore, cfg, db, zap.NewNop())
_ = service.NewDownloadService(memStore, blobStore, fileStore, "files", zap.NewNop())
folderSvc := service.NewFolderService(folderStore, fileStore, zap.NewNop())
fileSvc := service.NewFileService(memStore, blobStore, fileStore, "files", db, zap.NewNop())
uploadH := handler.NewUploadHandler(uploadSvc, zap.NewNop())
fileH := handler.NewFileHandler(fileSvc, zap.NewNop())
folderH := handler.NewFolderHandler(folderSvc, zap.NewNop())
r := gin.New()
r.Use(gin.Recovery())
v1 := r.Group("/api/v1")
files := v1.Group("/files")
uploads := files.Group("/uploads")
uploads.POST("", uploadH.InitUpload)
uploads.GET("/:id", uploadH.GetUploadStatus)
uploads.PUT("/:id/chunks/:index", uploadH.UploadChunk)
uploads.POST("/:id/complete", uploadH.CompleteUpload)
uploads.DELETE("/:id", uploadH.CancelUpload)
files.GET("", fileH.ListFiles)
files.GET("/:id", fileH.GetFile)
files.GET("/:id/download", fileH.DownloadFile)
files.DELETE("/:id", fileH.DeleteFile)
folders := files.Group("/folders")
folders.POST("", folderH.CreateFolder)
folders.GET("", folderH.ListFolders)
folders.GET("/:id", folderH.GetFolder)
folders.DELETE("/:id", folderH.DeleteFolder)
return r, db, memStore
}
// apiResponse mirrors server.APIResponse for decoding.
type apiResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) apiResponse {
t.Helper()
var resp apiResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode response: %v, body: %s", err, w.Body.String())
}
return resp
}
func createChunkRequest(t *testing.T, url string, data []byte) *http.Request {
t.Helper()
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("chunk", "chunk.bin")
if err != nil {
t.Fatal(err)
}
part.Write(data)
writer.Close()
req, err := http.NewRequest("PUT", url, &buf)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
return req
}
// helperUploadFile performs a full upload lifecycle: init → upload chunks → complete.
// Returns the file ID from the completed upload response.
func helperUploadFile(t *testing.T, router *gin.Engine, fileName string, fileData []byte, sha256Hash string, folderID *int64, chunkSize int64) int64 {
t.Helper()
fileSize := int64(len(fileData))
initBody := model.InitUploadRequest{
FileName: fileName,
FileSize: fileSize,
SHA256: sha256Hash,
FolderID: folderID,
ChunkSize: &chunkSize,
}
initJSON, _ := json.Marshal(initBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
resp := decodeResponse(t, w)
if !resp.Success {
t.Fatalf("init upload failed: %s", resp.Error)
}
if w.Code == http.StatusOK {
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("decode dedup file response: %v", err)
}
return fileResp.ID
}
if w.Code != http.StatusCreated {
t.Fatalf("init upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
var session model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &session); err != nil {
t.Fatalf("failed to decode session: %v", err)
}
totalChunks := session.TotalChunks
for i := 0; i < totalChunks; i++ {
start := int64(i) * chunkSize
end := start + chunkSize
if end > fileSize {
end = fileSize
}
chunkData := fileData[start:end]
url := fmt.Sprintf("/api/v1/files/uploads/%d/chunks/%d", session.ID, i)
cw := httptest.NewRecorder()
creq := createChunkRequest(t, url, chunkData)
router.ServeHTTP(cw, creq)
if cw.Code != http.StatusOK {
t.Fatalf("upload chunk %d: expected 200, got %d: %s", i, cw.Code, cw.Body.String())
}
}
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("complete upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("failed to decode file response: %v", err)
}
return fileResp.ID
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
func TestFileFullLifecycle(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
// Create test file data
fileData := []byte("Hello, World! This is a test file for the full lifecycle integration test.")
fileSize := int64(len(fileData))
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20) // 5MB min chunk size
// 1. Init upload
initBody := model.InitUploadRequest{
FileName: "test.txt",
FileSize: fileSize,
SHA256: sha256Hash,
ChunkSize: &chunkSize,
}
initJSON, _ := json.Marshal(initBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("init upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp := decodeResponse(t, w)
var session model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &session); err != nil {
t.Fatalf("failed to decode session: %v", err)
}
if session.Status != "pending" {
t.Fatalf("expected status pending, got %s", session.Status)
}
if session.TotalChunks != 1 {
t.Fatalf("expected 1 chunk, got %d", session.TotalChunks)
}
// 2. Upload chunk 0
url := fmt.Sprintf("/api/v1/files/uploads/%d/chunks/0", session.ID)
w = httptest.NewRecorder()
req = createChunkRequest(t, url, fileData)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("upload chunk: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 3. Complete upload
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("complete upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("failed to decode file response: %v", err)
}
if fileResp.Name != "test.txt" {
t.Fatalf("expected name test.txt, got %s", fileResp.Name)
}
// 4. Download file
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("download: expected 200, got %d: %s", w.Code, w.Body.String())
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatalf("downloaded data mismatch: got %q, want %q", w.Body.String(), string(fileData))
}
// 5. Delete file
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 6. Verify file is gone (download should fail)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code == http.StatusOK {
t.Fatal("expected download to fail after delete, got 200")
}
}
func TestFileDedup(t *testing.T) {
router, db, _ := setupFileTestRouter(t)
fileData := []byte("Duplicate content for dedup test.")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
// Upload file A (full lifecycle)
fileAID := helperUploadFile(t, router, "fileA.txt", fileData, sha256Hash, nil, chunkSize)
if fileAID == 0 {
t.Fatal("file A ID should not be 0")
}
// Upload file B with same SHA256 → should be instant dedup
fileBID := helperUploadFile(t, router, "fileB.txt", fileData, sha256Hash, nil, chunkSize)
if fileBID == 0 {
t.Fatal("file B ID should not be 0")
}
if fileBID == fileAID {
t.Fatal("file B should have a different ID than file A")
}
// Verify blob ref_count = 2
var blob model.FileBlob
if err := db.Where("sha256 = ?", sha256Hash).First(&blob).Error; err != nil {
t.Fatalf("blob not found: %v", err)
}
if blob.RefCount != 2 {
t.Fatalf("expected ref_count 2, got %d", blob.RefCount)
}
// Both files should be downloadable
for _, id := range []int64{fileAID, fileBID} {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", id), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("download file %d: expected 200, got %d", id, w.Code)
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatalf("downloaded data mismatch for file %d", id)
}
}
// Delete file A — file B should still be downloadable
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file A: expected 200, got %d: %s", w.Code, w.Body.String())
}
// Verify blob still exists with ref_count = 1
if err := db.Where("sha256 = ?", sha256Hash).First(&blob).Error; err != nil {
t.Fatalf("blob should still exist: %v", err)
}
if blob.RefCount != 2 {
t.Fatalf("expected ref_count 2 (blob still shared), got %d", blob.RefCount)
}
// File B still downloadable
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("file B should still be downloadable, got %d", w.Code)
}
}
func TestFileResumeUpload(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
// Create data that spans 2 chunks (min chunk = 5MB, use that)
chunkSize := int64(5 << 20) // 5MB
data1 := bytes.Repeat([]byte("A"), int(chunkSize))
data2 := bytes.Repeat([]byte("B"), int(chunkSize))
fileData := append(data1, data2...)
fileSize := int64(len(fileData))
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
// 1. Init upload
initBody := model.InitUploadRequest{
FileName: "resume.bin",
FileSize: fileSize,
SHA256: sha256Hash,
ChunkSize: &chunkSize,
}
initJSON, _ := json.Marshal(initBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("init: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp := decodeResponse(t, w)
var session model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &session); err != nil {
t.Fatalf("decode session: %v", err)
}
if session.TotalChunks != 2 {
t.Fatalf("expected 2 chunks, got %d", session.TotalChunks)
}
// 2. Upload chunk 0 only
w = httptest.NewRecorder()
req = createChunkRequest(t, fmt.Sprintf("/api/v1/files/uploads/%d/chunks/0", session.ID), data1)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("upload chunk 0: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 3. Get status — should show chunk 0 uploaded
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("get status: expected 200, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var status model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &status); err != nil {
t.Fatalf("decode status: %v", err)
}
if len(status.UploadedChunks) != 1 || status.UploadedChunks[0] != 0 {
t.Fatalf("expected uploaded_chunks=[0], got %v", status.UploadedChunks)
}
// 4. Upload chunk 1 (resume)
w = httptest.NewRecorder()
req = createChunkRequest(t, fmt.Sprintf("/api/v1/files/uploads/%d/chunks/1", session.ID), data2)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("upload chunk 1: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 5. Complete
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("complete: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("decode file response: %v", err)
}
// 6. Download and verify
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("download: expected 200, got %d: %s", w.Code, w.Body.String())
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatal("downloaded data does not match original")
}
}
func TestFileFolderOperations(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
// 1. Create folder
folderBody := model.CreateFolderRequest{Name: "test-folder"}
folderJSON, _ := json.Marshal(folderBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/folders", bytes.NewReader(folderJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("create folder: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp := decodeResponse(t, w)
var folderResp model.FolderResponse
if err := json.Unmarshal(resp.Data, &folderResp); err != nil {
t.Fatalf("decode folder: %v", err)
}
if folderResp.Name != "test-folder" {
t.Fatalf("expected name test-folder, got %s", folderResp.Name)
}
if folderResp.Path != "/test-folder/" {
t.Fatalf("expected path /test-folder/, got %s", folderResp.Path)
}
// 2. Upload a file into the folder
fileData := []byte("File inside folder")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
fileID := helperUploadFile(t, router, "folder_file.txt", fileData, sha256Hash, &folderResp.ID, chunkSize)
if fileID == 0 {
t.Fatal("file ID should not be 0")
}
// 3. List files in folder
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files?folder_id=%d", folderResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("list files: expected 200, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var listResp model.ListFilesResponse
if err := json.Unmarshal(resp.Data, &listResp); err != nil {
t.Fatalf("decode list: %v", err)
}
if listResp.Total != 1 {
t.Fatalf("expected 1 file in folder, got %d", listResp.Total)
}
if listResp.Files[0].Name != "folder_file.txt" {
t.Fatalf("expected file name folder_file.txt, got %s", listResp.Files[0].Name)
}
// 4. List folders (root)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/api/v1/files/folders", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("list folders: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 5. Try delete folder (should fail — not empty)
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/folders/%d", folderResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("delete non-empty folder: expected 400, got %d: %s", w.Code, w.Body.String())
}
// 6. Delete the file first
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 7. Now delete empty folder
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/folders/%d", folderResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete empty folder: expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestFileRangeDownload(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
fileData := []byte("0123456789ABCDEF") // 16 bytes
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
fileID := helperUploadFile(t, router, "range_test.bin", fileData, sha256Hash, nil, chunkSize)
// Download range bytes=4-9 → "456789"
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
req.Header.Set("Range", "bytes=4-9")
router.ServeHTTP(w, req)
if w.Code != http.StatusPartialContent {
t.Fatalf("range download: expected 206, got %d: %s", w.Code, w.Body.String())
}
contentRange := w.Header().Get("Content-Range")
expectedRange := fmt.Sprintf("bytes 4-9/%d", len(fileData))
if contentRange != expectedRange {
t.Fatalf("content-range: expected %q, got %q", expectedRange, contentRange)
}
if !bytes.Equal(w.Body.Bytes(), []byte("456789")) {
t.Fatalf("range content: expected '456789', got %q", w.Body.String())
}
// Full download still works
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("full download: expected 200, got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatalf("full download data mismatch")
}
}
func TestFileDeleteOneRefOtherStillDownloadable(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
fileData := []byte("Shared blob content for multi-ref test")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
// Upload file A
fileAID := helperUploadFile(t, router, "refA.txt", fileData, sha256Hash, nil, chunkSize)
// Upload file B with same content → dedup instant
fileBID := helperUploadFile(t, router, "refB.txt", fileData, sha256Hash, nil, chunkSize)
// Delete file A
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file A: expected 200, got %d: %s", w.Code, w.Body.String())
}
// Verify file A is gone
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
router.ServeHTTP(w, req)
// File is soft-deleted, but GetByID may still find it depending on soft-delete handling
// The important check is that file B is still downloadable
// File B should still be downloadable
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("file B should still be downloadable after deleting file A, got %d: %s", w.Code, w.Body.String())
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatal("file B download data mismatch")
}
// Delete file B — now blob should be fully removed
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file B: expected 200, got %d: %s", w.Code, w.Body.String())
}
// File B should now be gone
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code == http.StatusOK {
t.Fatal("file B should not be downloadable after deletion")
}
}

View File

@@ -0,0 +1,257 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// appListData mirrors the list endpoint response data structure.
type appListData struct {
Applications []json.RawMessage `json:"applications"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// appCreatedData mirrors the create endpoint response data structure.
type appCreatedData struct {
ID int64 `json:"id"`
}
// appMessageData mirrors the update/delete endpoint response data structure.
type appMessageData struct {
Message string `json:"message"`
}
// appData mirrors the application model returned by GET.
type appData struct {
ID int64 `json:"id"`
Name string `json:"name"`
ScriptTemplate string `json:"script_template"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
// appDoRequest is a small wrapper that marshals body and calls env.DoRequest.
func appDoRequest(env *testenv.TestEnv, method, path string, body interface{}) *http.Response {
var r io.Reader
if body != nil {
b, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("appDoRequest marshal: %v", err))
}
r = bytes.NewReader(b)
}
return env.DoRequest(method, path, r)
}
// appDecodeAll decodes the response and also reads the HTTP status.
func appDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// appSeedApp creates an app via the service (bypasses HTTP) and returns its ID.
func appSeedApp(env *testenv.TestEnv, name string) int64 {
id, err := env.CreateApp(name, "#!/bin/bash\necho hello", json.RawMessage(`[]`))
if err != nil {
panic(fmt.Sprintf("appSeedApp: %v", err))
}
return id
}
// TestIntegration_App_List verifies GET /api/v1/applications returns an empty list initially.
func TestIntegration_App_List(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/applications", nil)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var list appListData
if err := json.Unmarshal(data, &list); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if list.Total != 0 {
t.Fatalf("expected total=0, got %d", list.Total)
}
if len(list.Applications) != 0 {
t.Fatalf("expected 0 applications, got %d", len(list.Applications))
}
}
// TestIntegration_App_Create verifies POST /api/v1/applications creates an application.
func TestIntegration_App_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
body := map[string]interface{}{
"name": "test-app-create",
"script_template": "#!/bin/bash\necho hello",
"parameters": []interface{}{},
}
resp := appDoRequest(env, http.MethodPost, "/api/v1/applications", body)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var created appCreatedData
if err := json.Unmarshal(data, &created); err != nil {
t.Fatalf("unmarshal created data: %v", err)
}
if created.ID <= 0 {
t.Fatalf("expected positive id, got %d", created.ID)
}
}
// TestIntegration_App_Get verifies GET /api/v1/applications/:id returns the correct application.
func TestIntegration_App_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
id := appSeedApp(env, "test-app-get")
path := fmt.Sprintf("/api/v1/applications/%d", id)
resp := env.DoRequest(http.MethodGet, path, nil)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var app appData
if err := json.Unmarshal(data, &app); err != nil {
t.Fatalf("unmarshal app data: %v", err)
}
if app.ID != id {
t.Fatalf("expected id=%d, got %d", id, app.ID)
}
if app.Name != "test-app-get" {
t.Fatalf("expected name=test-app-get, got %s", app.Name)
}
}
// TestIntegration_App_Update verifies PUT /api/v1/applications/:id updates an application.
func TestIntegration_App_Update(t *testing.T) {
env := testenv.NewTestEnv(t)
id := appSeedApp(env, "test-app-update-before")
newName := "test-app-update-after"
body := map[string]interface{}{
"name": newName,
}
path := fmt.Sprintf("/api/v1/applications/%d", id)
resp := appDoRequest(env, http.MethodPut, path, body)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg appMessageData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal message data: %v", err)
}
if msg.Message != "application updated" {
t.Fatalf("expected message 'application updated', got %q", msg.Message)
}
getResp := env.DoRequest(http.MethodGet, path, nil)
_, _, getData, gErr := appDecodeAll(env, getResp)
if gErr != nil {
t.Fatalf("decode get response: %v", gErr)
}
var updated appData
if err := json.Unmarshal(getData, &updated); err != nil {
t.Fatalf("unmarshal updated app: %v", err)
}
if updated.Name != newName {
t.Fatalf("expected updated name=%q, got %q", newName, updated.Name)
}
}
// TestIntegration_App_Delete verifies DELETE /api/v1/applications/:id removes an application.
func TestIntegration_App_Delete(t *testing.T) {
env := testenv.NewTestEnv(t)
id := appSeedApp(env, "test-app-delete")
path := fmt.Sprintf("/api/v1/applications/%d", id)
resp := env.DoRequest(http.MethodDelete, path, nil)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg appMessageData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal message data: %v", err)
}
if msg.Message != "application deleted" {
t.Fatalf("expected message 'application deleted', got %q", msg.Message)
}
// Verify deletion returns 404.
getResp := env.DoRequest(http.MethodGet, path, nil)
getStatus, getSuccess, _, _ := appDecodeAll(env, getResp)
if getStatus != http.StatusNotFound {
t.Fatalf("expected status 404 after delete, got %d", getStatus)
}
if getSuccess {
t.Fatal("expected success=false after delete")
}
}
// TestIntegration_App_CreateValidation verifies POST /api/v1/applications with empty name returns error.
func TestIntegration_App_CreateValidation(t *testing.T) {
env := testenv.NewTestEnv(t)
body := map[string]interface{}{
"name": "",
"script_template": "#!/bin/bash\necho hello",
}
resp := appDoRequest(env, http.MethodPost, "/api/v1/applications", body)
status, success, _, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", status)
}
if success {
t.Fatal("expected success=false for validation error")
}
}

View File

@@ -0,0 +1,186 @@
package main
import (
"encoding/json"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// clusterNodeData mirrors the NodeResponse DTO returned by the API.
type clusterNodeData struct {
Name string `json:"name"`
State []string `json:"state"`
CPUs int32 `json:"cpus"`
RealMemory int64 `json:"real_memory"`
}
// clusterPartitionData mirrors the PartitionResponse DTO returned by the API.
type clusterPartitionData struct {
Name string `json:"name"`
State []string `json:"state"`
TotalNodes int32 `json:"total_nodes,omitempty"`
TotalCPUs int32 `json:"total_cpus,omitempty"`
}
// clusterDiagStat mirrors a single entry from the diag statistics.
type clusterDiagStat struct {
Parts []struct {
Param string `json:"param"`
} `json:"parts,omitempty"`
}
// clusterDecodeAll decodes the response and returns status, success, and raw data.
func clusterDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// TestIntegration_Cluster_Nodes verifies GET /api/v1/nodes returns the 3 pre-loaded mock nodes.
func TestIntegration_Cluster_Nodes(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/nodes", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var nodes []clusterNodeData
if err := json.Unmarshal(data, &nodes); err != nil {
t.Fatalf("unmarshal nodes: %v", err)
}
if len(nodes) != 3 {
t.Fatalf("expected 3 nodes, got %d", len(nodes))
}
names := make(map[string]bool, len(nodes))
for _, n := range nodes {
names[n.Name] = true
}
for _, expected := range []string{"node01", "node02", "node03"} {
if !names[expected] {
t.Errorf("missing expected node %q", expected)
}
}
}
// TestIntegration_Cluster_NodeByName verifies GET /api/v1/nodes/:name returns a single node.
func TestIntegration_Cluster_NodeByName(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/nodes/node01", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var node clusterNodeData
if err := json.Unmarshal(data, &node); err != nil {
t.Fatalf("unmarshal node: %v", err)
}
if node.Name != "node01" {
t.Fatalf("expected name=node01, got %q", node.Name)
}
}
// TestIntegration_Cluster_Partitions verifies GET /api/v1/partitions returns the 2 pre-loaded partitions.
func TestIntegration_Cluster_Partitions(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/partitions", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var partitions []clusterPartitionData
if err := json.Unmarshal(data, &partitions); err != nil {
t.Fatalf("unmarshal partitions: %v", err)
}
if len(partitions) != 2 {
t.Fatalf("expected 2 partitions, got %d", len(partitions))
}
names := make(map[string]bool, len(partitions))
for _, p := range partitions {
names[p.Name] = true
}
if !names["normal"] {
t.Error("missing expected partition \"normal\"")
}
if !names["gpu"] {
t.Error("missing expected partition \"gpu\"")
}
}
// TestIntegration_Cluster_PartitionByName verifies GET /api/v1/partitions/:name returns a single partition.
func TestIntegration_Cluster_PartitionByName(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/partitions/normal", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var part clusterPartitionData
if err := json.Unmarshal(data, &part); err != nil {
t.Fatalf("unmarshal partition: %v", err)
}
if part.Name != "normal" {
t.Fatalf("expected name=normal, got %q", part.Name)
}
}
// TestIntegration_Cluster_Diag verifies GET /api/v1/diag returns diagnostics data.
func TestIntegration_Cluster_Diag(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/diag", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
// Verify the response contains a "statistics" field (non-empty JSON object).
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("unmarshal diag top-level: %v", err)
}
if _, ok := raw["statistics"]; !ok {
t.Fatal("diag response missing \"statistics\" field")
}
}

View File

@@ -0,0 +1,202 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/testutil/testenv"
)
// e2eResponse mirrors the unified API response structure.
type e2eResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// e2eTaskCreatedData mirrors the POST /api/v1/tasks response data.
type e2eTaskCreatedData struct {
ID int64 `json:"id"`
}
// e2eTaskItem mirrors a single task in the list response.
type e2eTaskItem struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
Status string `json:"status"`
WorkDir string `json:"work_dir"`
ErrorMessage string `json:"error_message"`
}
// e2eTaskListData mirrors the list endpoint response data.
type e2eTaskListData struct {
Items []e2eTaskItem `json:"items"`
Total int64 `json:"total"`
}
// e2eSendRequest sends an HTTP request via the test env and returns the response.
func e2eSendRequest(env *testenv.TestEnv, method, path string, body string) *http.Response {
var r io.Reader
if body != "" {
r = strings.NewReader(body)
}
return env.DoRequest(method, path, r)
}
// e2eParseResponse decodes an HTTP response into e2eResponse.
func e2eParseResponse(resp *http.Response) (int, e2eResponse) {
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("e2eParseResponse read: %v", err))
}
resp.Body.Close()
var result e2eResponse
if err := json.Unmarshal(b, &result); err != nil {
panic(fmt.Sprintf("e2eParseResponse unmarshal: %v (body: %s)", err, string(b)))
}
return resp.StatusCode, result
}
// TestIntegration_E2E_CompleteWorkflow verifies the full lifecycle:
// create app → upload file → submit task → queued → running → completed.
func TestIntegration_E2E_CompleteWorkflow(t *testing.T) {
t.Log("========== E2E 全链路测试开始 ==========")
t.Log("")
env := testenv.NewTestEnv(t)
t.Log("✓ 测试环境创建完成 (SQLite + MockSlurm + MockMinIO + Router + Poller)")
t.Log("")
// Step 1: Create Application with script template and parameters.
t.Log("【步骤 1】创建应用")
appID, err := env.CreateApp("e2e-app", "#!/bin/bash\necho {{.np}}",
json.RawMessage(`[{"name":"np","type":"string","default":"1"}]`))
if err != nil {
t.Fatalf("step 1 create app: %v", err)
}
t.Logf(" → 应用创建成功, appID=%d, 脚本模板='#!/bin/bash echo {{.np}}', 参数=[np]", appID)
t.Log("")
// Step 2: Upload input file.
t.Log("【步骤 2】上传输入文件")
fileID, _ := env.UploadTestData("input.txt", []byte("test input data"))
t.Logf(" → 文件上传成功, fileID=%d, 内容='test input data' (存入 MockMinIO + SQLite)", fileID)
t.Log("")
// Step 3: Submit Task via API.
t.Log("【步骤 3】通过 HTTP API 提交任务")
body := fmt.Sprintf(
`{"app_id": %d, "task_name": "e2e-task", "values": {"np": "4"}, "file_ids": [%d]}`,
appID, fileID,
)
t.Logf(" → POST /api/v1/tasks body=%s", body)
resp := e2eSendRequest(env, http.MethodPost, "/api/v1/tasks", body)
status, result := e2eParseResponse(resp)
if status != http.StatusCreated {
t.Fatalf("step 3 submit task: status=%d, success=%v, error=%q", status, result.Success, result.Error)
}
var created e2eTaskCreatedData
if err := json.Unmarshal(result.Data, &created); err != nil {
t.Fatalf("step 3 parse task id: %v", err)
}
taskID := created.ID
if taskID <= 0 {
t.Fatalf("step 3: expected positive task id, got %d", taskID)
}
t.Logf(" → HTTP 201 Created, taskID=%d", taskID)
t.Log("")
// Step 4: Wait for queued status.
t.Log("【步骤 4】等待 TaskProcessor 异步提交到 MockSlurm")
t.Log(" → 后台流程: submitted → preparing → downloading → ready → queued")
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
t.Fatalf("step 4 wait for queued: %v (current status via API: %q)", err, taskStatus)
}
t.Logf(" → 任务状态变为 'queued' (TaskProcessor 已提交到 Slurm)")
t.Log("")
// Step 5: Get slurmJobID.
t.Log("【步骤 5】查询数据库获取 Slurm Job ID")
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
if err != nil {
t.Fatalf("step 5 get slurm job id: %v", err)
}
t.Logf(" → slurmJobID=%d (MockSlurm 中的作业号)", slurmJobID)
t.Log("")
// Step 6: Transition to RUNNING.
t.Log("【步骤 6】模拟 Slurm: 作业开始运行")
t.Logf(" → MockSlurm.SetJobState(%d, 'RUNNING')", slurmJobID)
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
t.Logf(" → MakeTaskStale(%d) — 绕过 30s 等待,让 poller 立即刷新", taskID)
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("step 6 make task stale: %v", err)
}
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
t.Fatalf("step 6 wait for running: %v (current status via API: %q)", err, taskStatus)
}
t.Logf(" → 任务状态变为 'running'")
t.Log("")
// Step 7: Transition to COMPLETED — job evicted from activeJobs to historyJobs.
t.Log("【步骤 7】模拟 Slurm: 作业运行完成")
t.Logf(" → MockSlurm.SetJobState(%d, 'COMPLETED') — 作业从 activeJobs 淘汰到 historyJobs", slurmJobID)
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
t.Log(" → MakeTaskStale + WaitForTaskStatus...")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("step 7 make task stale: %v", err)
}
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
t.Fatalf("step 7 wait for completed: %v (current status via API: %q)", err, taskStatus)
}
t.Logf(" → 任务状态变为 'completed' (通过 SlurmDB 历史回退路径获取)")
t.Log("")
// Step 8: Verify final state via GET /api/v1/tasks.
t.Log("【步骤 8】通过 HTTP API 验证最终状态")
finalStatus, finalItem := e2eFetchTaskStatus(env, taskID)
if finalStatus != "completed" {
t.Fatalf("step 8: expected status completed, got %q (error: %q)", finalStatus, finalItem.ErrorMessage)
}
t.Logf(" → GET /api/v1/tasks 返回 status='completed'")
t.Logf(" → task_name='%s', work_dir='%s'", finalItem.TaskName, finalItem.WorkDir)
t.Logf(" → MockSlurm activeJobs=%d, historyJobs=%d",
len(env.MockSlurm.GetAllActiveJobs()), len(env.MockSlurm.GetAllHistoryJobs()))
t.Log("")
// Step 9: Verify WorkDir exists and contains the input file.
t.Log("【步骤 9】验证工作目录")
if finalItem.WorkDir == "" {
t.Fatal("step 9: expected non-empty work_dir")
}
t.Logf(" → work_dir='%s' (非空TaskProcessor 已创建)", finalItem.WorkDir)
t.Log("")
t.Log("========== E2E 全链路测试通过 ✓ ==========")
}
// e2eFetchTaskStatus fetches a single task's status from the list API.
func e2eFetchTaskStatus(env *testenv.TestEnv, taskID int64) (string, e2eTaskItem) {
resp := e2eSendRequest(env, http.MethodGet, "/api/v1/tasks", "")
_, result := e2eParseResponse(resp)
var list e2eTaskListData
if err := json.Unmarshal(result.Data, &list); err != nil {
return "", e2eTaskItem{}
}
for _, item := range list.Items {
if item.ID == taskID {
return item.Status, item
}
}
return "", e2eTaskItem{}
}

View File

@@ -0,0 +1,170 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/testutil/testenv"
)
// fileAPIResp mirrors server.APIResponse for file integration tests.
type fileAPIResp struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// fileDecode parses an HTTP response body into fileAPIResp.
func fileDecode(t *testing.T, body io.Reader) fileAPIResp {
t.Helper()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("fileDecode: read body: %v", err)
}
var r fileAPIResp
if err := json.Unmarshal(data, &r); err != nil {
t.Fatalf("fileDecode: unmarshal: %v (body: %s)", err, string(data))
}
return r
}
func TestIntegration_File_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Upload a file so the list is non-empty.
env.UploadTestData("list_test.txt", []byte("hello list"))
resp := env.DoRequest("GET", "/api/v1/files", nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := fileDecode(t, resp.Body)
if !r.Success {
t.Fatalf("response not success: %s", r.Error)
}
var listResp model.ListFilesResponse
if err := json.Unmarshal(r.Data, &listResp); err != nil {
t.Fatalf("unmarshal list response: %v", err)
}
if len(listResp.Files) == 0 {
t.Fatal("expected at least 1 file in list, got 0")
}
found := false
for _, f := range listResp.Files {
if f.Name == "list_test.txt" {
found = true
break
}
}
if !found {
t.Fatal("expected to find list_test.txt in file list")
}
if listResp.Total < 1 {
t.Fatalf("expected total >= 1, got %d", listResp.Total)
}
if listResp.Page < 1 {
t.Fatalf("expected page >= 1, got %d", listResp.Page)
}
}
func TestIntegration_File_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
fileID, _ := env.UploadTestData("get_test.txt", []byte("hello get"))
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := fileDecode(t, resp.Body)
if !r.Success {
t.Fatalf("response not success: %s", r.Error)
}
var fileResp model.FileResponse
if err := json.Unmarshal(r.Data, &fileResp); err != nil {
t.Fatalf("unmarshal file response: %v", err)
}
if fileResp.ID != fileID {
t.Fatalf("expected file ID %d, got %d", fileID, fileResp.ID)
}
if fileResp.Name != "get_test.txt" {
t.Fatalf("expected name get_test.txt, got %s", fileResp.Name)
}
if fileResp.Size != int64(len("hello get")) {
t.Fatalf("expected size %d, got %d", len("hello get"), fileResp.Size)
}
}
func TestIntegration_File_Download(t *testing.T) {
env := testenv.NewTestEnv(t)
content := []byte("hello world")
fileID, _ := env.UploadTestData("download_test.txt", content)
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read download body: %v", err)
}
if string(body) != string(content) {
t.Fatalf("downloaded content mismatch: got %q, want %q", string(body), string(content))
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
t.Fatal("expected Content-Type header to be set")
}
}
func TestIntegration_File_Delete(t *testing.T) {
env := testenv.NewTestEnv(t)
fileID, _ := env.UploadTestData("delete_test.txt", []byte("hello delete"))
// Delete the file.
resp := env.DoRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := fileDecode(t, resp.Body)
if !r.Success {
t.Fatalf("delete response not success: %s", r.Error)
}
// Verify the file is gone — GET should return 500 (internal error) or 404.
getResp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
defer getResp.Body.Close()
if getResp.StatusCode == http.StatusOK {
gr := fileDecode(t, getResp.Body)
if gr.Success {
t.Fatal("expected file to be deleted, but GET still returns success")
}
}
}

View File

@@ -0,0 +1,193 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// folderData mirrors the FolderResponse DTO returned by the API.
type folderData struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parent_id,omitempty"`
Path string `json:"path"`
FileCount int64 `json:"file_count"`
SubFolderCount int64 `json:"subfolder_count"`
}
// folderMessageData mirrors the delete endpoint response data structure.
type folderMessageData struct {
Message string `json:"message"`
}
// folderDoRequest marshals body and calls env.DoRequest.
func folderDoRequest(env *testenv.TestEnv, method, path string, body interface{}) *http.Response {
var r io.Reader
if body != nil {
b, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("folderDoRequest marshal: %v", err))
}
r = bytes.NewReader(b)
}
return env.DoRequest(method, path, r)
}
// folderDecodeAll decodes the response and returns status, success, and raw data.
func folderDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// folderSeed creates a folder via HTTP and returns its ID.
func folderSeed(env *testenv.TestEnv, name string) int64 {
body := map[string]interface{}{"name": name}
resp := folderDoRequest(env, http.MethodPost, "/api/v1/files/folders", body)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
panic(fmt.Sprintf("folderSeed decode: %v", err))
}
if status != http.StatusCreated {
panic(fmt.Sprintf("folderSeed: expected 201, got %d", status))
}
if !success {
panic("folderSeed: expected success=true")
}
var f folderData
if err := json.Unmarshal(data, &f); err != nil {
panic(fmt.Sprintf("folderSeed unmarshal: %v", err))
}
return f.ID
}
// TestIntegration_Folder_Create verifies POST /api/v1/files/folders creates a folder.
func TestIntegration_Folder_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
body := map[string]interface{}{"name": "test-folder-create"}
resp := folderDoRequest(env, http.MethodPost, "/api/v1/files/folders", body)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var created folderData
if err := json.Unmarshal(data, &created); err != nil {
t.Fatalf("unmarshal created data: %v", err)
}
if created.ID <= 0 {
t.Fatalf("expected positive id, got %d", created.ID)
}
if created.Name != "test-folder-create" {
t.Fatalf("expected name=test-folder-create, got %s", created.Name)
}
}
// TestIntegration_Folder_List verifies GET /api/v1/files/folders returns a list.
func TestIntegration_Folder_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Seed two folders.
folderSeed(env, "list-folder-1")
folderSeed(env, "list-folder-2")
resp := env.DoRequest(http.MethodGet, "/api/v1/files/folders", nil)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var folders []folderData
if err := json.Unmarshal(data, &folders); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if len(folders) < 2 {
t.Fatalf("expected at least 2 folders, got %d", len(folders))
}
}
// TestIntegration_Folder_Get verifies GET /api/v1/files/folders/:id returns folder details.
func TestIntegration_Folder_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
id := folderSeed(env, "test-folder-get")
path := fmt.Sprintf("/api/v1/files/folders/%d", id)
resp := env.DoRequest(http.MethodGet, path, nil)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var f folderData
if err := json.Unmarshal(data, &f); err != nil {
t.Fatalf("unmarshal folder data: %v", err)
}
if f.ID != id {
t.Fatalf("expected id=%d, got %d", id, f.ID)
}
if f.Name != "test-folder-get" {
t.Fatalf("expected name=test-folder-get, got %s", f.Name)
}
}
// TestIntegration_Folder_Delete verifies DELETE /api/v1/files/folders/:id removes a folder.
func TestIntegration_Folder_Delete(t *testing.T) {
env := testenv.NewTestEnv(t)
id := folderSeed(env, "test-folder-delete")
path := fmt.Sprintf("/api/v1/files/folders/%d", id)
resp := env.DoRequest(http.MethodDelete, path, nil)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg folderMessageData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal message data: %v", err)
}
if msg.Message != "folder deleted" {
t.Fatalf("expected message 'folder deleted', got %q", msg.Message)
}
// Verify it's gone via GET → 404.
getResp := env.DoRequest(http.MethodGet, path, nil)
getStatus, getSuccess, _, _ := folderDecodeAll(env, getResp)
if getStatus != http.StatusNotFound {
t.Fatalf("expected status 404 after delete, got %d", getStatus)
}
if getSuccess {
t.Fatal("expected success=false after delete")
}
}

View File

@@ -0,0 +1,222 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// jobItemData mirrors the JobResponse DTO for a single job.
type jobItemData struct {
JobID int32 `json:"job_id"`
Name string `json:"name"`
State []string `json:"job_state"`
Partition string `json:"partition"`
}
// jobListData mirrors the paginated JobListResponse DTO.
type jobListData struct {
Jobs []jobItemData `json:"jobs"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// jobCancelData mirrors the cancel response message.
type jobCancelData struct {
Message string `json:"message"`
}
// jobDecodeAll decodes the response and returns status, success, and raw data.
func jobDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// jobSubmitBody builds a JSON body for job submit requests.
func jobSubmitBody(script string) *bytes.Reader {
body, _ := json.Marshal(map[string]string{"script": script})
return bytes.NewReader(body)
}
// jobSubmitViaAPI submits a job and returns the job ID. Fatals on failure.
func jobSubmitViaAPI(t *testing.T, env *testenv.TestEnv, script string) int32 {
t.Helper()
resp := env.DoRequest(http.MethodPost, "/api/v1/jobs/submit", jobSubmitBody(script))
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("submit job decode: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true on submit")
}
var job jobItemData
if err := json.Unmarshal(data, &job); err != nil {
t.Fatalf("unmarshal submitted job: %v", err)
}
return job.JobID
}
// TestIntegration_Jobs_Submit verifies POST /api/v1/jobs/submit creates a new job.
func TestIntegration_Jobs_Submit(t *testing.T) {
env := testenv.NewTestEnv(t)
script := "#!/bin/bash\necho hello"
resp := env.DoRequest(http.MethodPost, "/api/v1/jobs/submit", jobSubmitBody(script))
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var job jobItemData
if err := json.Unmarshal(data, &job); err != nil {
t.Fatalf("unmarshal job: %v", err)
}
if job.JobID <= 0 {
t.Fatalf("expected positive job_id, got %d", job.JobID)
}
}
// TestIntegration_Jobs_List verifies GET /api/v1/jobs returns a paginated job list.
func TestIntegration_Jobs_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Submit a job so the list is not empty.
jobSubmitViaAPI(t, env, "#!/bin/bash\necho list-test")
resp := env.DoRequest(http.MethodGet, "/api/v1/jobs", nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var list jobListData
if err := json.Unmarshal(data, &list); err != nil {
t.Fatalf("unmarshal job list: %v", err)
}
if list.Total < 1 {
t.Fatalf("expected at least 1 job, got total=%d", list.Total)
}
if list.Page != 1 {
t.Fatalf("expected page=1, got %d", list.Page)
}
}
// TestIntegration_Jobs_Get verifies GET /api/v1/jobs/:id returns a single job.
func TestIntegration_Jobs_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho get-test")
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
resp := env.DoRequest(http.MethodGet, path, nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var job jobItemData
if err := json.Unmarshal(data, &job); err != nil {
t.Fatalf("unmarshal job: %v", err)
}
if job.JobID != jobID {
t.Fatalf("expected job_id=%d, got %d", jobID, job.JobID)
}
}
// TestIntegration_Jobs_Cancel verifies DELETE /api/v1/jobs/:id cancels a job.
func TestIntegration_Jobs_Cancel(t *testing.T) {
env := testenv.NewTestEnv(t)
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho cancel-test")
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
resp := env.DoRequest(http.MethodDelete, path, nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg jobCancelData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal cancel response: %v", err)
}
if msg.Message == "" {
t.Fatal("expected non-empty cancel message")
}
}
// TestIntegration_Jobs_History verifies GET /api/v1/jobs/history returns historical jobs.
func TestIntegration_Jobs_History(t *testing.T) {
env := testenv.NewTestEnv(t)
// Submit and cancel a job so it moves from active to history queue.
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho history-test")
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
env.DoRequest(http.MethodDelete, path, nil)
resp := env.DoRequest(http.MethodGet, "/api/v1/jobs/history", nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var list jobListData
if err := json.Unmarshal(data, &list); err != nil {
t.Fatalf("unmarshal history: %v", err)
}
if list.Total < 1 {
t.Fatalf("expected at least 1 history job, got total=%d", list.Total)
}
// Verify the cancelled job appears in history.
found := false
for _, j := range list.Jobs {
if j.JobID == jobID {
found = true
break
}
}
if !found {
t.Fatalf("cancelled job %d not found in history", jobID)
}
}

View File

@@ -0,0 +1,261 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/testutil/testenv"
)
// taskAPIResponse decodes the unified API response envelope.
type taskAPIResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// taskCreateData is the data payload from a successful task creation.
type taskCreateData struct {
ID int64 `json:"id"`
}
// taskListData is the data payload from listing tasks.
type taskListData struct {
Items []taskListItem `json:"items"`
Total int64 `json:"total"`
}
type taskListItem struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
AppID int64 `json:"app_id"`
Status string `json:"status"`
SlurmJobID *int32 `json:"slurm_job_id"`
}
// taskSendReq sends an HTTP request via the test env and returns the response.
func taskSendReq(t *testing.T, env *testenv.TestEnv, method, path string, body string) *http.Response {
t.Helper()
var r io.Reader
if body != "" {
r = strings.NewReader(body)
}
resp := env.DoRequest(method, path, r)
return resp
}
// taskParseResp decodes the response body into a taskAPIResponse.
func taskParseResp(t *testing.T, resp *http.Response) taskAPIResponse {
t.Helper()
b, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatalf("read response body: %v", err)
}
var result taskAPIResponse
if err := json.Unmarshal(b, &result); err != nil {
t.Fatalf("unmarshal response: %v (body: %s)", err, string(b))
}
return result
}
// taskCreateViaAPI creates a task via the HTTP API and returns the task ID.
func taskCreateViaAPI(t *testing.T, env *testenv.TestEnv, appID int64, taskName string) int64 {
t.Helper()
body := fmt.Sprintf(`{"app_id":%d,"task_name":"%s","values":{},"file_ids":[]}`, appID, taskName)
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success=true, got error: %s", parsed.Error)
}
var data taskCreateData
if err := json.Unmarshal(parsed.Data, &data); err != nil {
t.Fatalf("unmarshal create data: %v", err)
}
if data.ID == 0 {
t.Fatal("expected non-zero task ID")
}
return data.ID
}
// ---------- Tests ----------
func TestIntegration_Task_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
// Create application
appID, err := env.CreateApp("task-create-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// Create task via API
taskID := taskCreateViaAPI(t, env, appID, "test-task-create")
// Verify the task ID is positive
if taskID <= 0 {
t.Fatalf("expected positive task ID, got %d", taskID)
}
// Wait briefly for async processing, then verify task exists in DB via list
time.Sleep(200 * time.Millisecond)
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer resp.Body.Close()
parsed := taskParseResp(t, resp)
var listData taskListData
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
found := false
for _, item := range listData.Items {
if item.ID == taskID {
found = true
if item.TaskName != "test-task-create" {
t.Errorf("expected task_name=test-task-create, got %s", item.TaskName)
}
if item.AppID != appID {
t.Errorf("expected app_id=%d, got %d", appID, item.AppID)
}
break
}
}
if !found {
t.Fatalf("task %d not found in list", taskID)
}
}
func TestIntegration_Task_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Create application
appID, err := env.CreateApp("task-list-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// Create 3 tasks
taskCreateViaAPI(t, env, appID, "list-task-1")
taskCreateViaAPI(t, env, appID, "list-task-2")
taskCreateViaAPI(t, env, appID, "list-task-3")
// Allow async processing
time.Sleep(200 * time.Millisecond)
// List tasks
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success, got error: %s", parsed.Error)
}
var listData taskListData
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if listData.Total < 3 {
t.Fatalf("expected at least 3 tasks, got %d", listData.Total)
}
// Verify each created task has required fields
for _, item := range listData.Items {
if item.ID == 0 {
t.Error("expected non-zero ID")
}
if item.Status == "" {
t.Error("expected non-empty status")
}
if item.AppID == 0 {
t.Error("expected non-zero app_id")
}
}
}
func TestIntegration_Task_PollerLifecycle(t *testing.T) {
env := testenv.NewTestEnv(t)
// 1. Create application
appID, err := env.CreateApp("poller-lifecycle-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// 2. Submit task via API
taskID := taskCreateViaAPI(t, env, appID, "poller-lifecycle-task")
// 3. Wait for queued — TaskProcessor submits to MockSlurm asynchronously.
// Intermediate states (submitted→preparing→downloading→ready→queued) are
// non-deterministic; only assert the final "queued" state.
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
t.Fatalf("wait for queued: %v", err)
}
// 4. Get slurm job ID from DB (not returned by API)
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
if err != nil {
t.Fatalf("get slurm job id: %v", err)
}
// 5. Transition: queued → running
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("make task stale (running): %v", err)
}
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
t.Fatalf("wait for running: %v", err)
}
// 6. Transition: running → completed
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("make task stale (completed): %v", err)
}
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
t.Fatalf("wait for completed: %v", err)
}
}
func TestIntegration_Task_Validation(t *testing.T) {
env := testenv.NewTestEnv(t)
// Missing required app_id
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", `{"task_name":"no-app-id"}`)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing app_id, got %d", resp.StatusCode)
}
parsed := taskParseResp(t, resp)
if parsed.Success {
t.Fatal("expected success=false for validation error")
}
if parsed.Error == "" {
t.Error("expected non-empty error message")
}
}

View File

@@ -0,0 +1,279 @@
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/testutil/testenv"
)
// uploadResponse mirrors server.APIResponse for upload integration tests.
type uploadAPIResp struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// uploadDecode parses an HTTP response body into uploadAPIResp.
func uploadDecode(t *testing.T, body io.Reader) uploadAPIResp {
t.Helper()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("uploadDecode: read body: %v", err)
}
var r uploadAPIResp
if err := json.Unmarshal(data, &r); err != nil {
t.Fatalf("uploadDecode: unmarshal: %v (body: %s)", err, string(data))
}
return r
}
// uploadInitSession calls InitUpload and returns the created session.
// Uses the real HTTP server from testenv.
func uploadInitSession(t *testing.T, env *testenv.TestEnv, fileName string, fileSize int64, sha256Hash string) model.UploadSessionResponse {
t.Helper()
reqBody := model.InitUploadRequest{
FileName: fileName,
FileSize: fileSize,
SHA256: sha256Hash,
}
body, _ := json.Marshal(reqBody)
resp := env.DoRequest("POST", "/api/v1/files/uploads", bytes.NewReader(body))
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("uploadInitSession: expected 201, got %d", resp.StatusCode)
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("uploadInitSession: response not success: %s", r.Error)
}
var session model.UploadSessionResponse
if err := json.Unmarshal(r.Data, &session); err != nil {
t.Fatalf("uploadInitSession: unmarshal session: %v", err)
}
return session
}
// uploadSendChunk sends a single chunk via multipart form data.
// Uses raw HTTP client to set the correct multipart content type.
func uploadSendChunk(t *testing.T, env *testenv.TestEnv, sessionID int64, chunkIndex int, chunkData []byte) {
t.Helper()
url := fmt.Sprintf("%s/api/v1/files/uploads/%d/chunks/%d", env.URL(), sessionID, chunkIndex)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("chunk", "chunk.bin")
if err != nil {
t.Fatalf("uploadSendChunk: create form file: %v", err)
}
part.Write(chunkData)
writer.Close()
req, err := http.NewRequest("PUT", url, &buf)
if err != nil {
t.Fatalf("uploadSendChunk: new request: %v", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("uploadSendChunk: do request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("uploadSendChunk: expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestIntegration_Upload_Init(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test upload init content")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "init_test.txt", int64(len(fileData)), sha256Hash)
if session.ID <= 0 {
t.Fatalf("expected positive session ID, got %d", session.ID)
}
if session.FileName != "init_test.txt" {
t.Fatalf("expected file_name init_test.txt, got %s", session.FileName)
}
if session.Status != "pending" {
t.Fatalf("expected status pending, got %s", session.Status)
}
if session.TotalChunks != 1 {
t.Fatalf("expected 1 chunk for small file, got %d", session.TotalChunks)
}
if session.FileSize != int64(len(fileData)) {
t.Fatalf("expected file_size %d, got %d", len(fileData), session.FileSize)
}
if session.SHA256 != sha256Hash {
t.Fatalf("expected sha256 %s, got %s", sha256Hash, session.SHA256)
}
}
func TestIntegration_Upload_Status(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test status content")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "status_test.txt", int64(len(fileData)), sha256Hash)
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("response not success: %s", r.Error)
}
var status model.UploadSessionResponse
if err := json.Unmarshal(r.Data, &status); err != nil {
t.Fatalf("unmarshal status: %v", err)
}
if status.ID != session.ID {
t.Fatalf("expected session ID %d, got %d", session.ID, status.ID)
}
if status.Status != "pending" {
t.Fatalf("expected status pending, got %s", status.Status)
}
if status.FileName != "status_test.txt" {
t.Fatalf("expected file_name status_test.txt, got %s", status.FileName)
}
}
func TestIntegration_Upload_Chunk(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test chunk upload data")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "chunk_test.txt", int64(len(fileData)), sha256Hash)
uploadSendChunk(t, env, session.ID, 0, fileData)
// Verify chunk appears in uploaded_chunks via status endpoint
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer resp.Body.Close()
r := uploadDecode(t, resp.Body)
var status model.UploadSessionResponse
if err := json.Unmarshal(r.Data, &status); err != nil {
t.Fatalf("unmarshal status after chunk: %v", err)
}
if len(status.UploadedChunks) != 1 {
t.Fatalf("expected 1 uploaded chunk, got %d", len(status.UploadedChunks))
}
if status.UploadedChunks[0] != 0 {
t.Fatalf("expected uploaded chunk index 0, got %d", status.UploadedChunks[0])
}
}
func TestIntegration_Upload_Complete(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test complete upload data")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "complete_test.txt", int64(len(fileData)), sha256Hash)
// Upload all chunks
for i := 0; i < session.TotalChunks; i++ {
uploadSendChunk(t, env, session.ID, i, fileData)
}
// Complete upload
resp := env.DoRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(bodyBytes))
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("complete response not success: %s", r.Error)
}
var fileResp model.FileResponse
if err := json.Unmarshal(r.Data, &fileResp); err != nil {
t.Fatalf("unmarshal file response: %v", err)
}
if fileResp.ID <= 0 {
t.Fatalf("expected positive file ID, got %d", fileResp.ID)
}
if fileResp.Name != "complete_test.txt" {
t.Fatalf("expected name complete_test.txt, got %s", fileResp.Name)
}
if fileResp.Size != int64(len(fileData)) {
t.Fatalf("expected size %d, got %d", len(fileData), fileResp.Size)
}
if fileResp.SHA256 != sha256Hash {
t.Fatalf("expected sha256 %s, got %s", sha256Hash, fileResp.SHA256)
}
}
func TestIntegration_Upload_Cancel(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test cancel upload data")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "cancel_test.txt", int64(len(fileData)), sha256Hash)
// Cancel the upload
resp := env.DoRequest("DELETE", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("cancel response not success: %s", r.Error)
}
// Verify session is no longer in pending state by checking status
statusResp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer statusResp.Body.Close()
sr := uploadDecode(t, statusResp.Body)
if sr.Success {
var status model.UploadSessionResponse
if err := json.Unmarshal(sr.Data, &status); err == nil {
if status.Status == "pending" {
t.Fatal("expected status to not be pending after cancel")
}
}
}
}

39
cmd/server/main.go Normal file
View File

@@ -0,0 +1,39 @@
package main
import (
"fmt"
"os"
"gcy_hpc_server/internal/app"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/logger"
"go.uber.org/zap"
)
func main() {
cfgPath := ""
if len(os.Args) > 1 {
cfgPath = os.Args[1]
}
cfg, err := config.Load(cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err)
os.Exit(1)
}
zapLogger, err := logger.NewLogger(cfg.Log)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err)
os.Exit(1)
}
defer zapLogger.Sync()
application, err := app.NewApp(cfg, zapLogger)
if err != nil {
zapLogger.Fatal("failed to initialize application", zap.Error(err))
}
if err := application.Run(); err != nil {
zapLogger.Fatal("application error", zap.Error(err))
}
}

128
cmd/server/main_test.go Normal file
View File

@@ -0,0 +1,128 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTestDB() *gorm.DB {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.Application{})
return db
}
func TestRouterRegistration(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
}))
defer slurmSrv.Close()
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
jobSvc := service.NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(newTestDB())
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
appH := handler.NewApplicationHandler(appSvc, zap.NewNop())
router := server.NewRouter(
handler.NewJobHandler(jobSvc, zap.NewNop()),
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
appH,
nil, nil, nil,
nil,
nil,
)
routes := router.Routes()
expected := []struct {
method string
path string
}{
{"POST", "/api/v1/jobs/submit"},
{"GET", "/api/v1/jobs"},
{"GET", "/api/v1/jobs/history"},
{"GET", "/api/v1/jobs/:id"},
{"DELETE", "/api/v1/jobs/:id"},
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/nodes/:name"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/partitions/:name"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/applications"},
{"POST", "/api/v1/applications"},
{"GET", "/api/v1/applications/:id"},
{"PUT", "/api/v1/applications/:id"},
{"DELETE", "/api/v1/applications/:id"},
// {"POST", "/api/v1/applications/:id/submit"}, // [已禁用] 已被 POST /tasks 取代
}
routeMap := map[string]bool{}
for _, r := range routes {
routeMap[r.Method+" "+r.Path] = true
}
for _, exp := range expected {
key := exp.method + " " + exp.path
if !routeMap[key] {
t.Errorf("missing route: %s", key)
}
}
if len(routes) < len(expected) {
t.Errorf("expected at least %d routes, got %d", len(expected), len(routes))
}
}
func TestSmokeGetJobsEndpoint(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
}))
defer slurmSrv.Close()
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
jobSvc := service.NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(newTestDB())
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
appH := handler.NewApplicationHandler(appSvc, zap.NewNop())
router := server.NewRouter(
handler.NewJobHandler(jobSvc, zap.NewNop()),
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
appH,
nil, nil, nil,
nil,
nil,
)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if success, ok := resp["success"].(bool); !ok || !success {
t.Fatalf("expected success=true, got %v", resp["success"])
}
}

434
cmd/server/task_test.go Normal file
View File

@@ -0,0 +1,434 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(1)
db.AutoMigrate(
&model.Application{},
&model.File{},
&model.FileBlob{},
&model.Task{},
)
t.Cleanup(func() {
sqlDB.Close()
})
return db
}
type mockSlurmHandler struct {
submitFn func(w http.ResponseWriter, r *http.Request)
}
func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
if h != nil && h.submitFn != nil {
h.submitFn(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`)
})
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) {
t.Helper()
return setupTaskTestServerWithSlurm(t, nil)
}
func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) {
t.Helper()
db := newTaskTestDB(t)
slurmSrv := newMockSlurmServer(t, slurmHandler)
client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
if err != nil {
t.Fatalf("slurm client: %v", err)
}
log := zap.NewNop()
jobSvc := service.NewJobService(client, log)
appStore := store.NewApplicationStore(db)
taskStore := store.NewTaskStore(db)
tmpDir, err := os.MkdirTemp("", "task-test-workdir-*")
if err != nil {
t.Fatalf("temp dir: %v", err)
}
taskSvc := service.NewTaskService(
taskStore, appStore,
nil, nil, nil,
jobSvc,
tmpDir,
log,
)
ctx, cancel := context.WithCancel(context.Background())
taskSvc.StartProcessor(ctx)
appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc)
appH := handler.NewApplicationHandler(appSvc, log)
taskH := handler.NewTaskHandler(taskSvc, log)
router := server.NewRouter(
handler.NewJobHandler(jobSvc, log),
handler.NewClusterHandler(service.NewClusterService(client, log), log),
appH,
nil, nil, nil,
taskH,
log,
)
httpSrv := httptest.NewServer(router)
cleanup := func() {
taskSvc.StopProcessor()
cancel()
httpSrv.Close()
os.RemoveAll(tmpDir)
}
return httpSrv, taskStore, cleanup
}
func createTestApp(t *testing.T, srvURL string) int64 {
t.Helper()
body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}`
resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("create app: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("create app: status %d, body %s", resp.StatusCode, b)
}
var result struct {
Data struct {
ID int64 `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("decode app response: %v", err)
}
return result.Data.ID
}
func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) {
t.Helper()
resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("post task: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
var result map[string]interface{}
json.Unmarshal(b, &result)
return resp, result
}
func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) {
t.Helper()
url := srvURL + "/api/v1/tasks"
if query != "" {
url += "?" + query
}
resp, err := http.Get(url)
if err != nil {
t.Fatalf("get tasks: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
var result map[string]interface{}
json.Unmarshal(b, &result)
data, _ := result["data"].(map[string]interface{})
items, _ := data["items"].([]interface{})
total, _ := data["total"].(float64)
return int(total), items
}
func TestTask_FullLifecycle(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
data, _ := result["data"].(map[string]interface{})
if data["id"] == nil {
t.Fatal("expected id in response")
}
time.Sleep(300 * time.Millisecond)
total, items := getTasks(t, srv.URL, "")
if total < 1 {
t.Fatalf("expected at least 1 task, got %d", total)
}
task := items[0].(map[string]interface{})
if task["status"] == nil || task["status"] == "" {
t.Error("expected non-empty status")
}
if task["task_name"] != "lifecycle-test" {
t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"])
}
if task["app_id"] != float64(appID) {
t.Errorf("expected app_id=%d, got %v", appID, task["app_id"])
}
}
func TestTask_CreateWithMissingApp(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_CreateWithInvalidBody(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
resp, result := postTask(t, srv.URL, `{}`)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_FileLimitExceeded(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
fileIDs := make([]int64, 101)
for i := range fileIDs {
fileIDs[i] = int64(i + 1)
}
idsJSON, _ := json.Marshal(fileIDs)
body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON))
resp, result := postTask(t, srv.URL, body)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_RetryScenario(t *testing.T) {
var failCount int32
slurmH := &mockSlurmHandler{
submitFn: func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&failCount, 1) <= 2 {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
return
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`)
},
}
srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
task, _ := taskStore.GetByID(context.Background(), taskID)
if task != nil && task.Status == model.TaskStatusQueued {
if task.SlurmJobID != nil && *task.SlurmJobID == 99 {
return
}
}
time.Sleep(100 * time.Millisecond)
}
task, _ := taskStore.GetByID(context.Background(), taskID)
if task == nil {
t.Fatalf("task %d not found after deadline", taskID)
}
t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v",
task.Status, task.RetryCount, task.SlurmJobID)
}
// [已禁用] 前端已全部迁移到 POST /tasks 接口,旧 API 兼容性测试不再需要。
/*
func TestTask_OldAPICompatibility(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
body := `{"values":{"np":"8"}}`
url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID)
resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("post submit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b)
}
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
data, _ := result["data"].(map[string]interface{})
if data == nil {
t.Fatal("expected data in response")
}
if data["job_id"] == nil {
t.Errorf("expected job_id in old API response, got: %v", data)
}
jobID, ok := data["job_id"].(float64)
if !ok || jobID == 0 {
t.Errorf("expected non-zero job_id, got %v", data["job_id"])
}
}
*/
func TestTask_ListWithFilters(t *testing.T) {
srv, taskStore, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
now := time.Now()
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-completed", AppID: appID, AppName: "test-app",
Status: model.TaskStatusCompleted, SubmittedAt: now,
})
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-failed", AppID: appID, AppName: "test-app",
Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now,
})
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-queued", AppID: appID, AppName: "test-app",
Status: model.TaskStatusQueued, SubmittedAt: now,
})
total, items := getTasks(t, srv.URL, "status=completed")
if total != 1 {
t.Fatalf("expected 1 completed, got %d", total)
}
task := items[0].(map[string]interface{})
if task["status"] != "completed" {
t.Errorf("expected status=completed, got %v", task["status"])
}
total, items = getTasks(t, srv.URL, "status=failed")
if total != 1 {
t.Fatalf("expected 1 failed, got %d", total)
}
total, _ = getTasks(t, srv.URL, "")
if total != 3 {
t.Fatalf("expected 3 total, got %d", total)
}
}
func TestTask_WorkDirCreated(t *testing.T) {
srv, taskStore, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
task, _ := taskStore.GetByID(context.Background(), taskID)
if task != nil && task.WorkDir != "" {
if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) {
t.Fatalf("work dir %s not created", task.WorkDir)
}
if !filepath.IsAbs(task.WorkDir) {
t.Errorf("expected absolute work dir, got %s", task.WorkDir)
}
return
}
time.Sleep(50 * time.Millisecond)
}
task, _ := taskStore.GetByID(context.Background(), taskID)
if task == nil {
t.Fatalf("task %d not found after deadline", taskID)
}
t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir)
}

66
go.mod
View File

@@ -1,3 +1,65 @@
module slurm-client
module gcy_hpc_server
go 1.22
go 1.25.0
require (
github.com/gin-gonic/gin v1.12.0
go.uber.org/zap v1.27.1
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/klauspost/crc32 v1.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/minio/crc64nvme v1.1.1 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/minio-go/v7 v7.0.100 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/tinylib/msgp v1.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
)

148
go.sum Normal file
View File

@@ -0,0 +1,148 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM=
github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI=
github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.100 h1:ShkWi8Tyj9RtU57OQB2HIXKz4bFgtVib0bbT1sbtLI8=
github.com/minio/minio-go/v7 v7.0.100/go.mod h1:EtGNKtlX20iL2yaYnxEigaIvj0G0GwSDnifnG8ClIdw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY=
github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=

2954
hpc_server_openapi.json Normal file

File diff suppressed because it is too large Load Diff

227
internal/app/app.go Normal file
View File

@@ -0,0 +1,227 @@
package app
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
)
type App struct {
cfg *config.Config
logger *zap.Logger
db *gorm.DB
server *http.Server
cancelCleanup context.CancelFunc
taskSvc *service.TaskService
taskPoller *TaskPoller
}
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
gormDB, err := initDB(cfg, logger)
if err != nil {
return nil, err
}
slurmClient, err := initSlurmClient(cfg)
if err != nil {
closeDB(gormDB)
return nil, err
}
srv, cancelCleanup, taskSvc, taskPoller := initHTTPServer(cfg, gormDB, slurmClient, logger)
return &App{
cfg: cfg,
logger: logger,
db: gormDB,
server: srv,
cancelCleanup: cancelCleanup,
taskSvc: taskSvc,
taskPoller: taskPoller,
}, nil
}
// Run starts the HTTP server and blocks until a shutdown signal or server error.
func (a *App) Run() error {
errCh := make(chan error, 1)
go func() {
a.logger.Info("starting server", zap.String("addr", a.server.Addr))
if err := a.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errCh <- fmt.Errorf("server listen: %w", err)
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-errCh:
// Server crashed before receiving a signal — clean up resources before
// returning, because the caller may call os.Exit and skip deferred Close().
a.logger.Error("server exited unexpectedly", zap.Error(err))
_ = a.Close()
return err
case sig := <-quit:
a.logger.Info("received shutdown signal", zap.String("signal", sig.String()))
}
a.logger.Info("shutting down server...")
return a.Close()
}
// Close cleans up all resources: HTTP server and database connections.
func (a *App) Close() error {
var errs []error
if a.taskSvc != nil {
a.taskSvc.StopProcessor()
}
if a.taskPoller != nil {
a.taskPoller.Stop()
}
if a.cancelCleanup != nil {
a.cancelCleanup()
}
if a.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := a.server.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
errs = append(errs, fmt.Errorf("shutdown http server: %w", err))
}
}
if a.db != nil {
sqlDB, err := a.db.DB()
if err != nil {
errs = append(errs, fmt.Errorf("get underlying sql.DB: %w", err))
} else if err := sqlDB.Close(); err != nil {
errs = append(errs, fmt.Errorf("close database: %w", err))
}
}
return errors.Join(errs...)
}
// ---------------------------------------------------------------------------
// Initialization helpers
// ---------------------------------------------------------------------------
func initDB(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
gormDB, err := store.NewGormDB(cfg.MySQLDSN, logger, cfg.Log.GormLevel)
if err != nil {
return nil, fmt.Errorf("init database: %w", err)
}
if err := store.AutoMigrate(gormDB); err != nil {
closeDB(gormDB)
return nil, fmt.Errorf("run migrations: %w", err)
}
return gormDB, nil
}
func closeDB(db *gorm.DB) {
if db == nil {
return
}
if sqlDB, err := db.DB(); err == nil {
_ = sqlDB.Close()
}
}
func initSlurmClient(cfg *config.Config) (*slurm.Client, error) {
client, err := service.NewSlurmClient(cfg.SlurmAPIURL, cfg.SlurmUserName, cfg.SlurmJWTKeyPath)
if err != nil {
return nil, fmt.Errorf("init slurm client: %w", err)
}
return client, nil
}
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc, *service.TaskService, *TaskPoller) {
ctx := context.Background()
jobSvc := service.NewJobService(slurmClient, logger)
clusterSvc := service.NewClusterService(slurmClient, logger)
jobH := handler.NewJobHandler(jobSvc, logger)
clusterH := handler.NewClusterHandler(clusterSvc, logger)
appStore := store.NewApplicationStore(db)
// File storage initialization
minioClient, err := storage.NewMinioClient(cfg.Minio)
if err != nil {
logger.Warn("failed to initialize MinIO client, file storage disabled", zap.Error(err))
}
var uploadH *handler.UploadHandler
var fileH *handler.FileHandler
var folderH *handler.FolderHandler
taskStore := store.NewTaskStore(db)
fileStore := store.NewFileStore(db)
blobStore := store.NewBlobStore(db)
var stagingSvc *service.FileStagingService
if minioClient != nil {
folderStore := store.NewFolderStore(db)
uploadStore := store.NewUploadStore(db)
uploadSvc := service.NewUploadService(minioClient, blobStore, fileStore, uploadStore, cfg.Minio, db, logger)
folderSvc := service.NewFolderService(folderStore, fileStore, logger)
fileSvc := service.NewFileService(minioClient, blobStore, fileStore, cfg.Minio.Bucket, db, logger)
uploadH = handler.NewUploadHandler(uploadSvc, logger)
fileH = handler.NewFileHandler(fileSvc, logger)
folderH = handler.NewFolderHandler(folderSvc, logger)
stagingSvc = service.NewFileStagingService(fileStore, blobStore, minioClient, cfg.Minio.Bucket, logger)
}
taskSvc := service.NewTaskService(taskStore, appStore, fileStore, blobStore, stagingSvc, jobSvc, cfg.WorkDirBase, logger)
taskSvc.StartProcessor(ctx)
appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger, taskSvc)
appH := handler.NewApplicationHandler(appSvc, logger)
poller := NewTaskPoller(taskSvc, 10*time.Second, logger)
poller.Start(ctx)
taskH := handler.NewTaskHandler(taskSvc, logger)
var cancelCleanup context.CancelFunc
if minioClient != nil {
cleanupCtx, cancel := context.WithCancel(context.Background())
cancelCleanup = cancel
go startCleanupWorker(cleanupCtx, store.NewUploadStore(db), minioClient, cfg.Minio.Bucket, logger)
}
router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, taskH, logger)
addr := ":" + cfg.ServerPort
return &http.Server{
Addr: addr,
Handler: router,
}, cancelCleanup, taskSvc, poller
}

25
internal/app/app_test.go Normal file
View File

@@ -0,0 +1,25 @@
package app
import (
"testing"
"gcy_hpc_server/internal/config"
"go.uber.org/zap/zaptest"
)
func TestNewApp_InvalidDB(t *testing.T) {
cfg := &config.Config{
ServerPort: "8080",
MySQLDSN: "invalid:dsn@tcp(localhost:99999)/nonexistent?parseTime=true",
SlurmAPIURL: "http://localhost:6820",
SlurmUserName: "root",
SlurmJWTKeyPath: "/nonexistent/jwt.key",
}
logger := zaptest.NewLogger(t)
_, err := NewApp(cfg, logger)
if err == nil {
t.Fatal("expected error for invalid DSN, got nil")
}
}

83
internal/app/cleanup.go Normal file
View File

@@ -0,0 +1,83 @@
package app
import (
"context"
"time"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// startCleanupWorker runs a background goroutine that periodically cleans up:
// 1. Expired upload sessions (mark → delete MinIO chunks → delete DB records)
// 2. Leaked multipart uploads from failed ComposeObject calls
func startCleanupWorker(ctx context.Context, uploadStore *store.UploadStore, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
cleanupExpiredSessions(ctx, uploadStore, objStorage, bucket, logger)
cleanupLeakedMultipartUploads(ctx, objStorage, bucket, logger)
for {
select {
case <-ctx.Done():
logger.Info("cleanup worker stopped")
return
case <-ticker.C:
cleanupExpiredSessions(ctx, uploadStore, objStorage, bucket, logger)
cleanupLeakedMultipartUploads(ctx, objStorage, bucket, logger)
}
}
}
// cleanupExpiredSessions performs three-phase cleanup of expired upload sessions:
// Phase 1: Find and mark expired sessions
// Phase 2: Delete MinIO temp chunks for each session
// Phase 3: Delete DB records (session + chunks)
func cleanupExpiredSessions(ctx context.Context, uploadStore *store.UploadStore, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
sessions, err := uploadStore.ListExpiredSessions(ctx)
if err != nil {
logger.Error("failed to list expired sessions", zap.Error(err))
return
}
if len(sessions) == 0 {
return
}
logger.Info("cleaning up expired sessions", zap.Int("count", len(sessions)))
for i := range sessions {
session := &sessions[i]
if err := uploadStore.UpdateSessionStatus(ctx, session.ID, "expired"); err != nil {
logger.Error("failed to mark session expired", zap.Int64("session_id", session.ID), zap.Error(err))
continue
}
objects, err := objStorage.ListObjects(ctx, bucket, session.MinioPrefix, true)
if err != nil {
logger.Error("failed to list session objects", zap.Int64("session_id", session.ID), zap.Error(err))
} else if len(objects) > 0 {
keys := make([]string, len(objects))
for j, obj := range objects {
keys[j] = obj.Key
}
if err := objStorage.RemoveObjects(ctx, bucket, keys, storage.RemoveObjectsOptions{}); err != nil {
logger.Error("failed to remove session objects", zap.Int64("session_id", session.ID), zap.Error(err))
}
}
if err := uploadStore.DeleteSession(ctx, session.ID); err != nil {
logger.Error("failed to delete session", zap.Int64("session_id", session.ID), zap.Error(err))
}
}
}
func cleanupLeakedMultipartUploads(ctx context.Context, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
if err := objStorage.RemoveIncompleteUpload(ctx, bucket, "uploads/"); err != nil {
logger.Error("failed to cleanup leaked multipart uploads", zap.Error(err))
}
}

View File

@@ -0,0 +1,266 @@
package app
import (
"context"
"io"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// mockCleanupStorage implements ObjectStorage for cleanup tests.
type mockCleanupStorage struct {
listObjectsFn func(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error)
removeObjectsFn func(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error
removeIncompleteFn func(ctx context.Context, bucket, object string) error
}
func (m *mockCleanupStorage) PutObject(_ context.Context, _ string, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockCleanupStorage) GetObject(_ context.Context, _ string, _ string, _ storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return nil, storage.ObjectInfo{}, nil
}
func (m *mockCleanupStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockCleanupStorage) AbortMultipartUpload(_ context.Context, _ string, _ string, _ string) error {
return nil
}
func (m *mockCleanupStorage) RemoveIncompleteUpload(_ context.Context, _ string, _ string) error {
if m.removeIncompleteFn != nil {
return m.removeIncompleteFn(context.Background(), "", "")
}
return nil
}
func (m *mockCleanupStorage) RemoveObject(_ context.Context, _ string, _ string, _ storage.RemoveObjectOptions) error {
return nil
}
func (m *mockCleanupStorage) ListObjects(ctx context.Context, bucket string, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
if m.listObjectsFn != nil {
return m.listObjectsFn(ctx, bucket, prefix, recursive)
}
return nil, nil
}
func (m *mockCleanupStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
if m.removeObjectsFn != nil {
return m.removeObjectsFn(ctx, bucket, keys, opts)
}
return nil
}
func (m *mockCleanupStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (m *mockCleanupStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (m *mockCleanupStorage) StatObject(_ context.Context, _ string, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupCleanupTestDB(t *testing.T) (*gorm.DB, *store.UploadStore) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.UploadSession{}, &model.UploadChunk{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db, store.NewUploadStore(db)
}
func TestCleanupExpiredSessions(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
past := time.Now().Add(-1 * time.Hour)
err := uploadStore.CreateSession(ctx, &model.UploadSession{
FileName: "expired.bin",
FileSize: 1024,
ChunkSize: 16 << 20,
TotalChunks: 1,
SHA256: "expired_hash",
Status: "pending",
MinioPrefix: "uploads/99/",
ExpiresAt: past,
})
if err != nil {
t.Fatalf("create session: %v", err)
}
var listedPrefix string
var removedKeys []string
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
listedPrefix = prefix
return []storage.ObjectInfo{{Key: "uploads/99/chunk_00000", Size: 100}}, nil
},
removeObjectsFn: func(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
removedKeys = keys
return nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if listedPrefix != "uploads/99/" {
t.Errorf("listed prefix = %q, want %q", listedPrefix, "uploads/99/")
}
if len(removedKeys) != 1 || removedKeys[0] != "uploads/99/chunk_00000" {
t.Errorf("removed keys = %v, want [uploads/99/chunk_00000]", removedKeys)
}
session, err := uploadStore.GetSession(ctx, 1)
if err != nil {
t.Fatalf("get session: %v", err)
}
if session != nil {
t.Error("session should be deleted after cleanup")
}
}
func TestCleanupExpiredSessions_Empty(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
called := false
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
called = true
return nil, nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if called {
t.Error("ListObjects should not be called when no expired sessions exist")
}
}
func TestCleanupExpiredSessions_CompletedNotCleaned(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
past := time.Now().Add(-1 * time.Hour)
err := uploadStore.CreateSession(ctx, &model.UploadSession{
FileName: "completed.bin",
FileSize: 1024,
ChunkSize: 16 << 20,
TotalChunks: 1,
SHA256: "completed_hash",
Status: "completed",
MinioPrefix: "uploads/100/",
ExpiresAt: past,
})
if err != nil {
t.Fatalf("create session: %v", err)
}
called := false
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
called = true
return nil, nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if called {
t.Error("ListObjects should not be called for completed sessions")
}
session, _ := uploadStore.GetSession(ctx, 1)
if session == nil {
t.Error("completed session should not be deleted")
}
}
func TestCleanupWorker_StopsOnContextCancel(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx, cancel := context.WithCancel(context.Background())
mockStore := &mockCleanupStorage{}
done := make(chan struct{})
go func() {
startCleanupWorker(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
close(done)
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Error("worker did not stop after context cancel")
}
}
func TestCleanupExpiredSessions_NoObjects(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
past := time.Now().Add(-1 * time.Hour)
err := uploadStore.CreateSession(ctx, &model.UploadSession{
FileName: "empty.bin",
FileSize: 1024,
ChunkSize: 16 << 20,
TotalChunks: 1,
SHA256: "empty_hash",
Status: "pending",
MinioPrefix: "uploads/200/",
ExpiresAt: past,
})
if err != nil {
t.Fatalf("create session: %v", err)
}
listCalled := false
removeCalled := false
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
listCalled = true
return nil, nil
},
removeObjectsFn: func(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
removeCalled = true
return nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if !listCalled {
t.Error("ListObjects should be called")
}
if removeCalled {
t.Error("RemoveObjects should not be called when no objects found")
}
session, _ := uploadStore.GetSession(ctx, 1)
if session != nil {
t.Error("session should be deleted even when no objects found")
}
}

View File

@@ -0,0 +1,61 @@
package app
import (
"context"
"sync"
"time"
"go.uber.org/zap"
)
// TaskPollable defines the interface for refreshing stale task statuses.
type TaskPollable interface {
RefreshStaleTasks(ctx context.Context) error
}
// TaskPoller periodically polls Slurm for task status updates via TaskPollable.
type TaskPoller struct {
taskSvc TaskPollable
interval time.Duration
cancel context.CancelFunc
wg sync.WaitGroup
logger *zap.Logger
}
// NewTaskPoller creates a new TaskPoller with the given service, interval, and logger.
func NewTaskPoller(taskSvc TaskPollable, interval time.Duration, logger *zap.Logger) *TaskPoller {
return &TaskPoller{
taskSvc: taskSvc,
interval: interval,
logger: logger,
}
}
// Start launches the background goroutine that periodically refreshes stale tasks.
func (p *TaskPoller) Start(ctx context.Context) {
ctx, p.cancel = context.WithCancel(ctx)
p.wg.Add(1)
go func() {
defer p.wg.Done()
ticker := time.NewTicker(p.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := p.taskSvc.RefreshStaleTasks(ctx); err != nil {
p.logger.Error("failed to refresh stale tasks", zap.Error(err))
}
}
}
}()
}
// Stop cancels the background goroutine and waits for it to finish.
func (p *TaskPoller) Stop() {
if p.cancel != nil {
p.cancel()
}
p.wg.Wait()
}

View File

@@ -0,0 +1,70 @@
package app
import (
"context"
"sync"
"testing"
"time"
"go.uber.org/zap"
)
type mockTaskPollable struct {
refreshFunc func(ctx context.Context) error
callCount int
mu sync.Mutex
}
func (m *mockTaskPollable) RefreshStaleTasks(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.refreshFunc != nil {
return m.refreshFunc(ctx)
}
return nil
}
func (m *mockTaskPollable) getCallCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.callCount
}
func TestTaskPoller_StartStop(t *testing.T) {
mock := &mockTaskPollable{}
logger := zap.NewNop()
poller := NewTaskPoller(mock, 1*time.Second, logger)
poller.Start(context.Background())
time.Sleep(100 * time.Millisecond)
poller.Stop()
// No goroutine leak — Stop() returned means wg.Wait() completed.
}
func TestTaskPoller_RefreshesStaleTasks(t *testing.T) {
mock := &mockTaskPollable{}
logger := zap.NewNop()
poller := NewTaskPoller(mock, 50*time.Millisecond, logger)
poller.Start(context.Background())
defer poller.Stop()
time.Sleep(300 * time.Millisecond)
if count := mock.getCallCount(); count < 1 {
t.Errorf("expected RefreshStaleTasks to be called at least once, got %d", count)
}
}
func TestTaskPoller_StopsCleanly(t *testing.T) {
mock := &mockTaskPollable{}
logger := zap.NewNop()
poller := NewTaskPoller(mock, 1*time.Second, logger)
poller.Start(context.Background())
poller.Stop()
// No panic and WaitGroup is done — Stop returned successfully.
}

View File

@@ -0,0 +1,28 @@
server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
work_dir_base: "/mnt/nfs_mount/platform" # 作业工作目录根路径,留空则不自动创建
log:
level: "info" # debug, info, warn, error
encoding: "json" # json, console
output_stdout: true # 是否输出日志到终端
file_path: "" # 日志文件路径,留空则不写文件
max_size: 100 # max MB per log file
max_backups: 5 # number of old log files to retain
max_age: 30 # days to retain old log files
compress: true # gzip rotated log files
gorm_level: "warn" # GORM SQL log level: silent, error, warn, info
minio:
endpoint: "http://fnos.dailz.cn:15001" # MinIO server address
access_key: "3dgDu9ncwflLoRQW2OeP" # access key
secret_key: "g2GLBNTPxJ9sdFwh37jtfilRSacEO5yQepMkDrnV" # secret key
bucket: "test" # bucket name
use_ssl: false # use TLS connection
chunk_size: 16777216 # upload chunk size in bytes (default: 16MB)
max_file_size: 53687091200 # max file size in bytes (default: 50GB)
min_chunk_size: 5242880 # minimum chunk size in bytes (default: 5MB)
session_ttl: 48 # session TTL in hours (default: 48)

79
internal/config/config.go Normal file
View File

@@ -0,0 +1,79 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
// LogConfig holds logging configuration values.
type LogConfig struct {
Level string `yaml:"level"` // debug, info, warn, error (default: info)
Encoding string `yaml:"encoding"` // json, console (default: json)
OutputStdout *bool `yaml:"output_stdout"` // 输出到终端 (default: true)
FilePath string `yaml:"file_path"` // log file path (for rotation)
MaxSize int `yaml:"max_size"` // MB per file (default: 100)
MaxBackups int `yaml:"max_backups"` // retained files (default: 5)
MaxAge int `yaml:"max_age"` // days to retain (default: 30)
Compress bool `yaml:"compress"` // gzip old files (default: true)
GormLevel string `yaml:"gorm_level"` // GORM SQL log level (default: warn)
}
// MinioConfig holds MinIO object storage configuration values.
type MinioConfig struct {
Endpoint string `yaml:"endpoint"` // MinIO server address
AccessKey string `yaml:"access_key"` // access key
SecretKey string `yaml:"secret_key"` // secret key
Bucket string `yaml:"bucket"` // bucket name
UseSSL bool `yaml:"use_ssl"` // use TLS connection
ChunkSize int64 `yaml:"chunk_size"` // upload chunk size in bytes (default: 16MB)
MaxFileSize int64 `yaml:"max_file_size"` // max file size in bytes (default: 50GB)
MinChunkSize int64 `yaml:"min_chunk_size"` // minimum chunk size in bytes (default: 5MB)
SessionTTL int `yaml:"session_ttl"` // session TTL in hours (default: 48)
}
// Config holds all application configuration values.
type Config struct {
ServerPort string `yaml:"server_port"`
SlurmAPIURL string `yaml:"slurm_api_url"`
SlurmUserName string `yaml:"slurm_user_name"`
SlurmJWTKeyPath string `yaml:"slurm_jwt_key_path"`
MySQLDSN string `yaml:"mysql_dsn"`
WorkDirBase string `yaml:"work_dir_base"` // base directory for job work dirs
Log LogConfig `yaml:"log"`
Minio MinioConfig `yaml:"minio"`
}
// Load reads a YAML configuration file and returns a parsed Config.
// If path is empty, it defaults to "./config.yaml".
func Load(path string) (*Config, error) {
if path == "" {
path = "./config.yaml"
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config file %s: %w", path, err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config file %s: %w", path, err)
}
if cfg.Minio.ChunkSize == 0 {
cfg.Minio.ChunkSize = 16 << 20 // 16MB
}
if cfg.Minio.MaxFileSize == 0 {
cfg.Minio.MaxFileSize = 50 << 30 // 50GB
}
if cfg.Minio.MinChunkSize == 0 {
cfg.Minio.MinChunkSize = 5 << 20 // 5MB
}
if cfg.Minio.SessionTTL == 0 {
cfg.Minio.SessionTTL = 48
}
return &cfg, nil
}

View File

@@ -0,0 +1,388 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.ServerPort != "9090" {
t.Errorf("ServerPort = %q, want %q", cfg.ServerPort, "9090")
}
if cfg.SlurmAPIURL != "http://slurm.example.com:6820" {
t.Errorf("SlurmAPIURL = %q, want %q", cfg.SlurmAPIURL, "http://slurm.example.com:6820")
}
if cfg.SlurmUserName != "admin" {
t.Errorf("SlurmUserName = %q, want %q", cfg.SlurmUserName, "admin")
}
if cfg.SlurmJWTKeyPath != "/etc/slurm/jwt.key" {
t.Errorf("SlurmJWTKeyPath = %q, want %q", cfg.SlurmJWTKeyPath, "/etc/slurm/jwt.key")
}
if cfg.MySQLDSN != "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true" {
t.Errorf("MySQLDSN = %q, want %q", cfg.MySQLDSN, "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true")
}
}
func TestLoadDefaultPath(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg == nil {
t.Fatal("Load() returned nil config")
}
}
func TestLoadWithLogConfig(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
log:
level: "debug"
encoding: "console"
output_stdout: true
file_path: "/var/log/app.log"
max_size: 200
max_backups: 10
max_age: 60
compress: false
gorm_level: "info"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Log.Level != "debug" {
t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug")
}
if cfg.Log.Encoding != "console" {
t.Errorf("Log.Encoding = %q, want %q", cfg.Log.Encoding, "console")
}
if cfg.Log.OutputStdout == nil || *cfg.Log.OutputStdout != true {
t.Errorf("Log.OutputStdout = %v, want true", cfg.Log.OutputStdout)
}
if cfg.Log.FilePath != "/var/log/app.log" {
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
}
if cfg.Log.MaxSize != 200 {
t.Errorf("Log.MaxSize = %d, want %d", cfg.Log.MaxSize, 200)
}
if cfg.Log.MaxBackups != 10 {
t.Errorf("Log.MaxBackups = %d, want %d", cfg.Log.MaxBackups, 10)
}
if cfg.Log.MaxAge != 60 {
t.Errorf("Log.MaxAge = %d, want %d", cfg.Log.MaxAge, 60)
}
if cfg.Log.Compress != false {
t.Errorf("Log.Compress = %v, want %v", cfg.Log.Compress, false)
}
if cfg.Log.GormLevel != "info" {
t.Errorf("Log.GormLevel = %q, want %q", cfg.Log.GormLevel, "info")
}
}
func TestLoadWithoutLogConfig(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Log.Level != "" {
t.Errorf("Log.Level = %q, want empty string", cfg.Log.Level)
}
if cfg.Log.Encoding != "" {
t.Errorf("Log.Encoding = %q, want empty string", cfg.Log.Encoding)
}
if cfg.Log.OutputStdout != nil {
t.Errorf("Log.OutputStdout = %v, want nil", cfg.Log.OutputStdout)
}
if cfg.Log.FilePath != "" {
t.Errorf("Log.FilePath = %q, want empty string", cfg.Log.FilePath)
}
if cfg.Log.MaxSize != 0 {
t.Errorf("Log.MaxSize = %d, want 0", cfg.Log.MaxSize)
}
if cfg.Log.MaxBackups != 0 {
t.Errorf("Log.MaxBackups = %d, want 0", cfg.Log.MaxBackups)
}
if cfg.Log.MaxAge != 0 {
t.Errorf("Log.MaxAge = %d, want 0", cfg.Log.MaxAge)
}
if cfg.Log.Compress != false {
t.Errorf("Log.Compress = %v, want false", cfg.Log.Compress)
}
if cfg.Log.GormLevel != "" {
t.Errorf("Log.GormLevel = %q, want empty string", cfg.Log.GormLevel)
}
}
func TestLoadExistingFieldsWithLogConfig(t *testing.T) {
content := []byte(`server_port: "7070"
slurm_api_url: "http://slurm2.example.com:6820"
slurm_user_name: "testuser"
slurm_jwt_key_path: "/keys/jwt.key"
mysql_dsn: "root:secret@tcp(db:3306)/mydb?parseTime=true"
log:
level: "warn"
encoding: "json"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.ServerPort != "7070" {
t.Errorf("ServerPort = %q, want %q", cfg.ServerPort, "7070")
}
if cfg.SlurmAPIURL != "http://slurm2.example.com:6820" {
t.Errorf("SlurmAPIURL = %q, want %q", cfg.SlurmAPIURL, "http://slurm2.example.com:6820")
}
if cfg.SlurmUserName != "testuser" {
t.Errorf("SlurmUserName = %q, want %q", cfg.SlurmUserName, "testuser")
}
if cfg.SlurmJWTKeyPath != "/keys/jwt.key" {
t.Errorf("SlurmJWTKeyPath = %q, want %q", cfg.SlurmJWTKeyPath, "/keys/jwt.key")
}
if cfg.MySQLDSN != "root:secret@tcp(db:3306)/mydb?parseTime=true" {
t.Errorf("MySQLDSN = %q, want %q", cfg.MySQLDSN, "root:secret@tcp(db:3306)/mydb?parseTime=true")
}
if cfg.Log.Level != "warn" {
t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "warn")
}
if cfg.Log.Encoding != "json" {
t.Errorf("Log.Encoding = %q, want %q", cfg.Log.Encoding, "json")
}
}
func TestLoadNonExistentFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.yaml")
if err == nil {
t.Fatal("Load() expected error for non-existent file, got nil")
}
}
func TestLoadWithOutputStdoutFalse(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
log:
level: "info"
encoding: "json"
output_stdout: false
file_path: "/var/log/app.log"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Log.OutputStdout == nil || *cfg.Log.OutputStdout != false {
t.Errorf("Log.OutputStdout = %v, want false", cfg.Log.OutputStdout)
}
if cfg.Log.FilePath != "/var/log/app.log" {
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
}
}
func TestLoadWithMinioConfig(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
minio:
endpoint: "minio.example.com:9000"
access_key: "myaccesskey"
secret_key: "mysecretkey"
bucket: "test-bucket"
use_ssl: true
chunk_size: 33554432
max_file_size: 107374182400
min_chunk_size: 10485760
session_ttl: 24
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Minio.Endpoint != "minio.example.com:9000" {
t.Errorf("Minio.Endpoint = %q, want %q", cfg.Minio.Endpoint, "minio.example.com:9000")
}
if cfg.Minio.AccessKey != "myaccesskey" {
t.Errorf("Minio.AccessKey = %q, want %q", cfg.Minio.AccessKey, "myaccesskey")
}
if cfg.Minio.SecretKey != "mysecretkey" {
t.Errorf("Minio.SecretKey = %q, want %q", cfg.Minio.SecretKey, "mysecretkey")
}
if cfg.Minio.Bucket != "test-bucket" {
t.Errorf("Minio.Bucket = %q, want %q", cfg.Minio.Bucket, "test-bucket")
}
if cfg.Minio.UseSSL != true {
t.Errorf("Minio.UseSSL = %v, want %v", cfg.Minio.UseSSL, true)
}
if cfg.Minio.ChunkSize != 33554432 {
t.Errorf("Minio.ChunkSize = %d, want %d", cfg.Minio.ChunkSize, 33554432)
}
if cfg.Minio.MaxFileSize != 107374182400 {
t.Errorf("Minio.MaxFileSize = %d, want %d", cfg.Minio.MaxFileSize, 107374182400)
}
if cfg.Minio.MinChunkSize != 10485760 {
t.Errorf("Minio.MinChunkSize = %d, want %d", cfg.Minio.MinChunkSize, 10485760)
}
if cfg.Minio.SessionTTL != 24 {
t.Errorf("Minio.SessionTTL = %d, want %d", cfg.Minio.SessionTTL, 24)
}
}
func TestLoadWithoutMinioConfig(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Minio.Endpoint != "" {
t.Errorf("Minio.Endpoint = %q, want empty string", cfg.Minio.Endpoint)
}
if cfg.Minio.AccessKey != "" {
t.Errorf("Minio.AccessKey = %q, want empty string", cfg.Minio.AccessKey)
}
if cfg.Minio.SecretKey != "" {
t.Errorf("Minio.SecretKey = %q, want empty string", cfg.Minio.SecretKey)
}
if cfg.Minio.Bucket != "" {
t.Errorf("Minio.Bucket = %q, want empty string", cfg.Minio.Bucket)
}
if cfg.Minio.UseSSL != false {
t.Errorf("Minio.UseSSL = %v, want false", cfg.Minio.UseSSL)
}
}
func TestLoadMinioDefaults(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
minio:
endpoint: "localhost:9000"
access_key: "minioadmin"
secret_key: "minioadmin"
bucket: "uploads"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Minio.ChunkSize != 16<<20 {
t.Errorf("Minio.ChunkSize = %d, want %d", cfg.Minio.ChunkSize, 16<<20)
}
if cfg.Minio.MaxFileSize != 50<<30 {
t.Errorf("Minio.MaxFileSize = %d, want %d", cfg.Minio.MaxFileSize, 50<<30)
}
if cfg.Minio.MinChunkSize != 5<<20 {
t.Errorf("Minio.MinChunkSize = %d, want %d", cfg.Minio.MinChunkSize, 5<<20)
}
if cfg.Minio.SessionTTL != 48 {
t.Errorf("Minio.SessionTTL = %d, want %d", cfg.Minio.SessionTTL, 48)
}
}

View File

@@ -0,0 +1,174 @@
package handler
import (
"encoding/json"
"errors"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
type ApplicationHandler struct {
appSvc *service.ApplicationService
logger *zap.Logger
}
func NewApplicationHandler(appSvc *service.ApplicationService, logger *zap.Logger) *ApplicationHandler {
return &ApplicationHandler{appSvc: appSvc, logger: logger}
}
func (h *ApplicationHandler) ListApplications(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
apps, total, err := h.appSvc.ListApplications(c.Request.Context(), page, pageSize)
if err != nil {
h.logger.Error("failed to list applications", zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, gin.H{
"applications": apps,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *ApplicationHandler) CreateApplication(c *gin.Context) {
var req model.CreateApplicationRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create application", zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if req.Name == "" || req.ScriptTemplate == "" {
h.logger.Warn("missing required fields for create application")
server.BadRequest(c, "name and script_template are required")
return
}
if len(req.Parameters) == 0 {
req.Parameters = json.RawMessage(`[]`)
}
id, err := h.appSvc.CreateApplication(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to create application", zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("application created", zap.Int64("id", id))
server.Created(c, gin.H{"id": id})
}
func (h *ApplicationHandler) GetApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
app, err := h.appSvc.GetApplication(c.Request.Context(), id)
if err != nil {
h.logger.Error("failed to get application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if app == nil {
h.logger.Warn("application not found", zap.Int64("id", id))
server.NotFound(c, "application not found")
return
}
server.OK(c, app)
}
func (h *ApplicationHandler) UpdateApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id for update", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
var req model.UpdateApplicationRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for update application", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if err := h.appSvc.UpdateApplication(c.Request.Context(), id, &req); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
h.logger.Warn("application not found for update", zap.Int64("id", id))
server.NotFound(c, "application not found")
return
}
h.logger.Error("failed to update application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, "failed to update application")
return
}
h.logger.Info("application updated", zap.Int64("id", id))
server.OK(c, gin.H{"message": "application updated"})
}
func (h *ApplicationHandler) DeleteApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.appSvc.DeleteApplication(c.Request.Context(), id); err != nil {
h.logger.Error("failed to delete application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("application deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "application deleted"})
}
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此端点不再使用。
/* func (h *ApplicationHandler) SubmitApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id for submit", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
var req model.ApplicationSubmitRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for submit application", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
resp, err := h.appSvc.SubmitFromApplication(c.Request.Context(), id, req.Values)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not found") {
h.logger.Warn("application not found for submit", zap.Int64("id", id))
server.NotFound(c, errStr)
return
}
if strings.Contains(errStr, "validation") {
h.logger.Warn("application submit validation failed", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to submit application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("application submitted", zap.Int64("id", id), zap.Int32("job_id", resp.JobID))
server.Created(c, resp)
} */

View File

@@ -0,0 +1,642 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
func itoa(id int64) string {
return fmt.Sprintf("%d", id)
}
func setupApplicationHandler() (*ApplicationHandler, *gorm.DB) {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, nil, "", zap.NewNop())
h := NewApplicationHandler(appSvc, zap.NewNop())
return h, db
}
func setupApplicationRouter(h *ApplicationHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
apps := v1.Group("/applications")
apps.GET("", h.ListApplications)
apps.POST("", h.CreateApplication)
apps.GET("/:id", h.GetApplication)
apps.PUT("/:id", h.UpdateApplication)
apps.DELETE("/:id", h.DeleteApplication)
// apps.POST("/:id/submit", h.SubmitApplication) // [已禁用] 已被 POST /tasks 取代
return r
}
func setupApplicationHandlerWithSlurm(slurmHandler http.HandlerFunc) (*ApplicationHandler, func()) {
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
jobSvc := service.NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
h := NewApplicationHandler(appSvc, zap.NewNop())
return h, srv.Close
}
func setupApplicationHandlerWithObserver() (*ApplicationHandler, *gorm.DB, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, nil, "", l)
return NewApplicationHandler(appSvc, l), db, recorded
}
func createTestApplication(h *ApplicationHandler, r *gin.Engine) int64 {
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
return int64(data["id"].(float64))
}
// ---- CRUD Tests ----
func TestCreateApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "my-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if _, ok := data["id"]; !ok {
t.Fatal("expected id in response data")
}
}
func TestCreateApplication_MissingName(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestCreateApplication_MissingScriptTemplate(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "my-app",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestCreateApplication_EmptyParameters(t *testing.T) {
h, db := setupApplicationHandler()
r := setupApplicationRouter(h)
body := `{"name":"empty-params-app","script_template":"#!/bin/bash\necho hello"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var app model.Application
db.First(&app)
if string(app.Parameters) != "[]" {
t.Fatalf("expected parameters to default to [], got %s", string(app.Parameters))
}
}
func TestListApplications_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
createTestApplication(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["total"].(float64) < 1 {
t.Fatal("expected at least 1 application")
}
}
func TestListApplications_Pagination(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
for i := 0; i < 5; i++ {
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: fmt.Sprintf("app-%d", i),
ScriptTemplate: "#!/bin/bash\necho " + fmt.Sprintf("app-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications?page=1&page_size=2", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 5 {
t.Fatalf("expected total=5, got %v", data["total"])
}
if data["page"].(float64) != 1 {
t.Fatalf("expected page=1, got %v", data["page"])
}
if data["page_size"].(float64) != 2 {
t.Fatalf("expected page_size=2, got %v", data["page_size"])
}
}
func TestGetApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["name"] != "test-app" {
t.Fatalf("expected name=test-app, got %v", data["name"])
}
}
func TestGetApplication_NotFound(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestGetApplication_InvalidID(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestUpdateApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
newName := "updated-app"
body, _ := json.Marshal(model.UpdateApplicationRequest{
Name: &newName,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/"+itoa(id), bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateApplication_NotFound(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
newName := "updated-app"
body, _ := json.Marshal(model.UpdateApplicationRequest{
Name: &newName,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/999", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestDeleteApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Fatalf("expected 404 after delete, got %d", w2.Code)
}
}
// ---- Submit Tests ----
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitApplication_Success(t *testing.T) {
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"job_id": 12345,
})
})
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
defer cleanup()
r := setupApplicationRouter(h)
params := `[{"name":"COUNT","type":"integer","required":true}]`
body := `{"name":"submit-app","script_template":"#!/bin/bash\necho $COUNT","parameters":` + params + `}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var createResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &createResp)
data := createResp["data"].(map[string]interface{})
id := int64(data["id"].(float64))
submitBody := `{"values":{"COUNT":"5"}}`
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
req2.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w2.Code, w2.Body.String())
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitApplication_AppNotFound(t *testing.T) {
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
defer cleanup()
r := setupApplicationRouter(h)
submitBody := `{"values":{"COUNT":"5"}}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/999/submit", strings.NewReader(submitBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitApplication_ValidationFail(t *testing.T) {
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
defer cleanup()
r := setupApplicationRouter(h)
params := `[{"name":"COUNT","type":"integer","required":true}]`
body := `{"name":"val-app","script_template":"#!/bin/bash\necho $COUNT","parameters":` + params + `}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var createResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &createResp)
data := createResp["data"].(map[string]interface{})
id := int64(data["id"].(float64))
submitBody := `{"values":{"COUNT":"not-a-number"}}`
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
req2.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w2.Code, w2.Body.String())
}
}
*/
// ---- Logging Tests ----
func TestApplicationLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "log-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application created" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application created' log message")
}
}
func TestApplicationLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/999", nil)
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application not found" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application not found' log message")
}
}
func TestApplicationLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
recorded.TakeAll()
newName := "updated-log-app"
body, _ := json.Marshal(model.UpdateApplicationRequest{
Name: &newName,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/"+itoa(id), bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application updated" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application updated' log message")
}
}
func TestApplicationLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
recorded.TakeAll()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application deleted" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application deleted' log message")
}
}
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestApplicationLogging_SubmitSuccess_LogsInfoWithID(t *testing.T) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"job_id": 42,
})
}))
defer slurmSrv.Close()
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
jobSvc := service.NewJobService(client, l)
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, jobSvc, "", l)
h := NewApplicationHandler(appSvc, l)
r := setupApplicationRouter(h)
params := `[{"name":"X","type":"string","required":false}]`
body := `{"name":"sub-log-app","script_template":"#!/bin/bash\necho $X","parameters":` + params + `}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var createResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &createResp)
data := createResp["data"].(map[string]interface{})
id := int64(data["id"].(float64))
recorded.TakeAll()
submitBody := `{"values":{"X":"val"}}`
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
req2.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w2, req2)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application submitted" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application submitted' log message")
}
}
*/
func TestApplicationLogging_CreateBadRequest_LogsWarn(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(`{"name":""}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "invalid request body for create application" {
found = true
break
}
}
if !found {
t.Fatal("expected 'invalid request body for create application' log message")
}
}
func TestApplicationLogging_LogsDoNotContainApplicationContent(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "secret-app",
ScriptTemplate: "#!/bin/bash\necho secret_password_here",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
for _, entry := range recorded.All() {
msg := entry.Message
if strings.Contains(msg, "secret_password_here") {
t.Fatalf("log message contains application content: %s", msg)
}
for _, field := range entry.Context {
if strings.Contains(field.String, "secret_password_here") {
t.Fatalf("log field contains application content: %s", field.String)
}
}
}
}

View File

@@ -0,0 +1,94 @@
package handler
import (
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ClusterHandler handles HTTP requests for cluster operations (nodes, partitions, diag).
type ClusterHandler struct {
clusterSvc *service.ClusterService
logger *zap.Logger
}
// NewClusterHandler creates a new ClusterHandler with the given ClusterService.
func NewClusterHandler(clusterSvc *service.ClusterService, logger *zap.Logger) *ClusterHandler {
return &ClusterHandler{clusterSvc: clusterSvc, logger: logger}
}
// GetNodes handles GET /api/v1/nodes.
func (h *ClusterHandler) GetNodes(c *gin.Context) {
nodes, err := h.clusterSvc.GetNodes(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetNodes"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, nodes)
}
// GetNode handles GET /api/v1/nodes/:name.
func (h *ClusterHandler) GetNode(c *gin.Context) {
name := c.Param("name")
resp, err := h.clusterSvc.GetNode(c.Request.Context(), name)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetNode"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if resp == nil {
h.logger.Warn("not found", zap.String("method", "GetNode"), zap.String("name", name))
server.NotFound(c, "node not found")
return
}
server.OK(c, resp)
}
// GetPartitions handles GET /api/v1/partitions.
func (h *ClusterHandler) GetPartitions(c *gin.Context) {
partitions, err := h.clusterSvc.GetPartitions(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetPartitions"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, partitions)
}
// GetPartition handles GET /api/v1/partitions/:name.
func (h *ClusterHandler) GetPartition(c *gin.Context) {
name := c.Param("name")
resp, err := h.clusterSvc.GetPartition(c.Request.Context(), name)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetPartition"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if resp == nil {
h.logger.Warn("not found", zap.String("method", "GetPartition"), zap.String("name", name))
server.NotFound(c, "partition not found")
return
}
server.OK(c, resp)
}
// GetDiag handles GET /api/v1/diag.
func (h *ClusterHandler) GetDiag(c *gin.Context) {
resp, err := h.clusterSvc.GetDiag(c.Request.Context())
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetDiag"), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}

View File

@@ -0,0 +1,634 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func setupClusterHandler(slurmHandler http.HandlerFunc) (*httptest.Server, *ClusterHandler) {
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
clusterSvc := service.NewClusterService(client, zap.NewNop())
return srv, NewClusterHandler(clusterSvc, zap.NewNop())
}
func setupClusterHandlerWithObserver(slurmHandler http.HandlerFunc) (*httptest.Server, *ClusterHandler, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
clusterSvc := service.NewClusterService(client, l)
return srv, NewClusterHandler(clusterSvc, l), recorded
}
func setupClusterRouter(h *ClusterHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
v1.GET("/nodes", h.GetNodes)
v1.GET("/nodes/:name", h.GetNode)
v1.GET("/partitions", h.GetPartitions)
v1.GET("/partitions/:name", h.GetPartition)
v1.GET("/diag", h.GetDiag)
return r
}
func TestGetNodes_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetNode_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetNode_NotFound(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestGetPartitions_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": int32(10),
},
"cpus": map[string]interface{}{
"total": int32(640),
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{"number": int64(60)},
},
},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetPartition_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": int32(10),
},
"cpus": map[string]interface{}{
"total": int32(640),
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{"number": int64(60)},
},
},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
func TestGetPartition_NotFound(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestGetDiag_Success(t *testing.T) {
srv, h := setupClusterHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"statistics": map[string]interface{}{
"server_thread_count": 3,
"agent_queue_size": 0,
"jobs_submitted": 100,
"jobs_started": 90,
"jobs_completed": 85,
"schedule_cycle_last": 10,
"schedule_cycle_total": 500,
},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["success"] != true {
t.Fatal("expected success=true")
}
}
// --- Logging tests ---
func TestClusterHandler_GetNodes_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetNodes")
}
func TestClusterHandler_GetNodes_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1"},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetNode_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetNode")
}
func TestClusterHandler_GetNode_NotFound_LogsWarn(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
if recorded.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", recorded.Len())
}
entry := recorded.All()[2]
if entry.Level != zapcore.WarnLevel {
t.Fatalf("expected Warn level, got %v", entry.Level)
}
if entry.Message != "not found" {
t.Fatalf("expected message 'not found', got %q", entry.Message)
}
assertField(t, entry.Context, "method", "GetNode")
assertField(t, entry.Context, "name", "nonexistent")
}
func TestClusterHandler_GetNode_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"IDLE"}, "cpus": 64, "real_memory": 128000},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/nodes/node1", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetPartitions_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetPartitions")
}
func TestClusterHandler_GetPartitions_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{"name": "normal"},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetPartition_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetPartition")
}
func TestClusterHandler_GetPartition_NotFound_LogsWarn(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/nonexistent", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
if recorded.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", recorded.Len())
}
entry := recorded.All()[2]
if entry.Level != zapcore.WarnLevel {
t.Fatalf("expected Warn level, got %v", entry.Level)
}
if entry.Message != "not found" {
t.Fatalf("expected message 'not found', got %q", entry.Message)
}
assertField(t, entry.Context, "method", "GetPartition")
assertField(t, entry.Context, "name", "nonexistent")
}
func TestClusterHandler_GetPartition_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": int32(10),
},
"cpus": map[string]interface{}{
"total": int32(640),
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{"number": int64(60)},
},
},
},
"last_update": map[string]interface{}{},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/partitions/normal", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
func TestClusterHandler_GetDiag_InternalError_LogsError(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal"}]}`)
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
handlerLogs := recorded.FilterMessage("handler error")
if handlerLogs.Len() != 1 {
t.Fatalf("expected 1 handler error log, got %d", handlerLogs.Len())
}
entry := handlerLogs.All()[0]
if entry.Level != zapcore.ErrorLevel {
t.Fatalf("expected Error level, got %v", entry.Level)
}
assertField(t, entry.Context, "method", "GetDiag")
}
func TestClusterHandler_GetDiag_Success_NoLogs(t *testing.T) {
srv, h, recorded := setupClusterHandlerWithObserver(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"statistics": map[string]interface{}{
"server_thread_count": 3,
"agent_queue_size": 0,
"jobs_submitted": 100,
"jobs_started": 90,
"jobs_completed": 85,
"schedule_cycle_last": 10,
"schedule_cycle_total": 500,
},
})
}))
defer srv.Close()
router := setupClusterRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/diag", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
// assertField checks that a zap Field slice contains a string field with the given key and value.
func assertField(t *testing.T, fields []zapcore.Field, key, value string) {
t.Helper()
for _, f := range fields {
if f.Key == key && f.String == value {
return
}
}
t.Fatalf("expected field %q=%q in context, got %v", key, value, fields)
}

View File

@@ -0,0 +1,138 @@
package handler
import (
"context"
"io"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type fileServiceProvider interface {
ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error)
GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error)
DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error)
DeleteFile(ctx context.Context, fileID int64) error
}
type FileHandler struct {
svc fileServiceProvider
logger *zap.Logger
}
func NewFileHandler(svc *service.FileService, logger *zap.Logger) *FileHandler {
return &FileHandler{svc: svc, logger: logger}
}
func (h *FileHandler) ListFiles(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
var folderID *int64
if v := c.Query("folder_id"); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil {
h.logger.Warn("invalid folder_id", zap.String("folder_id", v))
server.BadRequest(c, "invalid folder_id")
return
}
folderID = &id
}
search := c.Query("search")
files, total, err := h.svc.ListFiles(c.Request.Context(), folderID, page, pageSize, search)
if err != nil {
h.logger.Error("failed to list files", zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, model.ListFilesResponse{
Files: files,
Total: total,
Page: page,
PageSize: pageSize,
})
}
func (h *FileHandler) GetFile(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid file id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
file, blob, err := h.svc.GetFileMetadata(c.Request.Context(), id)
if err != nil {
h.logger.Error("failed to get file", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
resp := model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: file.BlobSHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}
server.OK(c, resp)
}
func (h *FileHandler) DownloadFile(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid file id for download", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
rangeHeader := c.GetHeader("Range")
reader, file, blob, start, end, err := h.svc.DownloadFile(c.Request.Context(), id, rangeHeader)
if err != nil {
h.logger.Error("failed to download file", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if rangeHeader != "" {
server.StreamRange(c, reader, start, end, blob.FileSize, blob.MimeType)
} else {
server.StreamFile(c, reader, file.Name, blob.FileSize, blob.MimeType)
}
}
func (h *FileHandler) DeleteFile(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid file id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.svc.DeleteFile(c.Request.Context(), id); err != nil {
h.logger.Error("failed to delete file", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("file deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "file deleted"})
}

View File

@@ -0,0 +1,369 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/model"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type mockFileService struct {
listFilesFn func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error)
getFileMetadataFn func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error)
downloadFileFn func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error)
deleteFileFn func(ctx context.Context, fileID int64) error
}
func (m *mockFileService) ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return m.listFilesFn(ctx, folderID, page, pageSize, search)
}
func (m *mockFileService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
return m.getFileMetadataFn(ctx, fileID)
}
func (m *mockFileService) DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
return m.downloadFileFn(ctx, fileID, rangeHeader)
}
func (m *mockFileService) DeleteFile(ctx context.Context, fileID int64) error {
return m.deleteFileFn(ctx, fileID)
}
type fileHandlerSetup struct {
handler *FileHandler
mock *mockFileService
router *gin.Engine
}
func newFileHandlerSetup() *fileHandlerSetup {
gin.SetMode(gin.TestMode)
mock := &mockFileService{}
h := &FileHandler{
svc: mock,
logger: zap.NewNop(),
}
r := gin.New()
v1 := r.Group("/api/v1")
files := v1.Group("/files")
files.GET("", h.ListFiles)
files.GET("/:id", h.GetFile)
files.GET("/:id/download", h.DownloadFile)
files.DELETE("/:id", h.DeleteFile)
return &fileHandlerSetup{handler: h, mock: mock, router: r}
}
// ---- ListFiles Tests ----
func TestListFiles_Empty(t *testing.T) {
s := newFileHandlerSetup()
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return []model.FileResponse{}, 0, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
files := data["files"].([]interface{})
if len(files) != 0 {
t.Fatalf("expected empty files list, got %d", len(files))
}
if data["total"].(float64) != 0 {
t.Fatalf("expected total=0, got %v", data["total"])
}
}
func TestListFiles_WithFiles(t *testing.T) {
s := newFileHandlerSetup()
now := time.Now()
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return []model.FileResponse{
{ID: 1, Name: "a.txt", Size: 100, MimeType: "text/plain", SHA256: "abc123", CreatedAt: now, UpdatedAt: now},
{ID: 2, Name: "b.pdf", Size: 200, MimeType: "application/pdf", SHA256: "def456", CreatedAt: now, UpdatedAt: now},
}, 2, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files?page=1&page_size=10", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
files := data["files"].([]interface{})
if len(files) != 2 {
t.Fatalf("expected 2 files, got %d", len(files))
}
if data["total"].(float64) != 2 {
t.Fatalf("expected total=2, got %v", data["total"])
}
if data["page"].(float64) != 1 {
t.Fatalf("expected page=1, got %v", data["page"])
}
if data["page_size"].(float64) != 10 {
t.Fatalf("expected page_size=10, got %v", data["page_size"])
}
}
func TestListFiles_WithFolderID(t *testing.T) {
s := newFileHandlerSetup()
var capturedFolderID *int64
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
capturedFolderID = folderID
return []model.FileResponse{}, 0, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files?folder_id=5", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if capturedFolderID == nil || *capturedFolderID != 5 {
t.Fatalf("expected folder_id=5, got %v", capturedFolderID)
}
}
func TestListFiles_ServiceError(t *testing.T) {
s := newFileHandlerSetup()
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return nil, 0, fmt.Errorf("db error")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
}
// ---- GetFile Tests ----
func TestGetFile_Found(t *testing.T) {
s := newFileHandlerSetup()
now := time.Now()
s.mock.getFileMetadataFn = func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
return &model.File{
ID: 1, Name: "test.txt", BlobSHA256: "abc123", CreatedAt: now, UpdatedAt: now,
}, &model.FileBlob{
ID: 1, SHA256: "abc123", FileSize: 1024, MimeType: "text/plain", CreatedAt: now,
}, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
}
func TestGetFile_NotFound(t *testing.T) {
s := newFileHandlerSetup()
s.mock.getFileMetadataFn = func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
return nil, nil, fmt.Errorf("file not found: 999")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/999", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetFile_InvalidID(t *testing.T) {
s := newFileHandlerSetup()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/abc", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
// ---- DownloadFile Tests ----
func TestDownloadFile_Full(t *testing.T) {
s := newFileHandlerSetup()
content := "hello world file content"
now := time.Now()
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
reader := io.NopCloser(strings.NewReader(content))
return reader,
&model.File{ID: 1, Name: "test.txt", CreatedAt: now},
&model.FileBlob{FileSize: int64(len(content)), MimeType: "text/plain"},
0, int64(len(content)) - 1, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1/download", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
// Check streaming headers
if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "test.txt") {
t.Fatalf("expected Content-Disposition to contain 'test.txt', got %s", got)
}
if got := w.Header().Get("Content-Type"); got != "text/plain" {
t.Fatalf("expected Content-Type=text/plain, got %s", got)
}
if got := w.Header().Get("Accept-Ranges"); got != "bytes" {
t.Fatalf("expected Accept-Ranges=bytes, got %s", got)
}
if w.Body.String() != content {
t.Fatalf("expected body %q, got %q", content, w.Body.String())
}
}
func TestDownloadFile_WithRange(t *testing.T) {
s := newFileHandlerSetup()
content := "hello world"
now := time.Now()
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
reader := io.NopCloser(strings.NewReader(content[0:5]))
return reader,
&model.File{ID: 1, Name: "test.txt", CreatedAt: now},
&model.FileBlob{FileSize: int64(len(content)), MimeType: "text/plain"},
0, 4, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1/download", nil)
req.Header.Set("Range", "bytes=0-4")
s.router.ServeHTTP(w, req)
if w.Code != http.StatusPartialContent {
t.Fatalf("expected 206, got %d: %s", w.Code, w.Body.String())
}
// Check Content-Range header
if got := w.Header().Get("Content-Range"); got != "bytes 0-4/11" {
t.Fatalf("expected Content-Range 'bytes 0-4/11', got %s", got)
}
if got := w.Header().Get("Content-Type"); got != "text/plain" {
t.Fatalf("expected Content-Type=text/plain, got %s", got)
}
if got := w.Header().Get("Content-Length"); got != "5" {
t.Fatalf("expected Content-Length=5, got %s", got)
}
}
func TestDownloadFile_InvalidID(t *testing.T) {
s := newFileHandlerSetup()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/abc/download", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestDownloadFile_ServiceError(t *testing.T) {
s := newFileHandlerSetup()
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
return nil, nil, nil, 0, 0, fmt.Errorf("file not found: 999")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/999/download", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
}
// ---- DeleteFile Tests ----
func TestDeleteFile_Success(t *testing.T) {
s := newFileHandlerSetup()
s.mock.deleteFileFn = func(ctx context.Context, fileID int64) error {
return nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/1", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["message"] != "file deleted" {
t.Fatalf("expected message 'file deleted', got %v", data["message"])
}
}
func TestDeleteFile_InvalidID(t *testing.T) {
s := newFileHandlerSetup()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/abc", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestDeleteFile_ServiceError(t *testing.T) {
s := newFileHandlerSetup()
s.mock.deleteFileFn = func(ctx context.Context, fileID int64) error {
return fmt.Errorf("file not found: 999")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/999", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
}

View File

@@ -0,0 +1,133 @@
package handler
import (
"context"
"strconv"
"strings"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type folderServiceProvider interface {
CreateFolder(ctx context.Context, name string, parentID *int64) (*model.FolderResponse, error)
GetFolder(ctx context.Context, id int64) (*model.FolderResponse, error)
ListFolders(ctx context.Context, parentID *int64) ([]model.FolderResponse, error)
DeleteFolder(ctx context.Context, id int64) error
}
type FolderHandler struct {
svc folderServiceProvider
logger *zap.Logger
}
func NewFolderHandler(svc *service.FolderService, logger *zap.Logger) *FolderHandler {
return &FolderHandler{svc: svc, logger: logger}
}
func (h *FolderHandler) CreateFolder(c *gin.Context) {
var req model.CreateFolderRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create folder", zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if req.Name == "" {
h.logger.Warn("missing folder name")
server.BadRequest(c, "name is required")
return
}
resp, err := h.svc.CreateFolder(c.Request.Context(), req.Name, req.ParentID)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "invalid folder name") || strings.Contains(errStr, "cannot be") {
h.logger.Warn("invalid folder name", zap.String("name", req.Name), zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to create folder", zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("folder created", zap.Int64("id", resp.ID))
server.Created(c, resp)
}
func (h *FolderHandler) GetFolder(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid folder id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
resp, err := h.svc.GetFolder(c.Request.Context(), id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
h.logger.Warn("folder not found", zap.Int64("id", id))
server.NotFound(c, "folder not found")
return
}
h.logger.Error("failed to get folder", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}
func (h *FolderHandler) ListFolders(c *gin.Context) {
var parentID *int64
if q := c.Query("parent_id"); q != "" {
pid, err := strconv.ParseInt(q, 10, 64)
if err != nil {
h.logger.Warn("invalid parent_id query param", zap.String("parent_id", q))
server.BadRequest(c, "invalid parent_id")
return
}
parentID = &pid
}
folders, err := h.svc.ListFolders(c.Request.Context(), parentID)
if err != nil {
h.logger.Error("failed to list folders", zap.Error(err))
server.InternalError(c, err.Error())
return
}
if folders == nil {
folders = []model.FolderResponse{}
}
server.OK(c, folders)
}
func (h *FolderHandler) DeleteFolder(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid folder id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.svc.DeleteFolder(c.Request.Context(), id); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not empty") {
h.logger.Warn("cannot delete non-empty folder", zap.Int64("id", id))
server.BadRequest(c, "folder is not empty")
return
}
if strings.Contains(errStr, "not found") {
h.logger.Warn("folder not found for delete", zap.Int64("id", id))
server.NotFound(c, "folder not found")
return
}
h.logger.Error("failed to delete folder", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("folder deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "folder deleted"})
}

View File

@@ -0,0 +1,206 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
func setupFolderHandler() (*FolderHandler, *gorm.DB) {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Folder{}, &model.File{})
folderStore := store.NewFolderStore(db)
fileStore := store.NewFileStore(db)
svc := service.NewFolderService(folderStore, fileStore, zap.NewNop())
h := NewFolderHandler(svc, zap.NewNop())
return h, db
}
func setupFolderRouter(h *FolderHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
folders := v1.Group("/files/folders")
folders.POST("", h.CreateFolder)
folders.GET("/:id", h.GetFolder)
folders.GET("", h.ListFolders)
folders.DELETE("/:id", h.DeleteFolder)
return r
}
func createTestFolder(h *FolderHandler, r *gin.Engine) int64 {
body, _ := json.Marshal(model.CreateFolderRequest{Name: "test-folder"})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
return int64(data["id"].(float64))
}
func TestCreateFolder_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
body, _ := json.Marshal(model.CreateFolderRequest{Name: "my-folder"})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["name"] != "my-folder" {
t.Fatalf("expected name=my-folder, got %v", data["name"])
}
if data["path"] != "/my-folder/" {
t.Fatalf("expected path=/my-folder/, got %v", data["path"])
}
}
func TestCreateFolder_PathTraversal(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
body, _ := json.Marshal(model.CreateFolderRequest{Name: ".."})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestCreateFolder_MissingName(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
body, _ := json.Marshal(map[string]interface{}{})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetFolder_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
id := createTestFolder(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["name"] != "test-folder" {
t.Fatalf("expected name=test-folder, got %v", data["name"])
}
}
func TestGetFolder_NotFound(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestListFolders_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
createTestFolder(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].([]interface{})
if len(data) < 1 {
t.Fatal("expected at least 1 folder")
}
}
func TestDeleteFolder_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
id := createTestFolder(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/folders/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestDeleteFolder_NonEmpty(t *testing.T) {
h, db := setupFolderHandler()
r := setupFolderRouter(h)
parentID := createTestFolder(h, r)
child := &model.Folder{
Name: "child-folder",
ParentID: &parentID,
Path: "/test-folder/child-folder/",
}
db.Create(child)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/folders/"+itoa(parentID), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}

132
internal/handler/job.go Normal file
View File

@@ -0,0 +1,132 @@
package handler
import (
"net/http"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// JobHandler handles HTTP requests for job operations.
type JobHandler struct {
jobSvc *service.JobService
logger *zap.Logger
}
// NewJobHandler creates a new JobHandler with the given JobService.
func NewJobHandler(jobSvc *service.JobService, logger *zap.Logger) *JobHandler {
return &JobHandler{jobSvc: jobSvc, logger: logger}
}
// SubmitJob handles POST /api/v1/jobs/submit.
func (h *JobHandler) SubmitJob(c *gin.Context) {
var req model.SubmitJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "invalid request body"))
server.BadRequest(c, "invalid request body")
return
}
if req.Script == "" {
h.logger.Warn("bad request", zap.String("method", "SubmitJob"), zap.String("error", "script is required"))
server.BadRequest(c, "script is required")
return
}
resp, err := h.jobSvc.SubmitJob(c.Request.Context(), &req)
if err != nil {
h.logger.Error("handler error", zap.String("method", "SubmitJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
server.ErrorWithStatus(c, http.StatusBadGateway, "slurm error: "+err.Error())
return
}
server.Created(c, resp)
}
// GetJobs handles GET /api/v1/jobs with pagination.
func (h *JobHandler) GetJobs(c *gin.Context) {
var query model.JobListQuery
if err := c.ShouldBindQuery(&query); err != nil {
h.logger.Warn("bad request", zap.String("method", "GetJobs"), zap.String("error", "invalid query params"))
server.BadRequest(c, "invalid query params")
return
}
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 {
query.PageSize = 20
}
resp, err := h.jobSvc.GetJobs(c.Request.Context(), &query)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJobs"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}
// GetJob handles GET /api/v1/jobs/:id.
func (h *JobHandler) GetJob(c *gin.Context) {
jobID := c.Param("id")
resp, err := h.jobSvc.GetJob(c.Request.Context(), jobID)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJob"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if resp == nil {
h.logger.Warn("bad request", zap.String("method", "GetJob"), zap.String("error", "job not found"))
server.NotFound(c, "job not found")
return
}
server.OK(c, resp)
}
// CancelJob handles DELETE /api/v1/jobs/:id.
func (h *JobHandler) CancelJob(c *gin.Context) {
jobID := c.Param("id")
err := h.jobSvc.CancelJob(c.Request.Context(), jobID)
if err != nil {
h.logger.Error("handler error", zap.String("method", "CancelJob"), zap.Int("status", http.StatusBadGateway), zap.Error(err))
server.ErrorWithStatus(c, http.StatusBadGateway, err.Error())
return
}
server.OK(c, gin.H{"message": "job cancelled"})
}
// GetJobHistory handles GET /api/v1/jobs/history.
func (h *JobHandler) GetJobHistory(c *gin.Context) {
var query model.JobHistoryQuery
if err := c.ShouldBindQuery(&query); err != nil {
h.logger.Warn("bad request", zap.String("method", "GetJobHistory"), zap.String("error", "invalid query params"))
server.BadRequest(c, "invalid query params")
return
}
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 {
query.PageSize = 20
}
resp, err := h.jobSvc.GetJobHistory(c.Request.Context(), &query)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJobHistory"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}

View File

@@ -0,0 +1,908 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func setupJobRouter(h *JobHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
jobs := v1.Group("/jobs")
{
jobs.POST("/submit", h.SubmitJob)
jobs.GET("", h.GetJobs)
jobs.GET("/history", h.GetJobHistory)
jobs.GET("/:id", h.GetJob)
jobs.DELETE("/:id", h.CancelJob)
}
return r
}
func setupJobHandler(mux *http.ServeMux) (*httptest.Server, *JobHandler) {
srv := httptest.NewServer(mux)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := service.NewJobService(client, zap.NewNop())
return srv, NewJobHandler(jobSvc, zap.NewNop())
}
func setupJobHandlerWithObserver(mux *http.ServeMux) (*httptest.Server, *JobHandler, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
srv := httptest.NewServer(mux)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := service.NewJobService(client, l)
return srv, NewJobHandler(jobSvc, l), recorded
}
func handlerLogs(logs *observer.ObservedLogs) []observer.LoggedEntry {
var handler []observer.LoggedEntry
for _, e := range logs.All() {
for _, f := range e.Context {
if f.Key == "method" {
handler = append(handler, e)
break
}
}
}
return handler
}
func TestSubmitJob_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: slurm.Ptr(int32(123))},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"script":"#!/bin/bash\necho hello"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if int(data["job_id"].(float64)) != 123 {
t.Errorf("expected job_id=123, got %v", data["job_id"])
}
}
func TestSubmitJob_MissingScript(t *testing.T) {
mux := http.NewServeMux()
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"partition":"normal"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["success"].(bool) {
t.Fatal("expected success=false")
}
if resp["error"] != "invalid request body" && resp["error"] != "script is required" {
t.Errorf("expected validation error, got %v", resp["error"])
}
}
func TestSubmitJob_EmptyScript(t *testing.T) {
mux := http.NewServeMux()
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"script":""}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["success"].(bool) {
t.Fatal("expected success=false")
}
}
func TestSubmitJob_SlurmError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
body := `{"script":"#!/bin/bash\necho hello"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetJobs_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("job1")},
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("job2")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs?page=1&page_size=10", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
jobs := data["jobs"].([]interface{})
if len(jobs) != 2 {
t.Fatalf("expected 2 jobs, got %d", len(jobs))
}
if int(data["total"].(float64)) != 2 {
t.Errorf("expected total=2, got %v", data["total"])
}
if int(data["page"].(float64)) != 1 {
t.Errorf("expected page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 10 {
t.Errorf("expected page_size=10, got %v", data["page_size"])
}
}
func TestGetJobs_Pagination(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("job1")},
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("job2")},
{JobID: slurm.Ptr(int32(3)), Name: slurm.Ptr("job3")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs?page=2&page_size=1", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
jobs := data["jobs"].([]interface{})
if len(jobs) != 1 {
t.Fatalf("expected 1 job on page 2, got %d", len(jobs))
}
if int(data["total"].(float64)) != 3 {
t.Errorf("expected total=3, got %v", data["total"])
}
jobData := jobs[0].(map[string]interface{})
if int(jobData["job_id"].(float64)) != 2 {
t.Errorf("expected job_id=2 on page 2, got %v", jobData["job_id"])
}
}
func TestGetJobs_DefaultPagination(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if int(data["page"].(float64)) != 1 {
t.Errorf("expected default page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 20 {
t.Errorf("expected default page_size=20, got %v", data["page_size"])
}
}
func TestGetJob_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{
{JobID: slurm.Ptr(int32(42)), Name: slurm.Ptr("test-job")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if int(data["job_id"].(float64)) != 42 {
t.Errorf("expected job_id=42, got %v", data["job_id"])
}
}
func TestGetJob_NotFound(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/999", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["error"].(string) != "job not found" {
t.Errorf("expected 'job not found' error, got %v", resp["error"])
}
}
func TestCancelJob_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiResp{})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["message"].(string) != "job cancelled" {
t.Errorf("expected 'job cancelled', got %v", data["message"])
}
}
func TestGetJobHistory_Success(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
Jobs: []slurm.Job{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("hist1")},
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("hist2")},
{JobID: slurm.Ptr(int32(3)), Name: slurm.Ptr("hist3")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=2", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
jobs := data["jobs"].([]interface{})
if len(jobs) != 2 {
t.Fatalf("expected 2 jobs on page 1, got %d", len(jobs))
}
if int(data["total"].(float64)) != 3 {
t.Errorf("expected total=3, got %v", data["total"])
}
if int(data["page"].(float64)) != 1 {
t.Errorf("expected page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 2 {
t.Errorf("expected page_size=2, got %v", data["page_size"])
}
}
func TestGetJobHistory_DefaultPagination(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
Jobs: []slurm.Job{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("h1")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if int(data["page"].(float64)) != 1 {
t.Errorf("expected default page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 20 {
t.Errorf("expected default page_size=20, got %v", data["page_size"])
}
}
func TestSubmitJob_InvalidBody(t *testing.T) {
mux := http.NewServeMux()
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`not json`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
// --- Logging verification tests ---
func TestSubmitJob_InvalidBody_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`not json`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
if entry.Context[0].Key != "method" || entry.Context[0].String != "SubmitJob" {
t.Errorf("expected method=SubmitJob, got %v", entry.Context[0])
}
}
func TestSubmitJob_EmptyScript_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":""}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
}
func TestSubmitJob_SlurmError_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":"#!/bin/bash\necho hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", entry.Level)
}
foundMethod := false
foundStatus := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "SubmitJob" {
foundMethod = true
}
if f.Key == "status" && f.Integer == http.StatusBadGateway {
foundStatus = true
}
}
if !foundMethod {
t.Error("expected method=SubmitJob in log fields")
}
if !foundStatus {
t.Error("expected status=502 in log fields")
}
}
func TestSubmitJob_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: slurm.Ptr(int32(123))},
})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodPost, "/api/v1/jobs/submit", bytes.NewBufferString(`{"script":"#!/bin/bash\necho hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}
func TestGetJobs_Error_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "GetJobs" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=GetJobs in log fields")
}
}
func TestGetJob_NotFound_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/999", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "GetJob" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=GetJob in log fields")
}
}
func TestGetJob_Error_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
if hLogs[0].Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", hLogs[0].Level)
}
}
func TestCancelJob_SlurmError_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "CancelJob" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=CancelJob in log fields")
}
}
func TestGetJobHistory_InvalidQuery_LogsWarn(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=abc", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
entry := hLogs[0]
if entry.Level != zapcore.WarnLevel {
t.Errorf("expected Warn level, got %v", entry.Level)
}
foundMethod := false
for _, f := range entry.Context {
if f.Key == "method" && f.String == "GetJobHistory" {
foundMethod = true
}
}
if !foundMethod {
t.Error("expected method=GetJobHistory in log fields")
}
}
func TestGetJobHistory_Error_LogsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=10", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 1 {
t.Fatalf("expected 1 handler log entry, got %d", len(hLogs))
}
if hLogs[0].Level != zapcore.ErrorLevel {
t.Errorf("expected Error level, got %v", hLogs[0].Level)
}
}
func TestGetJobHistory_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurmdb/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiSlurmdbdJobsResp{
Jobs: []slurm.Job{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("h1")},
},
})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs/history?page=1&page_size=10", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}
func TestGetJobs_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}
func TestCancelJob_Success_NoHandlerLogs(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/42", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiResp{})
})
srv, handler, logs := setupJobHandlerWithObserver(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/jobs/42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
hLogs := handlerLogs(logs)
if len(hLogs) != 0 {
t.Errorf("expected no handler log entries on success, got %d", len(hLogs))
}
}

View File

@@ -0,0 +1,98 @@
package handler
import (
"context"
"strings"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type taskServiceProvider interface {
SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error)
ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error)
}
type TaskHandler struct {
svc taskServiceProvider
logger *zap.Logger
}
func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler {
return &TaskHandler{svc: svc, logger: logger}
}
func (h *TaskHandler) CreateTask(c *gin.Context) {
var req model.CreateTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create task", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not found") {
h.logger.Warn("task submit target not found", zap.Error(err))
server.NotFound(c, errStr)
return
}
if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") {
h.logger.Warn("task submit validation failed", zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to create task", zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("task created", zap.Int64("id", taskID))
server.Created(c, gin.H{"id": taskID})
}
func (h *TaskHandler) ListTasks(c *gin.Context) {
var query model.TaskListQuery
_ = c.ShouldBindQuery(&query)
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 || query.PageSize > 100 {
query.PageSize = 10
}
tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query)
if err != nil {
h.logger.Error("failed to list tasks", zap.Error(err))
server.InternalError(c, err.Error())
return
}
responses := make([]model.TaskResponse, 0, len(tasks))
for i := range tasks {
responses = append(responses, model.TaskResponse{
ID: tasks[i].ID,
TaskName: tasks[i].TaskName,
AppID: tasks[i].AppID,
AppName: tasks[i].AppName,
Status: tasks[i].Status,
CurrentStep: tasks[i].CurrentStep,
RetryCount: tasks[i].RetryCount,
SlurmJobID: tasks[i].SlurmJobID,
WorkDir: tasks[i].WorkDir,
ErrorMessage: tasks[i].ErrorMessage,
CreatedAt: tasks[i].CreatedAt,
UpdatedAt: tasks[i].UpdatedAt,
})
}
server.OK(c, model.TaskListResponse{
Items: responses,
Total: total,
})
}

View File

@@ -0,0 +1,286 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
var taskDBCounter atomic.Int64
func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) {
t.Helper()
dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1)))
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
if err != nil {
t.Fatalf("open db: %v", err)
}
db.AutoMigrate(&model.Task{}, &model.Application{})
t.Cleanup(func() { os.Remove(dbFile) })
taskStore := store.NewTaskStore(db)
appStore := store.NewApplicationStore(db)
var jobSvc *service.JobService
if slurmSrv != nil {
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
jobSvc = service.NewJobService(client, zap.NewNop())
}
workDir := filepath.Join(t.TempDir(), "work")
taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop())
h := NewTaskHandler(taskSvc, zap.NewNop())
return h, db
}
func setupTaskRouter(h *TaskHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
tasks := v1.Group("/tasks")
tasks.POST("", h.CreateTask)
tasks.GET("", h.ListTasks)
return r
}
func createTestAppForTask(db *gorm.DB) int64 {
app := &model.Application{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\necho hello",
Parameters: json.RawMessage(`[]`),
}
db.Create(app)
return app.ID
}
// ---- CreateTask Tests ----
func TestTaskHandler_CreateTask_Success(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: "my-task",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if _, ok := data["id"]; !ok {
t.Fatal("expected id in response data")
}
}
func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) {
h, _ := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
body := `{"task_name":"no-app"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) {
h, _ := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json")))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
// ---- ListTasks Tests ----
func TestTaskHandler_ListTasks_Pagination(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
for i := 0; i < 5; i++ {
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: fmt.Sprintf("task-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
// Wait for async processing
time.Sleep(200 * time.Millisecond)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 5 {
t.Fatalf("expected total=5, got %v", data["total"])
}
items := data["items"].([]interface{})
if len(items) != 3 {
t.Fatalf("expected 3 items, got %d", len(items))
}
}
func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
for i := 0; i < 3; i++ {
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: fmt.Sprintf("filter-task-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
// Wait for async processing
time.Sleep(200 * time.Millisecond)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
items := data["items"].([]interface{})
for _, item := range items {
m := item.(map[string]interface{})
if m["status"] != "queued" {
t.Fatalf("expected status=queued, got %v", m["status"])
}
}
}
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
h, db := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
_ = createTestAppForTask(db)
// Directly insert tasks via DB to avoid needing processor
for i := 0; i < 15; i++ {
task := &model.Task{
TaskName: fmt.Sprintf("default-task-%d", i),
AppID: 1,
AppName: "test-app",
Status: model.TaskStatusSubmitted,
SubmittedAt: time.Now(),
}
db.Create(task)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 15 {
t.Fatalf("expected total=15, got %v", data["total"])
}
items := data["items"].([]interface{})
if len(items) != 10 {
t.Fatalf("expected 10 items (default page_size), got %d", len(items))
}
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"context"
"io"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// uploadServiceProvider defines the interface for upload operations.
type uploadServiceProvider interface {
InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error)
UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error
CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error)
GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error)
CancelUpload(ctx context.Context, sessionID int64) error
}
// UploadHandler handles HTTP requests for chunked file uploads.
type UploadHandler struct {
svc uploadServiceProvider
logger *zap.Logger
}
// NewUploadHandler creates a new UploadHandler.
func NewUploadHandler(svc *service.UploadService, logger *zap.Logger) *UploadHandler {
return &UploadHandler{svc: svc, logger: logger}
}
// InitUpload initiates a new chunked upload session.
// POST /api/v1/files/uploads
func (h *UploadHandler) InitUpload(c *gin.Context) {
var req model.InitUploadRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for init upload", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
result, err := h.svc.InitUpload(c.Request.Context(), req)
if err != nil {
h.logger.Error("failed to init upload", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
switch resp := result.(type) {
case model.FileResponse:
server.OK(c, resp)
case model.UploadSessionResponse:
server.Created(c, resp)
default:
server.Created(c, resp)
}
}
// UploadChunk uploads a single chunk of an upload session.
// PUT /api/v1/files/uploads/:id/chunks/:index
func (h *UploadHandler) UploadChunk(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
chunkIndex, err := strconv.Atoi(c.Param("index"))
if err != nil {
h.logger.Warn("invalid chunk index", zap.String("index", c.Param("index")))
server.BadRequest(c, "invalid chunk index")
return
}
file, header, err := c.Request.FormFile("chunk")
if err != nil {
h.logger.Warn("missing chunk file in request", zap.Error(err))
server.BadRequest(c, "missing chunk file")
return
}
defer file.Close()
if err := h.svc.UploadChunk(c.Request.Context(), sessionID, chunkIndex, file, header.Size); err != nil {
h.logger.Error("failed to upload chunk", zap.Int64("session_id", sessionID), zap.Int("chunk_index", chunkIndex), zap.Error(err))
server.BadRequest(c, err.Error())
return
}
server.OK(c, gin.H{"message": "chunk uploaded"})
}
// CompleteUpload finalizes an upload session and assembles the file.
// POST /api/v1/files/uploads/:id/complete
func (h *UploadHandler) CompleteUpload(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id for complete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
resp, err := h.svc.CompleteUpload(c.Request.Context(), sessionID)
if err != nil {
h.logger.Error("failed to complete upload", zap.Int64("session_id", sessionID), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.Created(c, resp)
}
// GetUploadStatus returns the current status of an upload session.
// GET /api/v1/files/uploads/:id
func (h *UploadHandler) GetUploadStatus(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id for status", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
resp, err := h.svc.GetUploadStatus(c.Request.Context(), sessionID)
if err != nil {
h.logger.Error("failed to get upload status", zap.Int64("session_id", sessionID), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}
// CancelUpload cancels and cleans up an upload session.
// DELETE /api/v1/files/uploads/:id
func (h *UploadHandler) CancelUpload(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id for cancel", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
if err := h.svc.CancelUpload(c.Request.Context(), sessionID); err != nil {
h.logger.Error("failed to cancel upload", zap.Int64("session_id", sessionID), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, gin.H{"message": "upload cancelled"})
}

View File

@@ -0,0 +1,307 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"testing"
"time"
"gcy_hpc_server/internal/model"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type mockUploadService struct {
initUploadFn func(ctx context.Context, req model.InitUploadRequest) (interface{}, error)
uploadChunkFn func(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error
completeUploadFn func(ctx context.Context, sessionID int64) (*model.FileResponse, error)
getUploadStatusFn func(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error)
cancelUploadFn func(ctx context.Context, sessionID int64) error
}
func (m *mockUploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return m.initUploadFn(ctx, req)
}
func (m *mockUploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
return m.uploadChunkFn(ctx, sessionID, chunkIndex, reader, size)
}
func (m *mockUploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
return m.completeUploadFn(ctx, sessionID)
}
func (m *mockUploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
return m.getUploadStatusFn(ctx, sessionID)
}
func (m *mockUploadService) CancelUpload(ctx context.Context, sessionID int64) error {
return m.cancelUploadFn(ctx, sessionID)
}
func setupUploadHandler(mock uploadServiceProvider) (*UploadHandler, *gin.Engine) {
gin.SetMode(gin.TestMode)
h := &UploadHandler{svc: mock, logger: zap.NewNop()}
r := gin.New()
v1 := r.Group("/api/v1")
uploads := v1.Group("/files/uploads")
uploads.POST("", h.InitUpload)
uploads.PUT("/:id/chunks/:index", h.UploadChunk)
uploads.POST("/:id/complete", h.CompleteUpload)
uploads.GET("/:id", h.GetUploadStatus)
uploads.DELETE("/:id", h.CancelUpload)
return h, r
}
func TestInitUpload_NewSession(t *testing.T) {
mock := &mockUploadService{
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return model.UploadSessionResponse{
ID: 1,
FileName: req.FileName,
FileSize: req.FileSize,
ChunkSize: 5 * 1024 * 1024,
TotalChunks: 2,
SHA256: req.SHA256,
Status: "pending",
UploadedChunks: []int{},
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
body, _ := json.Marshal(model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 * 1024 * 1024,
SHA256: "abc123",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["status"] != "pending" {
t.Fatalf("expected status=pending, got %v", data["status"])
}
}
func TestInitUpload_DedupHit(t *testing.T) {
mock := &mockUploadService{
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return model.FileResponse{
ID: 42,
Name: req.FileName,
Size: req.FileSize,
SHA256: req.SHA256,
MimeType: "application/octet-stream",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
body, _ := json.Marshal(model.InitUploadRequest{
FileName: "existing.bin",
FileSize: 1024,
SHA256: "existinghash",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["sha256"] != "existinghash" {
t.Fatalf("expected sha256=existinghash, got %v", data["sha256"])
}
}
func TestInitUpload_MissingFields(t *testing.T) {
mock := &mockUploadService{
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return nil, nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader([]byte(`{}`)))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUploadChunk_Success(t *testing.T) {
var receivedSessionID int64
var receivedChunkIndex int
mock := &mockUploadService{
uploadChunkFn: func(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
receivedSessionID = sessionID
receivedChunkIndex = chunkIndex
readBytes, _ := io.ReadAll(reader)
if len(readBytes) == 0 {
t.Error("expected non-empty chunk data")
}
return nil
},
}
_, r := setupUploadHandler(mock)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, _ := writer.CreateFormFile("chunk", "test.bin")
fmt.Fprintf(part, "chunk data here")
writer.Close()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/files/uploads/1/chunks/0", &buf)
req.Header.Set("Content-Type", writer.FormDataContentType())
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if receivedSessionID != 1 {
t.Fatalf("expected session_id=1, got %d", receivedSessionID)
}
if receivedChunkIndex != 0 {
t.Fatalf("expected chunk_index=0, got %d", receivedChunkIndex)
}
}
func TestCompleteUpload_Success(t *testing.T) {
mock := &mockUploadService{
completeUploadFn: func(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
return &model.FileResponse{
ID: 10,
Name: "completed.bin",
Size: 2048,
SHA256: "completehash",
MimeType: "application/octet-stream",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads/5/complete", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["name"] != "completed.bin" {
t.Fatalf("expected name=completed.bin, got %v", data["name"])
}
}
func TestGetUploadStatus_Success(t *testing.T) {
mock := &mockUploadService{
getUploadStatusFn: func(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
return &model.UploadSessionResponse{
ID: sessionID,
FileName: "status.bin",
FileSize: 4096,
ChunkSize: 2048,
TotalChunks: 2,
SHA256: "statushash",
Status: "uploading",
UploadedChunks: []int{0},
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/uploads/3", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["status"] != "uploading" {
t.Fatalf("expected status=uploading, got %v", data["status"])
}
}
func TestCancelUpload_Success(t *testing.T) {
var receivedSessionID int64
mock := &mockUploadService{
cancelUploadFn: func(ctx context.Context, sessionID int64) error {
receivedSessionID = sessionID
return nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/uploads/7", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if receivedSessionID != 7 {
t.Fatalf("expected session_id=7, got %d", receivedSessionID)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["message"] != "upload cancelled" {
t.Fatalf("expected message='upload cancelled', got %v", data["message"])
}
}

148
internal/logger/gorm.go Normal file
View File

@@ -0,0 +1,148 @@
package logger
import (
"context"
"errors"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
const slowQueryThreshold = 200 * time.Millisecond
// GormLogger implements gorm's logger.Interface backed by zap.
type GormLogger struct {
logger *zap.Logger
level zapcore.Level
silent bool
}
// Compile-time interface check.
var _ gormlogger.Interface = (*GormLogger)(nil)
// NewGormLogger creates a new GormLogger wrapping the given zap logger.
// The level string maps to zap levels; empty defaults to "warn".
// The special value "silent" suppresses all output.
func NewGormLogger(zapLogger *zap.Logger, level string) gormlogger.Interface {
lvl := parseGormLevel(level)
silent := level == "silent"
return &GormLogger{
logger: zapLogger,
level: lvl,
silent: silent,
}
}
// LogMode returns a new GormLogger with the given gorm log level.
// It does NOT mutate the receiver.
func (l *GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
newLogger := &GormLogger{
logger: l.logger,
level: l.level,
silent: l.silent,
}
switch level {
case gormlogger.Silent:
newLogger.silent = true
case gormlogger.Error:
newLogger.level = zapcore.ErrorLevel
newLogger.silent = false
case gormlogger.Warn:
newLogger.level = zapcore.WarnLevel
newLogger.silent = false
case gormlogger.Info:
newLogger.level = zapcore.InfoLevel
newLogger.silent = false
}
return newLogger
}
// Info logs at zap.InfoLevel with structured fields from key-value pairs.
func (l *GormLogger) Info(ctx context.Context, msg string, args ...any) {
if l.silent || l.level > zapcore.InfoLevel {
return
}
l.logger.Info(msg, argsToFields(args)...)
}
// Warn logs at zap.WarnLevel with structured fields from key-value pairs.
func (l *GormLogger) Warn(ctx context.Context, msg string, args ...any) {
if l.silent || l.level > zapcore.WarnLevel {
return
}
l.logger.Warn(msg, argsToFields(args)...)
}
// Error logs at zap.ErrorLevel with structured fields from key-value pairs.
func (l *GormLogger) Error(ctx context.Context, msg string, args ...any) {
if l.silent || l.level > zapcore.ErrorLevel {
return
}
l.logger.Error(msg, argsToFields(args)...)
}
// Trace logs SQL query information based on execution results.
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
if l.silent {
return
}
elapsed := time.Since(begin)
sql, rows := fc()
switch {
case err != nil && !errors.Is(err, gorm.ErrRecordNotFound):
l.logger.Error("gorm query error",
zap.Error(err),
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed),
)
case elapsed > slowQueryThreshold:
l.logger.Warn("gorm slow query",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Float64("elapsed_ms", float64(elapsed.Nanoseconds())/1e6),
)
default:
if l.level > zapcore.InfoLevel {
return
}
l.logger.Info("gorm query",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed),
)
}
}
// parseGormLevel parses a level string into a zapcore.Level.
// Defaults to zapcore.WarnLevel.
func parseGormLevel(level string) zapcore.Level {
if level == "" || level == "silent" {
return zapcore.WarnLevel
}
var lvl zapcore.Level
if err := lvl.UnmarshalText([]byte(level)); err != nil {
return zapcore.WarnLevel
}
return lvl
}
// argsToFields converts alternating key-value pairs into zap fields.
func argsToFields(args []any) []zap.Field {
fields := make([]zap.Field, 0, len(args)/2)
for i := 0; i+1 < len(args); i += 2 {
key, ok := args[i].(string)
if !ok {
continue
}
fields = append(fields, zap.Any(key, args[i+1]))
}
return fields
}

View File

@@ -0,0 +1,509 @@
package logger
import (
"context"
"errors"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
// newObservedLogger creates a zap logger backed by an observer for test assertions.
func newObservedLogger() (*zap.Logger, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
return zap.New(core), recorded
}
func TestNewGormLogger_DefaultLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
if g.level != zapcore.WarnLevel {
t.Fatalf("expected warn level, got %v", g.level)
}
if g.silent {
t.Fatal("expected silent=false")
}
}
func TestNewGormLogger_Silent(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "silent")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
if !g.silent {
t.Fatal("expected silent=true")
}
}
func TestNewGormLogger_ExplicitLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "error")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
if g.level != zapcore.ErrorLevel {
t.Fatalf("expected error level, got %v", g.level)
}
}
func TestNewGormLogger_InvalidLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "bogus")
g, ok := gl.(*GormLogger)
if !ok {
t.Fatal("expected *GormLogger")
}
// Should default to warn
if g.level != zapcore.WarnLevel {
t.Fatalf("expected warn level fallback, got %v", g.level)
}
}
func TestGormLogger_Info(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
gl.Info(context.Background(), "test info", "key", "value")
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Message != "test info" {
t.Fatalf("expected message 'test info', got %q", entries[0].Message)
}
if entries[0].Level != zapcore.InfoLevel {
t.Fatalf("expected InfoLevel, got %v", entries[0].Level)
}
}
func TestGormLogger_Warn(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
gl.Warn(context.Background(), "test warn", "code", 42)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Message != "test warn" {
t.Fatalf("expected message 'test warn', got %q", entries[0].Message)
}
if entries[0].Level != zapcore.WarnLevel {
t.Fatalf("expected WarnLevel, got %v", entries[0].Level)
}
}
func TestGormLogger_Error(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "error")
gl.Error(context.Background(), "test error", "module", "gorm")
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
}
if entries[0].Message != "test error" {
t.Fatalf("expected message 'test error', got %q", entries[0].Message)
}
if entries[0].Level != zapcore.ErrorLevel {
t.Fatalf("expected ErrorLevel, got %v", entries[0].Level)
}
}
func TestGormLogger_LevelFiltering(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
// Info should be suppressed at warn level
gl.Info(context.Background(), "should be suppressed")
if len(recorded.All()) != 0 {
t.Fatal("info should be suppressed at warn level")
}
// Warn should pass
gl.Warn(context.Background(), "should pass")
if len(recorded.All()) != 1 {
t.Fatal("warn should pass at warn level")
}
}
func TestGormLogger_SilentSuppressesAll(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "silent")
gl.Info(context.Background(), "info msg")
gl.Warn(context.Background(), "warn msg")
gl.Error(context.Background(), "error msg")
if len(recorded.All()) != 0 {
t.Fatalf("silent mode should suppress all logs, got %d", len(recorded.All()))
}
}
func TestGormLogger_LogMode_ReturnsNewInstance(t *testing.T) {
log, _ := newObservedLogger()
original := NewGormLogger(log, "warn").(*GormLogger)
modified := original.LogMode(gormlogger.Info)
// Must be a different instance
if modified == original {
t.Fatal("LogMode should return a new instance")
}
// Original should be unchanged
if original.level != zapcore.WarnLevel {
t.Fatalf("original level should remain warn, got %v", original.level)
}
// New instance should have InfoLevel
mod := modified.(*GormLogger)
if mod.level != zapcore.InfoLevel {
t.Fatalf("modified level should be info, got %v", mod.level)
}
}
func TestGormLogger_LogMode_Silent(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "info").(*GormLogger)
silent := gl.LogMode(gormlogger.Silent).(*GormLogger)
if !silent.silent {
t.Fatal("LogMode(Silent) should set silent=true")
}
// Original should not be silent
if gl.silent {
t.Fatal("original should not be affected")
}
}
func TestGormLogger_LogMode_ErrorLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "info").(*GormLogger)
modified := gl.LogMode(gormlogger.Error).(*GormLogger)
if modified.level != zapcore.ErrorLevel {
t.Fatalf("expected ErrorLevel, got %v", modified.level)
}
}
func TestGormLogger_LogMode_WarnLevel(t *testing.T) {
log, _ := newObservedLogger()
gl := NewGormLogger(log, "info").(*GormLogger)
modified := gl.LogMode(gormlogger.Warn).(*GormLogger)
if modified.level != zapcore.WarnLevel {
t.Fatalf("expected WarnLevel, got %v", modified.level)
}
}
func TestGormLogger_Trace_WithError(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users", 0 }
err := errors.New("connection refused")
gl.Trace(context.Background(), begin, fc, err)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Fatalf("expected ErrorLevel for real errors, got %v", entries[0].Level)
}
if entries[0].Message != "gorm query error" {
t.Fatalf("unexpected message: %q", entries[0].Message)
}
}
func TestGormLogger_Trace_ErrRecordNotFound(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 }
gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound)
// ErrRecordNotFound should NOT be logged as error
// At default level with no slow query, it should log as info
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level == zapcore.ErrorLevel {
t.Fatal("ErrRecordNotFound should NOT be logged at ErrorLevel")
}
if entries[0].Message != "gorm query" {
t.Fatalf("expected 'gorm query', got %q", entries[0].Message)
}
}
func TestGormLogger_Trace_ErrRecordNotFound_SuppressedAtWarnLevel(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users WHERE id = ?", 0 }
gl.Trace(context.Background(), begin, fc, gorm.ErrRecordNotFound)
// ErrRecordNotFound is not a real error, so it falls through to the default path.
// At warn level, the default path (info-level) is suppressed.
entries := recorded.All()
if len(entries) != 0 {
t.Fatalf("expected 0 entries at warn level, got %d", len(entries))
}
}
func TestGormLogger_Trace_SlowQuery(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
// Simulate a begin time far enough in the past to exceed the threshold
begin := time.Now().Add(-500 * time.Millisecond)
fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level != zapcore.WarnLevel {
t.Fatalf("expected WarnLevel for slow query, got %v", entries[0].Level)
}
if entries[0].Message != "gorm slow query" {
t.Fatalf("expected 'gorm slow query', got %q", entries[0].Message)
}
}
func TestGormLogger_Trace_NormalQuery(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
begin := time.Now()
fc := func() (string, int64) { return "SELECT 1", 1 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Level != zapcore.InfoLevel {
t.Fatalf("expected InfoLevel for normal query, got %v", entries[0].Level)
}
if entries[0].Message != "gorm query" {
t.Fatalf("expected 'gorm query', got %q", entries[0].Message)
}
}
func TestGormLogger_Trace_NormalQuerySuppressedAtWarnLevel(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "SELECT 1", 1 }
gl.Trace(context.Background(), begin, fc, nil)
// Normal queries are info-level, suppressed at warn
if len(recorded.All()) != 0 {
t.Fatal("normal query should be suppressed at warn level")
}
}
func TestGormLogger_Trace_SilentSuppressesAll(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "silent")
begin := time.Now().Add(-500 * time.Millisecond)
fc := func() (string, int64) { return "SELECT SLEEP(1)", 1 }
err := errors.New("some error")
gl.Trace(context.Background(), begin, fc, err)
if len(recorded.All()) != 0 {
t.Fatal("silent mode should suppress even Trace errors")
}
}
func TestGormLogger_Trace_StructuredFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
begin := time.Now()
fc := func() (string, int64) { return "SELECT * FROM users", 42 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
// Verify structured fields
fields := entries[0].ContextMap()
if fields["sql"] != "SELECT * FROM users" {
t.Fatalf("expected sql field, got %v", fields["sql"])
}
if fields["rows"] != int64(42) {
t.Fatalf("expected rows=42, got %v", fields["rows"])
}
if _, ok := fields["elapsed"]; !ok {
t.Fatal("expected elapsed field")
}
}
func TestGormLogger_Trace_ErrorFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now()
fc := func() (string, int64) { return "INSERT INTO users", 0 }
err := errors.New("duplicate key")
gl.Trace(context.Background(), begin, fc, err)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["sql"] != "INSERT INTO users" {
t.Fatalf("expected sql field, got %v", fields["sql"])
}
if _, ok := fields["error"]; !ok {
t.Fatal("expected error field")
}
}
func TestGormLogger_Trace_SlowQueryFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "warn")
begin := time.Now().Add(-500 * time.Millisecond)
fc := func() (string, int64) { return "SELECT * FROM large_table", 1000 }
gl.Trace(context.Background(), begin, fc, nil)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["sql"] != "SELECT * FROM large_table" {
t.Fatalf("expected sql field, got %v", fields["sql"])
}
if fields["rows"] != int64(1000) {
t.Fatalf("expected rows=1000, got %v", fields["rows"])
}
if _, ok := fields["elapsed_ms"]; !ok {
t.Fatal("expected elapsed_ms field for slow query")
}
}
func TestArgsToFields(t *testing.T) {
tests := []struct {
name string
args []interface{}
expected int
}{
{"empty", nil, 0},
{"single_pair", []interface{}{"key", "value"}, 1},
{"two_pairs", []interface{}{"a", 1, "b", 2}, 2},
{"odd_args_ignores_last", []interface{}{"key", "value", "orphan"}, 1},
{"non_string_key_ignored", []interface{}{123, "value"}, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fields := argsToFields(tt.args)
if len(fields) != tt.expected {
t.Fatalf("expected %d fields, got %d", tt.expected, len(fields))
}
})
}
}
func TestArgsToFields_FieldValues(t *testing.T) {
fields := argsToFields([]interface{}{"name", "test", "count", 42})
if len(fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(fields))
}
if fields[0].Key != "name" {
t.Fatalf("expected key 'name', got %q", fields[0].Key)
}
if fields[1].Key != "count" {
t.Fatalf("expected key 'count', got %q", fields[1].Key)
}
}
func TestParseGormLevel(t *testing.T) {
tests := []struct {
input string
expected zapcore.Level
}{
{"debug", zapcore.DebugLevel},
{"info", zapcore.InfoLevel},
{"warn", zapcore.WarnLevel},
{"error", zapcore.ErrorLevel},
{"", zapcore.WarnLevel},
{"silent", zapcore.WarnLevel},
{"invalid", zapcore.WarnLevel},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := parseGormLevel(tt.input)
if got != tt.expected {
t.Fatalf("expected %v, got %v", tt.expected, got)
}
})
}
}
func TestGormLogger_InfoWithMultipleFields(t *testing.T) {
log, recorded := newObservedLogger()
gl := NewGormLogger(log, "info")
gl.Info(context.Background(), "multi fields", "key1", "val1", "key2", 123, "key3", true)
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["key1"] != "val1" {
t.Fatalf("expected key1=val1, got %v", fields["key1"])
}
if fields["key2"] != int64(123) {
t.Fatalf("expected key2=123, got %v", fields["key2"])
}
if fields["key3"] != true {
t.Fatalf("expected key3=true, got %v", fields["key3"])
}
}

93
internal/logger/logger.go Normal file
View File

@@ -0,0 +1,93 @@
package logger
import (
"fmt"
"os"
"gcy_hpc_server/internal/config"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
func NewLogger(cfg config.LogConfig) (*zap.Logger, error) {
level := applyDefault(cfg.Level, "info")
var zapLevel zapcore.Level
if err := zapLevel.UnmarshalText([]byte(level)); err != nil {
return nil, fmt.Errorf("invalid log level %q: %w", level, err)
}
encoding := applyDefault(cfg.Encoding, "json")
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "ts"
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
var encoder zapcore.Encoder
switch encoding {
case "console":
encoder = zapcore.NewConsoleEncoder(encoderConfig)
default:
encoder = zapcore.NewJSONEncoder(encoderConfig)
}
var syncers []zapcore.WriteSyncer
stdout := true
if cfg.OutputStdout != nil {
stdout = *cfg.OutputStdout
}
if stdout {
syncers = append(syncers, zapcore.AddSync(os.Stdout))
}
if cfg.FilePath != "" {
maxSize := applyDefaultInt(cfg.MaxSize, 100)
maxBackups := applyDefaultInt(cfg.MaxBackups, 5)
maxAge := applyDefaultInt(cfg.MaxAge, 30)
compress := cfg.Compress || (cfg.MaxSize == 0 && cfg.MaxBackups == 0 && cfg.MaxAge == 0)
lj := &lumberjack.Logger{
Filename: cfg.FilePath,
MaxSize: maxSize,
MaxBackups: maxBackups,
MaxAge: maxAge,
Compress: compress,
}
syncers = append(syncers, zapcore.AddSync(lj))
}
if len(syncers) == 0 {
syncers = append(syncers, zapcore.AddSync(os.Stdout))
}
writeSyncer := syncers[0]
if len(syncers) > 1 {
writeSyncer = zapcore.NewMultiWriteSyncer(syncers...)
}
core := zapcore.NewCore(encoder, writeSyncer, zapLevel)
opts := []zap.Option{
zap.AddCaller(),
zap.AddStacktrace(zapcore.ErrorLevel),
}
return zap.New(core, opts...), nil
}
func applyDefault(val, def string) string {
if val == "" {
return def
}
return val
}
func applyDefaultInt(val, def int) int {
if val == 0 {
return def
}
return val
}

View File

@@ -0,0 +1,286 @@
package logger
import (
"os"
"path/filepath"
"strings"
"testing"
"gcy_hpc_server/internal/config"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest"
)
func ptrBool(v bool) *bool { return &v }
// TestNewLogger_JSONConfig creates a logger with JSON encoding and verifies
// that log entries are emitted successfully.
func TestNewLogger_JSONConfig(t *testing.T) {
cfg := config.LogConfig{
Level: "debug",
Encoding: "json",
OutputStdout: ptrBool(true),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
// Should not panic when logging
log.Info("json logger test", zap.String("key", "value"))
}
// TestNewLogger_ConsoleConfig creates a logger with console encoding.
func TestNewLogger_ConsoleConfig(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "console",
OutputStdout: ptrBool(true),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("console logger test", zap.Int("num", 42))
}
// TestNewLogger_InvalidLevel verifies that an invalid log level returns an error.
func TestNewLogger_InvalidLevel(t *testing.T) {
cfg := config.LogConfig{
Level: "bogus",
Encoding: "json",
}
_, err := NewLogger(cfg)
if err == nil {
t.Fatal("expected error for invalid log level, got nil")
}
}
// TestNewLogger_EmptyConfig verifies defaults are applied when config is zero-value.
func TestNewLogger_EmptyConfig(t *testing.T) {
cfg := config.LogConfig{} // all zero values
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("default config test")
}
// TestNewLogger_FileOutput verifies that file output with rotation config works.
func TestNewLogger_FileOutput(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "test.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
FilePath: logFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("file output test", zap.String("msg", "hello"))
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if len(data) == 0 {
t.Fatal("log file is empty, expected output")
}
if !strings.Contains(string(data), "file output test") {
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
}
}
// TestNewLogger_MultiWriter verifies that both stdout and file output work together.
func TestNewLogger_MultiWriter(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "multi.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
OutputStdout: ptrBool(true),
FilePath: logFile,
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("multi writer test", zap.String("writer", "both"))
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if !strings.Contains(string(data), "multi writer test") {
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
}
}
// TestNewLogger_Observer verifies actual log output content using zaptest.
func TestNewLogger_Observer(t *testing.T) {
// Use zaptest.NewLogger to capture logs in test output
log := zaptest.NewLogger(t,
zaptest.WrapOptions(zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)),
zaptest.Level(zapcore.DebugLevel),
)
// These should all succeed without panicking
log.Debug("debug msg", zap.String("k", "v"))
log.Info("info msg", zap.Int("n", 1))
log.Warn("warn msg")
log.Error("error msg")
}
// TestNewLogger_AllLevels verifies all valid log levels parse correctly.
func TestNewLogger_AllLevels(t *testing.T) {
levels := []string{"debug", "info", "warn", "error", "dpanic", "panic", "fatal"}
for _, level := range levels {
t.Run(level, func(t *testing.T) {
cfg := config.LogConfig{
Level: level,
Encoding: "json",
OutputStdout: ptrBool(true),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("level %q: NewLogger returned error: %v", level, err)
}
log.Sync()
})
}
}
// TestNewLogger_InvalidEncoding falls back gracefully — the factory should
// treat an unrecognized encoding as an error or default to JSON.
func TestNewLogger_InvalidEncoding(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "xml",
OutputStdout: ptrBool(true),
}
// The implementation should default to JSON for unknown encoding.
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("unexpected error for invalid encoding: %v", err)
}
defer log.Sync()
log.Info("invalid encoding test")
}
// TestNewLogger_DefaultRotation verifies rotation defaults are applied.
func TestNewLogger_DefaultRotation(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "rotation.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
FilePath: logFile,
// MaxSize, MaxBackups, MaxAge, Compress all zero → defaults apply
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("rotation defaults test")
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if len(data) == 0 {
t.Fatal("log file is empty")
}
}
func TestNewLogger_OutputStdoutNil(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("default stdout test")
}
func TestNewLogger_OutputStdoutFalseWithFile(t *testing.T) {
tmpDir := t.TempDir()
logFile := filepath.Join(tmpDir, "nostdout.log")
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
OutputStdout: ptrBool(false),
FilePath: logFile,
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
log.Info("file only test")
log.Sync()
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
if !strings.Contains(string(data), "file only test") {
t.Fatalf("log file content does not contain expected message;\ngot: %s", string(data))
}
}
func TestNewLogger_OutputStdoutFalseFallback(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Encoding: "json",
OutputStdout: ptrBool(false),
}
log, err := NewLogger(cfg)
if err != nil {
t.Fatalf("NewLogger returned error: %v", err)
}
defer log.Sync()
log.Info("fallback stdout test")
}

View File

@@ -0,0 +1,25 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RequestLogger returns a Gin middleware that logs each request using zap.
func RequestLogger(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request",
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.Int("status", c.Writer.Status()),
zap.Duration("latency", time.Since(start)),
zap.String("client_ip", c.ClientIP()),
)
}
}

View File

@@ -0,0 +1,76 @@
package model
import (
"encoding/json"
"time"
)
// Parameter type constants for ParameterSchema.Type.
const (
ParamTypeString = "string"
ParamTypeInteger = "integer"
ParamTypeEnum = "enum"
ParamTypeFile = "file"
ParamTypeDirectory = "directory"
ParamTypeBoolean = "boolean"
)
// Application represents a parameterized application definition for HPC job submission.
type Application struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"uniqueIndex;size:255;not null" json:"name"`
Description string `gorm:"type:text" json:"description,omitempty"`
Icon string `gorm:"size:255" json:"icon,omitempty"`
Category string `gorm:"size:255" json:"category,omitempty"`
ScriptTemplate string `gorm:"type:text;not null" json:"script_template"`
Parameters json.RawMessage `gorm:"type:json" json:"parameters,omitempty"`
Scope string `gorm:"size:50;default:'system'" json:"scope,omitempty"`
CreatedBy int64 `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (Application) TableName() string {
return "hpc_applications"
}
// ParameterSchema defines a single parameter in an application's form schema.
type ParameterSchema struct {
Name string `json:"name"`
Label string `json:"label,omitempty"`
Type string `json:"type"`
Required bool `json:"required,omitempty"`
Default string `json:"default,omitempty"`
Options []string `json:"options,omitempty"`
Description string `json:"description,omitempty"`
}
// CreateApplicationRequest is the DTO for creating a new application.
type CreateApplicationRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
Category string `json:"category,omitempty"`
ScriptTemplate string `json:"script_template" binding:"required"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Scope string `json:"scope,omitempty"`
}
// UpdateApplicationRequest is the DTO for updating an existing application.
// All fields are optional. Parameters uses *json.RawMessage to distinguish
// between "not provided" (nil) and "set to empty" (non-nil).
type UpdateApplicationRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Icon *string `json:"icon,omitempty"`
Category *string `json:"category,omitempty"`
ScriptTemplate *string `json:"script_template,omitempty"`
Parameters *json.RawMessage `json:"parameters,omitempty"`
Scope *string `json:"scope,omitempty"`
}
// ApplicationSubmitRequest is the DTO for submitting a job from an application.
// ApplicationID is parsed from the URL :id parameter, not included in the body.
type ApplicationSubmitRequest struct {
Values map[string]string `json:"values" binding:"required"`
}

73
internal/model/cluster.go Normal file
View File

@@ -0,0 +1,73 @@
package model
// NodeResponse is the API response for a node.
type NodeResponse struct {
// Identity
Name string `json:"name"` // 节点主机名
State []string `json:"state"` // 节点状态 (e.g. ["IDLE"], ["ALLOCATED","COMPLETING"])
Reason string `json:"reason,omitempty"` // 节点 DOWN/DRAIN 的原因
ReasonSetByUser string `json:"reason_set_by_user,omitempty"` // 设置原因的用户
// CPU Resources
CPUs int32 `json:"cpus"` // 总 CPU 核数
AllocCpus *int32 `json:"alloc_cpus,omitempty"` // 已分配 CPU 核数
Cores *int32 `json:"cores,omitempty"` // 物理核心数
Sockets *int32 `json:"sockets,omitempty"` // CPU 插槽数
Threads *int32 `json:"threads,omitempty"` // 每核线程数
CpuLoad *int32 `json:"cpu_load,omitempty"` // CPU 负载 (内核 nice 值乘以 100)
// Memory (MiB)
RealMemory int64 `json:"real_memory"` // 物理内存总量
AllocMemory int64 `json:"alloc_memory,omitempty"` // 已分配内存
FreeMem *int64 `json:"free_mem,omitempty"` // 空闲内存
// Hardware
Arch string `json:"architecture,omitempty"` // 系统架构 (e.g. x86_64)
OS string `json:"operating_system,omitempty"` // 操作系统版本
Gres string `json:"gres,omitempty"` // 可用通用资源 (e.g. "gpu:4")
GresUsed string `json:"gres_used,omitempty"` // 已使用的通用资源 (e.g. "gpu:2")
// Network
Address string `json:"address,omitempty"` // 节点地址 (IP)
Hostname string `json:"hostname,omitempty"` // 节点主机名 (可能与 Name 不同)
// Scheduling
Weight *int32 `json:"weight,omitempty"` // 调度权重
Features string `json:"features,omitempty"` // 节点特性标签 (可修改)
ActiveFeatures string `json:"active_features,omitempty"` // 当前生效的特性标签 (只读)
}
// PartitionResponse is the API response for a partition.
type PartitionResponse struct {
// Identity
Name string `json:"name"` // 分区名称
State []string `json:"state"` // 分区状态 (e.g. ["UP"], ["DOWN","DRAIN"])
Default bool `json:"default,omitempty"` // 是否为默认分区
// Nodes
Nodes string `json:"nodes,omitempty"` // 分区包含的节点列表
TotalNodes int32 `json:"total_nodes,omitempty"` // 节点总数
// CPUs
TotalCPUs int32 `json:"total_cpus,omitempty"` // CPU 总核数
MaxCPUsPerNode *int32 `json:"max_cpus_per_node,omitempty"` // 每节点最大 CPU 核数
// Limits
MaxTime string `json:"max_time,omitempty"` // 最大运行时间 (分钟,"UNLIMITED" 表示无限)
MaxNodes *int32 `json:"max_nodes,omitempty"` // 单作业最大节点数
MinNodes *int32 `json:"min_nodes,omitempty"` // 单作业最小节点数
DefaultTime string `json:"default_time,omitempty"` // 默认运行时间限制
GraceTime *int32 `json:"grace_time,omitempty"` // 作业抢占后的宽限时间 (秒)
// Priority
Priority *int32 `json:"priority,omitempty"` // 分区内作业优先级因子
// Access Control - QOS
QOSAllowed string `json:"qos_allowed,omitempty"` // 允许使用的 QOS 列表
QOSDeny string `json:"qos_deny,omitempty"` // 禁止使用的 QOS 列表
QOSAssigned string `json:"qos_assigned,omitempty"` // 分区默认分配的 QOS
// Access Control - Accounts
AccountsAllowed string `json:"accounts_allowed,omitempty"` // 允许使用的账户列表
AccountsDeny string `json:"accounts_deny,omitempty"` // 禁止使用的账户列表
}

193
internal/model/file.go Normal file
View File

@@ -0,0 +1,193 @@
package model
import (
"fmt"
"strings"
"time"
"unicode"
"gorm.io/gorm"
)
// FileBlob represents a physical file stored in MinIO, deduplicated by SHA256.
type FileBlob struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
SHA256 string `gorm:"uniqueIndex;size:64;not null" json:"sha256"`
MinioKey string `gorm:"size:255;not null" json:"minio_key"`
FileSize int64 `gorm:"not null" json:"file_size"`
MimeType string `gorm:"size:255" json:"mime_type"`
RefCount int `gorm:"not null;default:0" json:"ref_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (FileBlob) TableName() string {
return "hpc_file_blobs"
}
// File represents a logical file visible to users, backed by a FileBlob.
type File struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"size:255;not null" json:"name"`
FolderID *int64 `gorm:"index" json:"folder_id,omitempty"`
BlobSHA256 string `gorm:"size:64;not null" json:"blob_sha256"`
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
func (File) TableName() string {
return "hpc_files"
}
// Folder represents a directory in the virtual file system.
type Folder struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"size:255;not null" json:"name"`
ParentID *int64 `gorm:"index" json:"parent_id,omitempty"`
Path string `gorm:"uniqueIndex;size:768;not null" json:"path"`
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
func (Folder) TableName() string {
return "hpc_folders"
}
// UploadSession represents an in-progress chunked upload.
// State transitions: pending→uploading, pending→completed(zero-byte), uploading→merging,
// uploading→cancelled, merging→completed, merging→failed, any→expired
type UploadSession struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
FileName string `gorm:"size:255;not null" json:"file_name"`
FileSize int64 `gorm:"not null" json:"file_size"`
ChunkSize int64 `gorm:"not null" json:"chunk_size"`
TotalChunks int `gorm:"not null" json:"total_chunks"`
SHA256 string `gorm:"size:64;not null" json:"sha256"`
FolderID *int64 `gorm:"index" json:"folder_id,omitempty"`
Status string `gorm:"size:20;not null;default:pending" json:"status"`
MinioPrefix string `gorm:"size:255;not null" json:"minio_prefix"`
MimeType string `gorm:"size:255;default:'application/octet-stream'" json:"mime_type"`
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (UploadSession) TableName() string {
return "hpc_upload_sessions"
}
// UploadChunk represents a single chunk of an upload session.
type UploadChunk struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
SessionID int64 `gorm:"not null;uniqueIndex:idx_session_chunk" json:"session_id"`
ChunkIndex int `gorm:"not null;uniqueIndex:idx_session_chunk" json:"chunk_index"`
MinioKey string `gorm:"size:255;not null" json:"minio_key"`
SHA256 string `gorm:"size:64" json:"sha256,omitempty"`
Size int64 `gorm:"not null" json:"size"`
Status string `gorm:"size:20;not null;default:pending" json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (UploadChunk) TableName() string {
return "hpc_upload_chunks"
}
// InitUploadRequest is the DTO for initiating a chunked upload.
type InitUploadRequest struct {
FileName string `json:"file_name" binding:"required"`
FileSize int64 `json:"file_size" binding:"required"`
SHA256 string `json:"sha256" binding:"required"`
FolderID *int64 `json:"folder_id,omitempty"`
ChunkSize *int64 `json:"chunk_size,omitempty"`
MimeType string `json:"mime_type,omitempty"`
}
// CreateFolderRequest is the DTO for creating a new folder.
type CreateFolderRequest struct {
Name string `json:"name" binding:"required"`
ParentID *int64 `json:"parent_id,omitempty"`
}
// UploadSessionResponse is the DTO returned when creating/querying an upload session.
type UploadSessionResponse struct {
ID int64 `json:"id"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
ChunkSize int64 `json:"chunk_size"`
TotalChunks int `json:"total_chunks"`
SHA256 string `json:"sha256"`
Status string `json:"status"`
UploadedChunks []int `json:"uploaded_chunks"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
// FileResponse is the DTO for a file in API responses.
type FileResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
FolderID *int64 `json:"folder_id,omitempty"`
Size int64 `json:"size"`
MimeType string `json:"mime_type"`
SHA256 string `json:"sha256"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// FolderResponse is the DTO for a folder in API responses.
type FolderResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parent_id,omitempty"`
Path string `json:"path"`
FileCount int64 `json:"file_count"`
SubFolderCount int64 `json:"subfolder_count"`
CreatedAt time.Time `json:"created_at"`
}
// ListFilesResponse is the paginated response for listing files.
type ListFilesResponse struct {
Files []FileResponse `json:"files"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// ValidateFileName rejects empty, "..", "/", "\", null bytes, control chars, leading/trailing spaces.
func ValidateFileName(name string) error {
if name == "" {
return fmt.Errorf("file name cannot be empty")
}
if strings.TrimSpace(name) != name {
return fmt.Errorf("file name cannot have leading or trailing spaces")
}
if name == ".." {
return fmt.Errorf("file name cannot be '..'")
}
if strings.Contains(name, "/") || strings.Contains(name, "\\") {
return fmt.Errorf("file name cannot contain '/' or '\\'")
}
for _, r := range name {
if r == 0 {
return fmt.Errorf("file name cannot contain null bytes")
}
if unicode.IsControl(r) {
return fmt.Errorf("file name cannot contain control characters")
}
}
return nil
}
// ValidateFolderName rejects same as ValidateFileName plus ".".
func ValidateFolderName(name string) error {
if name == "." {
return fmt.Errorf("folder name cannot be '.'")
}
return ValidateFileName(name)
}

96
internal/model/job.go Normal file
View File

@@ -0,0 +1,96 @@
package model
// SubmitJobRequest is the API request for submitting a job.
type SubmitJobRequest struct {
Script string `json:"script"` // 作业脚本内容
Partition string `json:"partition,omitempty"` // 提交到的分区
QOS string `json:"qos,omitempty"` // 使用的 QOS 策略
CPUs int32 `json:"cpus,omitempty"` // 请求的 CPU 核数
Memory string `json:"memory,omitempty"` // 请求的内存大小
TimeLimit string `json:"time_limit,omitempty"` // 运行时间限制 (分钟)
JobName string `json:"job_name,omitempty"` // 作业名称
Environment map[string]string `json:"environment,omitempty"` // 环境变量键值对
WorkDir string `json:"work_dir,omitempty"` // 作业工作目录
}
// JobResponse is the API response for a job.
type JobResponse struct {
// Identity
JobID int32 `json:"job_id"` // Slurm 作业 ID
Name string `json:"name"` // 作业名称
State []string `json:"job_state"` // 作业当前状态 (e.g. ["RUNNING"], ["PENDING","REQUEUED"])
StateReason string `json:"state_reason,omitempty"` // 作业等待/失败的原因
// Scheduling
Partition string `json:"partition"` // 所属分区
QOS string `json:"qos,omitempty"` // 使用的 QOS 策略
Priority *int32 `json:"priority,omitempty"` // 作业优先级
TimeLimit string `json:"time_limit,omitempty"` // 运行时间限制 (分钟,"UNLIMITED" 表示无限)
// Ownership
Account string `json:"account,omitempty"` // 计费账户
User string `json:"user,omitempty"` // 提交用户
Cluster string `json:"cluster,omitempty"` // 所属集群
// Resources
Cpus *int32 `json:"cpus,omitempty"` // 分配/请求的 CPU 核数
Tasks *int32 `json:"tasks,omitempty"` // 任务数
NodeCount *int32 `json:"node_count,omitempty"` // 节点数
Nodes string `json:"nodes,omitempty"` // 分配的节点列表
BatchHost string `json:"batch_host,omitempty"` // 批处理主节点
// Timing (Unix timestamp)
SubmitTime *int64 `json:"submit_time,omitempty"` // 提交时间
StartTime *int64 `json:"start_time,omitempty"` // 开始运行时间
EndTime *int64 `json:"end_time,omitempty"` // 结束/预计结束时间
// Result
ExitCode *int32 `json:"exit_code,omitempty"` // 退出码 (nil 表示未结束)
// IO Paths
StdOut string `json:"standard_output,omitempty"` // 标准输出文件路径
StdErr string `json:"standard_error,omitempty"` // 标准错误文件路径
StdIn string `json:"standard_input,omitempty"` // 标准输入文件路径
WorkDir string `json:"working_directory,omitempty"` // 工作目录
Command string `json:"command,omitempty"` // 执行的命令
// Array Job
ArrayJobID *int32 `json:"array_job_id,omitempty"` // 数组作业的父 Job ID
ArrayTaskID *int32 `json:"array_task_id,omitempty"` // 数组作业中的子任务 ID
}
// JobListResponse is the paginated response for job listings.
type JobListResponse struct {
Jobs []JobResponse `json:"jobs"` // 作业列表
Total int `json:"total"` // 符合条件的作业总数
Page int `json:"page"` // 当前页码 (从 1 开始)
PageSize int `json:"page_size"` // 每页条数
}
// JobListQuery contains pagination parameters for active job listing.
type JobListQuery struct {
Page int `form:"page,default=1" json:"page,omitempty"` // 页码 (从 1 开始)
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"` // 每页条数
}
// JobHistoryQuery contains query parameters for job history.
type JobHistoryQuery struct {
Users string `form:"users" json:"users,omitempty"` // 按用户名过滤 (逗号分隔)
StartTime string `form:"start_time" json:"start_time,omitempty"` // 作业开始时间下限 (Unix 时间戳)
EndTime string `form:"end_time" json:"end_time,omitempty"` // 作业结束时间上限 (Unix 时间戳)
SubmitTime string `form:"submit_time" json:"submit_time,omitempty"` // 作业提交时间过滤 (Unix 时间戳)
Account string `form:"account" json:"account,omitempty"` // 按计费账户过滤
Partition string `form:"partition" json:"partition,omitempty"` // 按分区过滤
State string `form:"state" json:"state,omitempty"` // 按作业状态过滤 (e.g. "COMPLETED", "FAILED")
JobName string `form:"job_name" json:"job_name,omitempty"` // 按作业名称过滤
Cluster string `form:"cluster" json:"cluster,omitempty"` // 按集群名称过滤
Qos string `form:"qos" json:"qos,omitempty"` // 按 QOS 策略过滤
Constraints string `form:"constraints" json:"constraints,omitempty"` // 按节点约束过滤
ExitCode string `form:"exit_code" json:"exit_code,omitempty"` // 按退出码过滤
Node string `form:"node" json:"node,omitempty"` // 按分配节点过滤
Reservation string `form:"reservation" json:"reservation,omitempty"` // 按预约名称过滤
Groups string `form:"groups" json:"groups,omitempty"` // 按用户组过滤
Wckey string `form:"wckey" json:"wckey,omitempty"` // 按 WCKey (Workload Characterization Key) 过滤
Page int `form:"page,default=1" json:"page,omitempty"` // 页码 (从 1 开始)
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"` // 每页条数
}

93
internal/model/task.go Normal file
View File

@@ -0,0 +1,93 @@
package model
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// Task status constants.
const (
TaskStatusSubmitted = "submitted"
TaskStatusPreparing = "preparing"
TaskStatusDownloading = "downloading"
TaskStatusReady = "ready"
TaskStatusQueued = "queued"
TaskStatusRunning = "running"
TaskStatusCompleted = "completed"
TaskStatusFailed = "failed"
)
// Task step constants for step-level retry tracking.
const (
TaskStepPreparing = "preparing"
TaskStepDownloading = "downloading"
TaskStepSubmitting = "submitting"
)
// Task represents an HPC task submitted through the application framework.
type Task struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TaskName string `gorm:"size:255" json:"task_name"`
AppID int64 `json:"app_id"`
AppName string `gorm:"size:255" json:"app_name"`
Status string `json:"status"`
CurrentStep string `json:"current_step"`
RetryCount int `json:"retry_count"`
Values json.RawMessage `gorm:"type:text" json:"values,omitempty"`
InputFileIDs json.RawMessage `json:"input_file_ids" gorm:"column:input_file_ids;type:text"`
Script string `json:"script,omitempty"`
SlurmJobID *int32 `json:"slurm_job_id,omitempty"`
WorkDir string `json:"work_dir,omitempty"`
Partition string `json:"partition,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
UserID string `json:"user_id"`
SubmittedAt time.Time `json:"submitted_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
func (Task) TableName() string {
return "hpc_tasks"
}
// CreateTaskRequest is the DTO for creating a new task.
type CreateTaskRequest struct {
AppID int64 `json:"app_id" binding:"required"`
TaskName string `json:"task_name"`
Values map[string]string `json:"values"`
InputFileIDs []int64 `json:"file_ids"`
}
// TaskResponse is the DTO returned in API responses.
type TaskResponse struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
AppID int64 `json:"app_id"`
AppName string `json:"app_name"`
Status string `json:"status"`
CurrentStep string `json:"current_step"`
RetryCount int `json:"retry_count"`
SlurmJobID *int32 `json:"slurm_job_id"`
WorkDir string `json:"work_dir"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TaskListResponse is the paginated response for listing tasks.
type TaskListResponse struct {
Items []TaskResponse `json:"items"`
Total int64 `json:"total"`
}
// TaskListQuery contains query parameters for listing tasks.
type TaskListQuery struct {
Page int `form:"page" json:"page,omitempty"`
PageSize int `form:"page_size" json:"page_size,omitempty"`
Status string `form:"status" json:"status,omitempty"`
}

104
internal/model/task_test.go Normal file
View File

@@ -0,0 +1,104 @@
package model
import (
"encoding/json"
"testing"
"time"
)
func TestTask_TableName(t *testing.T) {
task := Task{}
if got := task.TableName(); got != "hpc_tasks" {
t.Errorf("Task.TableName() = %q, want %q", got, "hpc_tasks")
}
}
func TestTask_JSONRoundTrip(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
jobID := int32(42)
task := Task{
ID: 1,
TaskName: "test task",
AppID: 10,
AppName: "GROMACS",
Status: TaskStatusRunning,
CurrentStep: TaskStepSubmitting,
RetryCount: 1,
Values: json.RawMessage(`{"np":"4"}`),
InputFileIDs: json.RawMessage(`[1,2,3]`),
Script: "#!/bin/bash",
SlurmJobID: &jobID,
WorkDir: "/data/work",
Partition: "gpu",
ErrorMessage: "",
UserID: "user1",
SubmittedAt: now,
StartedAt: &now,
FinishedAt: nil,
CreatedAt: now,
UpdatedAt: now,
}
data, err := json.Marshal(task)
if err != nil {
t.Fatalf("marshal task: %v", err)
}
var got Task
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal task: %v", err)
}
if got.ID != task.ID {
t.Errorf("ID = %d, want %d", got.ID, task.ID)
}
if got.TaskName != task.TaskName {
t.Errorf("TaskName = %q, want %q", got.TaskName, task.TaskName)
}
if got.Status != task.Status {
t.Errorf("Status = %q, want %q", got.Status, task.Status)
}
if got.CurrentStep != task.CurrentStep {
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, task.CurrentStep)
}
if got.RetryCount != task.RetryCount {
t.Errorf("RetryCount = %d, want %d", got.RetryCount, task.RetryCount)
}
if got.SlurmJobID == nil || *got.SlurmJobID != jobID {
t.Errorf("SlurmJobID = %v, want %d", got.SlurmJobID, jobID)
}
if got.UserID != task.UserID {
t.Errorf("UserID = %q, want %q", got.UserID, task.UserID)
}
if string(got.Values) != string(task.Values) {
t.Errorf("Values = %s, want %s", got.Values, task.Values)
}
if string(got.InputFileIDs) != string(task.InputFileIDs) {
t.Errorf("InputFileIDs = %s, want %s", got.InputFileIDs, task.InputFileIDs)
}
if got.FinishedAt != nil {
t.Errorf("FinishedAt = %v, want nil", got.FinishedAt)
}
}
func TestCreateTaskRequest_JSONBinding(t *testing.T) {
payload := `{"app_id":5,"task_name":"my task","values":{"np":"8"},"file_ids":[10,20]}`
var req CreateTaskRequest
if err := json.Unmarshal([]byte(payload), &req); err != nil {
t.Fatalf("unmarshal CreateTaskRequest: %v", err)
}
if req.AppID != 5 {
t.Errorf("AppID = %d, want 5", req.AppID)
}
if req.TaskName != "my task" {
t.Errorf("TaskName = %q, want %q", req.TaskName, "my task")
}
if v, ok := req.Values["np"]; !ok || v != "8" {
t.Errorf("Values[\"np\"] = %q, want %q", v, "8")
}
if len(req.InputFileIDs) != 2 || req.InputFileIDs[0] != 10 || req.InputFileIDs[1] != 20 {
t.Errorf("InputFileIDs = %v, want [10 20]", req.InputFileIDs)
}
}

142
internal/server/response.go Normal file
View File

@@ -0,0 +1,142 @@
package server
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// APIResponse represents the unified JSON structure for all API responses.
type APIResponse struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// OK responds with 200 and success data.
func OK(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, APIResponse{Success: true, Data: data})
}
// Created responds with 201 and success data.
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, APIResponse{Success: true, Data: data})
}
// BadRequest responds with 400 and an error message.
func BadRequest(c *gin.Context, msg string) {
c.JSON(http.StatusBadRequest, APIResponse{Success: false, Error: msg})
}
// NotFound responds with 404 and an error message.
func NotFound(c *gin.Context, msg string) {
c.JSON(http.StatusNotFound, APIResponse{Success: false, Error: msg})
}
// InternalError responds with 500 and an error message.
func InternalError(c *gin.Context, msg string) {
c.JSON(http.StatusInternalServerError, APIResponse{Success: false, Error: msg})
}
// ErrorWithStatus responds with a custom status code and an error message.
func ErrorWithStatus(c *gin.Context, code int, msg string) {
c.JSON(code, APIResponse{Success: false, Error: msg})
}
// ParseRange parses an HTTP Range header (RFC 7233).
// Only single-part ranges are supported: bytes=start-end, bytes=start-, bytes=-suffix.
// Multi-part ranges (bytes=0-100,200-300) return an error.
func ParseRange(rangeHeader string, fileSize int64) (start, end int64, err error) {
if rangeHeader == "" {
return 0, 0, fmt.Errorf("empty range header")
}
if !strings.HasPrefix(rangeHeader, "bytes=") {
return 0, 0, fmt.Errorf("invalid range unit: %s", rangeHeader)
}
rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
if strings.Contains(rangeSpec, ",") {
return 0, 0, fmt.Errorf("multi-part ranges are not supported")
}
rangeSpec = strings.TrimSpace(rangeSpec)
parts := strings.Split(rangeSpec, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid range format: %s", rangeSpec)
}
if parts[0] == "" {
suffix, parseErr := strconv.ParseInt(parts[1], 10, 64)
if parseErr != nil {
return 0, 0, fmt.Errorf("invalid suffix range: %s", parts[1])
}
if suffix <= 0 || suffix > fileSize {
return 0, 0, fmt.Errorf("suffix range %d exceeds file size %d", suffix, fileSize)
}
start = fileSize - suffix
end = fileSize - 1
} else if parts[1] == "" {
start, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid range start: %s", parts[0])
}
if start >= fileSize {
return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize)
}
end = fileSize - 1
} else {
start, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid range start: %s", parts[0])
}
end, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid range end: %s", parts[1])
}
if start > end {
return 0, 0, fmt.Errorf("range start %d > end %d", start, end)
}
if start >= fileSize {
return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize)
}
if end >= fileSize {
end = fileSize - 1
}
}
return start, end, nil
}
// StreamFile sends a full file as an HTTP response with proper headers.
func StreamFile(c *gin.Context, reader io.ReadCloser, filename string, fileSize int64, contentType string) {
defer reader.Close()
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
c.Header("Content-Type", contentType)
c.Header("Content-Length", strconv.FormatInt(fileSize, 10))
c.Header("Accept-Ranges", "bytes")
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// StreamRange sends a partial content response (206) for a byte range.
func StreamRange(c *gin.Context, reader io.ReadCloser, start, end, totalSize int64, contentType string) {
defer reader.Close()
contentLength := end - start + 1
c.Header("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize))
c.Header("Content-Type", contentType)
c.Header("Content-Length", strconv.FormatInt(contentLength, 10))
c.Header("Accept-Ranges", "bytes")
c.Status(http.StatusPartialContent)
io.Copy(c.Writer, reader)
}

View File

@@ -0,0 +1,178 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func setupTestContext() (*httptest.ResponseRecorder, *gin.Context) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
return w, c
}
func parseResponse(t *testing.T, w *httptest.ResponseRecorder) APIResponse {
t.Helper()
var resp APIResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response body: %v", err)
}
return resp
}
func TestOK(t *testing.T) {
w, c := setupTestContext()
OK(c, map[string]string{"msg": "hello"})
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
resp := parseResponse(t, w)
if !resp.Success {
t.Fatal("expected success=true")
}
}
func TestCreated(t *testing.T) {
w, c := setupTestContext()
Created(c, map[string]int{"id": 1})
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", w.Code)
}
resp := parseResponse(t, w)
if !resp.Success {
t.Fatal("expected success=true")
}
}
func TestBadRequest(t *testing.T) {
w, c := setupTestContext()
BadRequest(c, "invalid input")
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "invalid input" {
t.Fatalf("expected error 'invalid input', got '%s'", resp.Error)
}
}
func TestNotFound(t *testing.T) {
w, c := setupTestContext()
NotFound(c, "resource missing")
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "resource missing" {
t.Fatalf("expected error 'resource missing', got '%s'", resp.Error)
}
}
func TestInternalError(t *testing.T) {
w, c := setupTestContext()
InternalError(c, "something broke")
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "something broke" {
t.Fatalf("expected error 'something broke', got '%s'", resp.Error)
}
}
func TestErrorWithStatus(t *testing.T) {
w, c := setupTestContext()
ErrorWithStatus(c, http.StatusConflict, "already exists")
if w.Code != http.StatusConflict {
t.Fatalf("expected 409, got %d", w.Code)
}
resp := parseResponse(t, w)
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "already exists" {
t.Fatalf("expected error 'already exists', got '%s'", resp.Error)
}
}
func TestParseRangeStandard(t *testing.T) {
tests := []struct {
rangeHeader string
fileSize int64
wantStart int64
wantEnd int64
wantErr bool
}{
{"bytes=0-1023", 10000, 0, 1023, false},
{"bytes=1024-", 10000, 1024, 9999, false},
{"bytes=-1024", 10000, 8976, 9999, false},
{"bytes=0-0", 10000, 0, 0, false},
{"bytes=9999-", 10000, 9999, 9999, false},
}
for _, tt := range tests {
start, end, err := ParseRange(tt.rangeHeader, tt.fileSize)
if (err != nil) != tt.wantErr {
t.Errorf("ParseRange(%q, %d) error = %v, wantErr %v", tt.rangeHeader, tt.fileSize, err, tt.wantErr)
continue
}
if !tt.wantErr {
if start != tt.wantStart || end != tt.wantEnd {
t.Errorf("ParseRange(%q, %d) = (%d, %d), want (%d, %d)", tt.rangeHeader, tt.fileSize, start, end, tt.wantStart, tt.wantEnd)
}
}
}
}
func TestParseRangeInvalidAndMultiPart(t *testing.T) {
tests := []struct {
rangeHeader string
fileSize int64
}{
{"", 10000},
{"bytes=9999-0", 10000},
{"bytes=20000-", 10000},
{"bytes=0-100,200-300", 10000},
{"bytes=0-100, 400-500", 10000},
{"bytes=", 10000},
{"chars=0-100", 10000},
}
for _, tt := range tests {
_, _, err := ParseRange(tt.rangeHeader, tt.fileSize)
if err == nil {
t.Errorf("ParseRange(%q, %d) expected error, got nil", tt.rangeHeader, tt.fileSize)
}
}
}
func TestParseRangeEdgeCases(t *testing.T) {
start, end, err := ParseRange("bytes=0-99999", 10000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if end != 9999 {
t.Errorf("end = %d, want 9999 (clamped to fileSize-1)", end)
}
if start != 0 {
t.Errorf("start = %d, want 0", start)
}
}

198
internal/server/server.go Normal file
View File

@@ -0,0 +1,198 @@
package server
import (
"net/http"
"gcy_hpc_server/internal/middleware"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type JobHandler interface {
SubmitJob(c *gin.Context)
GetJobs(c *gin.Context)
GetJobHistory(c *gin.Context)
GetJob(c *gin.Context)
CancelJob(c *gin.Context)
}
type ClusterHandler interface {
GetNodes(c *gin.Context)
GetNode(c *gin.Context)
GetPartitions(c *gin.Context)
GetPartition(c *gin.Context)
GetDiag(c *gin.Context)
}
type ApplicationHandler interface {
ListApplications(c *gin.Context)
CreateApplication(c *gin.Context)
GetApplication(c *gin.Context)
UpdateApplication(c *gin.Context)
DeleteApplication(c *gin.Context)
// SubmitApplication(c *gin.Context) // [已禁用] 已被 POST /tasks 取代
}
type UploadHandler interface {
InitUpload(c *gin.Context)
GetUploadStatus(c *gin.Context)
UploadChunk(c *gin.Context)
CompleteUpload(c *gin.Context)
CancelUpload(c *gin.Context)
}
type FileHandler interface {
ListFiles(c *gin.Context)
GetFile(c *gin.Context)
DownloadFile(c *gin.Context)
DeleteFile(c *gin.Context)
}
type FolderHandler interface {
CreateFolder(c *gin.Context)
GetFolder(c *gin.Context)
ListFolders(c *gin.Context)
DeleteFolder(c *gin.Context)
}
type TaskHandler interface {
CreateTask(c *gin.Context)
ListTasks(c *gin.Context)
}
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, taskH TaskHandler, logger *zap.Logger) *gin.Engine {
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(gin.Recovery())
if logger != nil {
r.Use(middleware.RequestLogger(logger))
}
v1 := r.Group("/api/v1")
jobs := v1.Group("/jobs")
jobs.POST("/submit", jobH.SubmitJob)
jobs.GET("", jobH.GetJobs)
jobs.GET("/history", jobH.GetJobHistory)
jobs.GET("/:id", jobH.GetJob)
jobs.DELETE("/:id", jobH.CancelJob)
v1.GET("/nodes", clusterH.GetNodes)
v1.GET("/nodes/:name", clusterH.GetNode)
v1.GET("/partitions", clusterH.GetPartitions)
v1.GET("/partitions/:name", clusterH.GetPartition)
v1.GET("/diag", clusterH.GetDiag)
apps := v1.Group("/applications")
apps.GET("", appH.ListApplications)
apps.POST("", appH.CreateApplication)
apps.GET("/:id", appH.GetApplication)
apps.PUT("/:id", appH.UpdateApplication)
apps.DELETE("/:id", appH.DeleteApplication)
// apps.POST("/:id/submit", appH.SubmitApplication) // [已禁用] 已被 POST /tasks 取代
files := v1.Group("/files")
if uploadH != nil {
uploads := files.Group("/uploads")
uploads.POST("", uploadH.InitUpload)
uploads.GET("/:id", uploadH.GetUploadStatus)
uploads.PUT("/:id/chunks/:index", uploadH.UploadChunk)
uploads.POST("/:id/complete", uploadH.CompleteUpload)
uploads.DELETE("/:id", uploadH.CancelUpload)
}
if fileH != nil {
files.GET("", fileH.ListFiles)
files.GET("/:id", fileH.GetFile)
files.GET("/:id/download", fileH.DownloadFile)
files.DELETE("/:id", fileH.DeleteFile)
}
if folderH != nil {
folders := files.Group("/folders")
folders.POST("", folderH.CreateFolder)
folders.GET("", folderH.ListFolders)
folders.GET("/:id", folderH.GetFolder)
folders.DELETE("/:id", folderH.DeleteFolder)
}
if taskH != nil {
tasks := v1.Group("/tasks")
{
tasks.POST("", taskH.CreateTask)
tasks.GET("", taskH.ListTasks)
}
}
return r
}
// NewTestRouter creates a router for testing without real handlers.
func NewTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(gin.Recovery())
v1 := r.Group("/api/v1")
registerPlaceholderRoutes(v1)
return r
}
func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
jobs := v1.Group("/jobs")
jobs.POST("/submit", notImplemented)
jobs.GET("", notImplemented)
jobs.GET("/history", notImplemented)
jobs.GET("/:id", notImplemented)
jobs.DELETE("/:id", notImplemented)
v1.GET("/nodes", notImplemented)
v1.GET("/nodes/:name", notImplemented)
v1.GET("/partitions", notImplemented)
v1.GET("/partitions/:name", notImplemented)
v1.GET("/diag", notImplemented)
apps := v1.Group("/applications")
apps.GET("", notImplemented)
apps.POST("", notImplemented)
apps.GET("/:id", notImplemented)
apps.PUT("/:id", notImplemented)
apps.DELETE("/:id", notImplemented)
// apps.POST("/:id/submit", notImplemented) // [已禁用] 已被 POST /tasks 取代
files := v1.Group("/files")
uploads := files.Group("/uploads")
uploads.POST("", notImplemented)
uploads.GET("/:id", notImplemented)
uploads.PUT("/:id/chunks/:index", notImplemented)
uploads.POST("/:id/complete", notImplemented)
uploads.DELETE("/:id", notImplemented)
files.GET("", notImplemented)
files.GET("/:id", notImplemented)
files.GET("/:id/download", notImplemented)
files.DELETE("/:id", notImplemented)
folders := files.Group("/folders")
folders.POST("", notImplemented)
folders.GET("", notImplemented)
folders.GET("/:id", notImplemented)
folders.DELETE("/:id", notImplemented)
v1.POST("/tasks", notImplemented)
v1.GET("/tasks", notImplemented)
}
func notImplemented(c *gin.Context) {
c.JSON(http.StatusNotImplemented, APIResponse{
Success: false,
Error: "not implemented",
})
}

View File

@@ -0,0 +1,109 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestAllRoutesRegistered(t *testing.T) {
r := NewTestRouter()
routes := r.Routes()
expected := []struct {
method string
path string
}{
{"POST", "/api/v1/jobs/submit"},
{"GET", "/api/v1/jobs"},
{"GET", "/api/v1/jobs/history"},
{"GET", "/api/v1/jobs/:id"},
{"DELETE", "/api/v1/jobs/:id"},
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/nodes/:name"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/partitions/:name"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/applications"},
{"POST", "/api/v1/applications"},
{"GET", "/api/v1/applications/:id"},
{"PUT", "/api/v1/applications/:id"},
{"DELETE", "/api/v1/applications/:id"},
// {"POST", "/api/v1/applications/:id/submit"}, // [已禁用] 已被 POST /tasks 取代
}
routeMap := map[string]bool{}
for _, route := range routes {
key := route.Method + " " + route.Path
routeMap[key] = true
}
for _, exp := range expected {
key := exp.method + " " + exp.path
if !routeMap[key] {
t.Errorf("missing route: %s", key)
}
}
if len(routes) < len(expected) {
t.Errorf("expected at least %d routes, got %d", len(expected), len(routes))
}
}
func TestUnregisteredPathReturns404(t *testing.T) {
r := NewTestRouter()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/nonexistent", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404 for unregistered path, got %d", w.Code)
}
}
func TestRegisteredPathReturns501(t *testing.T) {
r := NewTestRouter()
endpoints := []struct {
method string
path string
}{
{"GET", "/api/v1/jobs"},
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/applications"},
}
for _, ep := range endpoints {
w := httptest.NewRecorder()
req, _ := http.NewRequest(ep.method, ep.path, nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotImplemented {
t.Fatalf("%s %s: expected 501, got %d", ep.method, ep.path, w.Code)
}
var resp APIResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp.Success {
t.Fatal("expected success=false")
}
if resp.Error != "not implemented" {
t.Fatalf("expected error 'not implemented', got '%s'", resp.Error)
}
}
}
func TestRouterUsesGinMode(t *testing.T) {
gin.SetMode(gin.TestMode)
r := NewTestRouter()
if r == nil {
t.Fatal("NewRouter returned nil")
}
}

View File

@@ -0,0 +1,111 @@
package service
import (
"context"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// ApplicationService handles parameter validation, script rendering, and job
// submission for parameterized HPC applications.
type ApplicationService struct {
store *store.ApplicationStore
jobSvc *JobService
workDirBase string
logger *zap.Logger
taskSvc *TaskService
}
func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger, taskSvc ...*TaskService) *ApplicationService {
var ts *TaskService
if len(taskSvc) > 0 {
ts = taskSvc[0]
}
return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger, taskSvc: ts}
}
// ListApplications delegates to the store.
func (s *ApplicationService) ListApplications(ctx context.Context, page, pageSize int) ([]model.Application, int, error) {
return s.store.List(ctx, page, pageSize)
}
// CreateApplication delegates to the store.
func (s *ApplicationService) CreateApplication(ctx context.Context, req *model.CreateApplicationRequest) (int64, error) {
return s.store.Create(ctx, req)
}
// GetApplication delegates to the store.
func (s *ApplicationService) GetApplication(ctx context.Context, id int64) (*model.Application, error) {
return s.store.GetByID(ctx, id)
}
// UpdateApplication delegates to the store.
func (s *ApplicationService) UpdateApplication(ctx context.Context, id int64, req *model.UpdateApplicationRequest) error {
return s.store.Update(ctx, id, req)
}
// DeleteApplication delegates to the store.
func (s *ApplicationService) DeleteApplication(ctx context.Context, id int64) error {
return s.store.Delete(ctx, id)
}
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此方法不再被调用。
/* // SubmitFromApplication orchestrates the full submission flow.
// When TaskService is available, it delegates to ProcessTaskSync which creates
// an hpc_tasks record and runs the full pipeline. Otherwise falls back to the
// original direct implementation.
func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicationID int64, values map[string]string) (*model.JobResponse, error) {
// [已禁用] 旧的直接提交路径,已被 TaskService 管道取代。生产环境中 taskSvc 始终非 nil此分支不会执行。
// if s.taskSvc != nil {
req := &model.CreateTaskRequest{
AppID: applicationID,
Values: values,
InputFileIDs: nil, // old API has no file_ids concept
TaskName: "",
}
return s.taskSvc.ProcessTaskSync(ctx, req)
// }
// // Fallback: original direct logic when TaskService not available
// app, err := s.store.GetByID(ctx, applicationID)
// if err != nil {
// return nil, fmt.Errorf("get application: %w", err)
// }
// if app == nil {
// return nil, fmt.Errorf("application %d not found", applicationID)
// }
//
// var params []model.ParameterSchema
// if len(app.Parameters) > 0 {
// if err := json.Unmarshal(app.Parameters, &params); err != nil {
// return nil, fmt.Errorf("parse parameters: %w", err)
// }
// }
//
// if err := ValidateParams(params, values); err != nil {
// return nil, err
// }
//
// rendered := RenderScript(app.ScriptTemplate, params, values)
//
// workDir := ""
// if s.workDirBase != "" {
// safeName := SanitizeDirName(app.Name)
// subDir := time.Now().Format("20060102_150405") + "_" + RandomSuffix(4)
// workDir = filepath.Join(s.workDirBase, safeName, subDir)
// if err := os.MkdirAll(workDir, 0777); err != nil {
// return nil, fmt.Errorf("create work directory %s: %w", workDir, err)
// }
// // 绕过 umask确保整条路径都有写权限
// for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
// os.Chmod(dir, 0777)
// }
// os.Chmod(s.workDirBase, 0777)
// }
//
// req := &model.SubmitJobRequest{Script: rendered, WorkDir: workDir}
// return s.jobSvc.SubmitJob(ctx, req)
} */

View File

@@ -0,0 +1,367 @@
package service
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
gormlogger "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupApplicationService(t *testing.T, slurmHandler http.HandlerFunc) (*ApplicationService, func()) {
t.Helper()
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Application{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
jobSvc := NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(db)
appSvc := NewApplicationService(appStore, jobSvc, "", zap.NewNop())
return appSvc, srv.Close
}
func TestValidateParams_AllRequired(t *testing.T) {
params := []model.ParameterSchema{
{Name: "NAME", Type: model.ParamTypeString, Required: true},
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
}
values := map[string]string{"NAME": "hello", "COUNT": "5"}
if err := ValidateParams(params, values); err != nil {
t.Errorf("expected no error, got %v", err)
}
}
func TestValidateParams_MissingRequired(t *testing.T) {
params := []model.ParameterSchema{
{Name: "NAME", Type: model.ParamTypeString, Required: true},
}
values := map[string]string{}
err := ValidateParams(params, values)
if err == nil {
t.Fatal("expected error for missing required param")
}
if !strings.Contains(err.Error(), "NAME") {
t.Errorf("error should mention param name, got: %v", err)
}
}
func TestValidateParams_InvalidInteger(t *testing.T) {
params := []model.ParameterSchema{
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
}
values := map[string]string{"COUNT": "abc"}
err := ValidateParams(params, values)
if err == nil {
t.Fatal("expected error for invalid integer")
}
if !strings.Contains(err.Error(), "integer") {
t.Errorf("error should mention integer, got: %v", err)
}
}
func TestValidateParams_InvalidEnum(t *testing.T) {
params := []model.ParameterSchema{
{Name: "MODE", Type: model.ParamTypeEnum, Required: true, Options: []string{"fast", "slow"}},
}
values := map[string]string{"MODE": "medium"}
err := ValidateParams(params, values)
if err == nil {
t.Fatal("expected error for invalid enum value")
}
if !strings.Contains(err.Error(), "MODE") {
t.Errorf("error should mention param name, got: %v", err)
}
}
func TestValidateParams_BooleanValues(t *testing.T) {
params := []model.ParameterSchema{
{Name: "FLAG", Type: model.ParamTypeBoolean, Required: true},
}
for _, val := range []string{"true", "false", "1", "0"} {
err := ValidateParams(params, map[string]string{"FLAG": val})
if err != nil {
t.Errorf("boolean value %q should be valid, got error: %v", val, err)
}
}
}
func TestRenderScript_SimpleReplacement(t *testing.T) {
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
values := map[string]string{"INPUT": "data.txt"}
result := RenderScript("echo $INPUT", params, values)
expected := "echo 'data.txt'"
if result != expected {
t.Errorf("got %q, want %q", result, expected)
}
}
func TestRenderScript_DefaultValues(t *testing.T) {
params := []model.ParameterSchema{{Name: "OUTPUT", Type: model.ParamTypeString, Default: "out.log"}}
values := map[string]string{}
result := RenderScript("cat $OUTPUT", params, values)
expected := "cat 'out.log'"
if result != expected {
t.Errorf("got %q, want %q", result, expected)
}
}
func TestRenderScript_PreservesUnknownVars(t *testing.T) {
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
values := map[string]string{"INPUT": "data.txt"}
result := RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values)
if !strings.Contains(result, "$HOME") {
t.Error("$HOME should be preserved")
}
if !strings.Contains(result, "$PATH") {
t.Error("$PATH should be preserved")
}
if !strings.Contains(result, "'data.txt'") {
t.Error("$INPUT should be replaced")
}
}
func TestRenderScript_ShellEscaping(t *testing.T) {
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
tests := []struct {
name string
value string
expected string
}{
{"semicolon injection", "; rm -rf /", "'; rm -rf /'"},
{"command substitution", "$(cat /etc/passwd)", "'$(cat /etc/passwd)'"},
{"single quote", "hello'world", "'hello'\\''world'"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value})
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestRenderScript_OverlappingParams(t *testing.T) {
template := "$JOB_NAME and $JOB"
params := []model.ParameterSchema{
{Name: "JOB", Type: model.ParamTypeString},
{Name: "JOB_NAME", Type: model.ParamTypeString},
}
values := map[string]string{"JOB": "myjob", "JOB_NAME": "my-test-job"}
result := RenderScript(template, params, values)
if strings.Contains(result, "$JOB_NAME") {
t.Error("$JOB_NAME was not replaced")
}
if strings.Contains(result, "$JOB") {
t.Error("$JOB was not replaced")
}
if !strings.Contains(result, "'my-test-job'") {
t.Errorf("expected 'my-test-job' in result, got: %s", result)
}
if !strings.Contains(result, "'myjob'") {
t.Errorf("expected 'myjob' in result, got: %s", result)
}
}
func TestRenderScript_NewlineInValue(t *testing.T) {
params := []model.ParameterSchema{{Name: "CMD", Type: model.ParamTypeString}}
values := map[string]string{"CMD": "line1\nline2"}
result := RenderScript("echo $CMD", params, values)
expected := "echo 'line1\nline2'"
if result != expected {
t.Errorf("got %q, want %q", result, expected)
}
}
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_Success(t *testing.T) {
jobID := int32(42)
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer cleanup()
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
})
if err != nil {
t.Fatalf("create app: %v", err)
}
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
"JOB_NAME": "my-job",
"INPUT": "hello",
})
if err != nil {
t.Fatalf("SubmitFromApplication() error = %v", err)
}
if resp.JobID != 42 {
t.Errorf("JobID = %d, want 42", resp.JobID)
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_AppNotFound(t *testing.T) {
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer cleanup()
_, err := appSvc.SubmitFromApplication(context.Background(), 99999, map[string]string{})
if err == nil {
t.Fatal("expected error for non-existent app")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_ValidationFail(t *testing.T) {
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer cleanup()
_, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
Name: "valid-app",
ScriptTemplate: "#!/bin/bash\necho $INPUT",
Parameters: json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`),
})
_, err = appSvc.SubmitFromApplication(context.Background(), 1, map[string]string{})
if err == nil {
t.Fatal("expected validation error for missing required param")
}
if !strings.Contains(err.Error(), "missing") {
t.Errorf("error should mention 'missing', got: %v", err)
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_NoParameters(t *testing.T) {
jobID := int32(99)
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer cleanup()
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
Name: "simple-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
if err != nil {
t.Fatalf("create app: %v", err)
}
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{})
if err != nil {
t.Fatalf("SubmitFromApplication() error = %v", err)
}
if resp.JobID != 99 {
t.Errorf("JobID = %d, want 99", resp.JobID)
}
}
*/
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此测试不再需要。
/*
func TestSubmitFromApplication_DelegatesToTaskService(t *testing.T) {
jobID := int32(77)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer srv.Close()
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Application{}, &model.Task{}, &model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
appStore := store.NewApplicationStore(db)
taskStore := store.NewTaskStore(db)
fileStore := store.NewFileStore(db)
blobStore := store.NewBlobStore(db)
workDirBase := filepath.Join(t.TempDir(), "workdir")
os.MkdirAll(workDirBase, 0777)
taskSvc := NewTaskService(taskStore, appStore, fileStore, blobStore, nil, jobSvc, workDirBase, zap.NewNop())
appSvc := NewApplicationService(appStore, jobSvc, workDirBase, zap.NewNop(), taskSvc)
id, err := appStore.Create(context.Background(), &model.CreateApplicationRequest{
Name: "delegated-app",
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
})
if err != nil {
t.Fatalf("create app: %v", err)
}
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
"JOB_NAME": "delegated-job",
"INPUT": "test-data",
})
if err != nil {
t.Fatalf("SubmitFromApplication() error = %v", err)
}
if resp.JobID != 77 {
t.Errorf("JobID = %d, want 77", resp.JobID)
}
var task model.Task
if err := db.Where("app_id = ?", id).First(&task).Error; err != nil {
t.Fatalf("no hpc_tasks record found for app_id %d: %v", id, err)
}
if task.SlurmJobID == nil || *task.SlurmJobID != 77 {
t.Errorf("task SlurmJobID = %v, want 77", task.SlurmJobID)
}
if task.Status != model.TaskStatusQueued {
t.Errorf("task Status = %q, want %q", task.Status, model.TaskStatusQueued)
}
}
*/

View File

@@ -0,0 +1,356 @@
package service
import (
"context"
"fmt"
"strconv"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
)
func derefStr(s *string) string {
if s == nil {
return ""
}
return *s
}
func derefInt32(i *int32) int32 {
if i == nil {
return 0
}
return *i
}
func derefInt64(i *int64) int64 {
if i == nil {
return 0
}
return *i
}
func uint32NoValString(v *slurm.Uint32NoVal) string {
if v == nil {
return ""
}
if v.Infinite != nil && *v.Infinite {
return "UNLIMITED"
}
if v.Number != nil {
return strconv.FormatInt(*v.Number, 10)
}
return ""
}
func derefUint64NoValInt64(v *slurm.Uint64NoVal) *int64 {
if v != nil && v.Number != nil {
return v.Number
}
return nil
}
func derefCSVString(cs *slurm.CSVString) string {
if cs == nil || len(*cs) == 0 {
return ""
}
result := ""
for i, s := range *cs {
if i > 0 {
result += ","
}
result += s
}
return result
}
type ClusterService struct {
client *slurm.Client
logger *zap.Logger
}
func NewClusterService(client *slurm.Client, logger *zap.Logger) *ClusterService {
return &ClusterService{client: client, logger: logger}
}
func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetNodes"),
)
start := time.Now()
resp, _, err := s.client.Nodes.GetNodes(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetNodes"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get nodes", zap.Error(err))
return nil, fmt.Errorf("get nodes: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetNodes"),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Nodes == nil {
return nil, nil
}
result := make([]model.NodeResponse, 0, len(*resp.Nodes))
for _, n := range *resp.Nodes {
result = append(result, mapNode(n))
}
return result, nil
}
func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetNode"),
zap.String("node_name", name),
)
start := time.Now()
resp, _, err := s.client.Nodes.GetNode(ctx, name, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetNode"),
zap.String("node_name", name),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err))
return nil, fmt.Errorf("get node %s: %w", name, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetNode"),
zap.String("node_name", name),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Nodes == nil || len(*resp.Nodes) == 0 {
return nil, nil
}
n := (*resp.Nodes)[0]
mapped := mapNode(n)
return &mapped, nil
}
func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetPartitions"),
)
start := time.Now()
resp, _, err := s.client.Partitions.GetPartitions(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetPartitions"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get partitions", zap.Error(err))
return nil, fmt.Errorf("get partitions: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetPartitions"),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Partitions == nil {
return nil, nil
}
result := make([]model.PartitionResponse, 0, len(*resp.Partitions))
for _, pi := range *resp.Partitions {
result = append(result, mapPartition(pi))
}
return result, nil
}
func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetPartition"),
zap.String("partition_name", name),
)
start := time.Now()
resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetPartition"),
zap.String("partition_name", name),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err))
return nil, fmt.Errorf("get partition %s: %w", name, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetPartition"),
zap.String("partition_name", name),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Partitions == nil || len(*resp.Partitions) == 0 {
return nil, nil
}
p := (*resp.Partitions)[0]
mapped := mapPartition(p)
return &mapped, nil
}
func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetDiag"),
)
start := time.Now()
resp, _, err := s.client.Diag.GetDiag(ctx)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetDiag"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get diag", zap.Error(err))
return nil, fmt.Errorf("get diag: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetDiag"),
zap.Duration("took", took),
zap.Any("body", resp),
)
return resp, nil
}
func mapNode(n slurm.Node) model.NodeResponse {
return model.NodeResponse{
Name: derefStr(n.Name),
State: n.State,
CPUs: derefInt32(n.Cpus),
AllocCpus: n.AllocCpus,
Cores: n.Cores,
Sockets: n.Sockets,
Threads: n.Threads,
RealMemory: derefInt64(n.RealMemory),
AllocMemory: derefInt64(n.AllocMemory),
FreeMem: derefUint64NoValInt64(n.FreeMem),
CpuLoad: n.CpuLoad,
Arch: derefStr(n.Architecture),
OS: derefStr(n.OperatingSystem),
Gres: derefStr(n.Gres),
GresUsed: derefStr(n.GresUsed),
Reason: derefStr(n.Reason),
ReasonSetByUser: derefStr(n.ReasonSetByUser),
Address: derefStr(n.Address),
Hostname: derefStr(n.Hostname),
Weight: n.Weight,
Features: derefCSVString(n.Features),
ActiveFeatures: derefCSVString(n.ActiveFeatures),
}
}
func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
var state []string
var isDefault bool
if pi.Partition != nil {
state = pi.Partition.State
for _, s := range state {
if s == "DEFAULT" {
isDefault = true
break
}
}
}
var nodes string
if pi.Nodes != nil {
nodes = derefStr(pi.Nodes.Configured)
}
var totalCPUs int32
if pi.CPUs != nil {
totalCPUs = derefInt32(pi.CPUs.Total)
}
var totalNodes int32
if pi.Nodes != nil {
totalNodes = derefInt32(pi.Nodes.Total)
}
var maxTime string
if pi.Maximums != nil {
maxTime = uint32NoValString(pi.Maximums.Time)
}
var maxNodes *int32
if pi.Maximums != nil {
maxNodes = mapUint32NoValToInt32(pi.Maximums.Nodes)
}
var maxCPUsPerNode *int32
if pi.Maximums != nil {
maxCPUsPerNode = mapUint32NoValToInt32(pi.Maximums.CpusPerNode)
}
var minNodes *int32
if pi.Minimums != nil {
minNodes = pi.Minimums.Nodes
}
var defaultTime string
if pi.Defaults != nil {
defaultTime = uint32NoValString(pi.Defaults.Time)
}
var graceTime *int32 = pi.GraceTime
var priority *int32
if pi.Priority != nil {
priority = pi.Priority.JobFactor
}
var qosAllowed, qosDeny, qosAssigned string
if pi.QOS != nil {
qosAllowed = derefStr(pi.QOS.Allowed)
qosDeny = derefStr(pi.QOS.Deny)
qosAssigned = derefStr(pi.QOS.Assigned)
}
var accountsAllowed, accountsDeny string
if pi.Accounts != nil {
accountsAllowed = derefStr(pi.Accounts.Allowed)
accountsDeny = derefStr(pi.Accounts.Deny)
}
return model.PartitionResponse{
Name: derefStr(pi.Name),
State: state,
Default: isDefault,
Nodes: nodes,
TotalNodes: totalNodes,
TotalCPUs: totalCPUs,
MaxTime: maxTime,
MaxNodes: maxNodes,
MaxCPUsPerNode: maxCPUsPerNode,
MinNodes: minNodes,
DefaultTime: defaultTime,
GraceTime: graceTime,
Priority: priority,
QOSAllowed: qosAllowed,
QOSDeny: qosDeny,
QOSAssigned: qosAssigned,
AccountsAllowed: accountsAllowed,
AccountsDeny: accountsDeny,
}
}

View File

@@ -0,0 +1,467 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func mockServer(handler http.HandlerFunc) (*slurm.Client, func()) {
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return client, srv.Close
}
func TestGetNodes(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/nodes" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"nodes": []map[string]interface{}{
{
"name": "node1",
"state": []string{"IDLE"},
"cpus": 64,
"real_memory": 256000,
"alloc_memory": 0,
"architecture": "x86_64",
"operating_system": "Linux 5.15",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
nodes, err := svc.GetNodes(context.Background())
if err != nil {
t.Fatalf("GetNodes returned error: %v", err)
}
if len(nodes) != 1 {
t.Fatalf("expected 1 node, got %d", len(nodes))
}
n := nodes[0]
if n.Name != "node1" {
t.Errorf("expected name node1, got %s", n.Name)
}
if len(n.State) != 1 || n.State[0] != "IDLE" {
t.Errorf("expected state [IDLE], got %v", n.State)
}
if n.CPUs != 64 {
t.Errorf("expected 64 CPUs, got %d", n.CPUs)
}
if n.RealMemory != 256000 {
t.Errorf("expected real_memory 256000, got %d", n.RealMemory)
}
if n.Arch != "x86_64" {
t.Errorf("expected arch x86_64, got %s", n.Arch)
}
if n.OS != "Linux 5.15" {
t.Errorf("expected OS 'Linux 5.15', got %s", n.OS)
}
}
func TestGetNodes_Empty(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
nodes, err := svc.GetNodes(context.Background())
if err != nil {
t.Fatalf("GetNodes returned error: %v", err)
}
if nodes != nil {
t.Errorf("expected nil for empty response, got %v", nodes)
}
}
func TestGetNode(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/node/node1" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"nodes": []map[string]interface{}{
{"name": "node1", "state": []string{"ALLOCATED"}, "cpus": 32},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
node, err := svc.GetNode(context.Background(), "node1")
if err != nil {
t.Fatalf("GetNode returned error: %v", err)
}
if node == nil {
t.Fatal("expected node, got nil")
}
if node.Name != "node1" {
t.Errorf("expected name node1, got %s", node.Name)
}
}
func TestGetNode_NotFound(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
node, err := svc.GetNode(context.Background(), "missing")
if err != nil {
t.Fatalf("GetNode returned error: %v", err)
}
if node != nil {
t.Errorf("expected nil for missing node, got %+v", node)
}
}
func TestGetPartitions(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/partitions" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "normal",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "node[1-10]",
"total": 10,
},
"cpus": map[string]interface{}{
"total": 640,
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{
"set": true,
"infinite": false,
"number": 86400,
},
},
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
partitions, err := svc.GetPartitions(context.Background())
if err != nil {
t.Fatalf("GetPartitions returned error: %v", err)
}
if len(partitions) != 1 {
t.Fatalf("expected 1 partition, got %d", len(partitions))
}
p := partitions[0]
if p.Name != "normal" {
t.Errorf("expected name normal, got %s", p.Name)
}
if len(p.State) != 1 || p.State[0] != "UP" {
t.Errorf("expected state [UP], got %v", p.State)
}
if p.Nodes != "node[1-10]" {
t.Errorf("expected nodes 'node[1-10]', got %s", p.Nodes)
}
if p.TotalCPUs != 640 {
t.Errorf("expected 640 total CPUs, got %d", p.TotalCPUs)
}
if p.TotalNodes != 10 {
t.Errorf("expected 10 total nodes, got %d", p.TotalNodes)
}
if p.MaxTime != "86400" {
t.Errorf("expected max_time '86400', got %s", p.MaxTime)
}
}
func TestGetPartitions_Empty(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
partitions, err := svc.GetPartitions(context.Background())
if err != nil {
t.Fatalf("GetPartitions returned error: %v", err)
}
if partitions != nil {
t.Errorf("expected nil for empty response, got %v", partitions)
}
}
func TestGetPartition(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/partition/gpu" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"partitions": []map[string]interface{}{
{
"name": "gpu",
"partition": map[string]interface{}{
"state": []string{"UP"},
},
"nodes": map[string]interface{}{
"configured": "gpu[1-4]",
"total": 4,
},
"maximums": map[string]interface{}{
"time": map[string]interface{}{
"set": true,
"infinite": true,
},
},
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
part, err := svc.GetPartition(context.Background(), "gpu")
if err != nil {
t.Fatalf("GetPartition returned error: %v", err)
}
if part == nil {
t.Fatal("expected partition, got nil")
}
if part.Name != "gpu" {
t.Errorf("expected name gpu, got %s", part.Name)
}
if part.MaxTime != "UNLIMITED" {
t.Errorf("expected max_time UNLIMITED, got %s", part.MaxTime)
}
}
func TestGetPartition_NotFound(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
part, err := svc.GetPartition(context.Background(), "missing")
if err != nil {
t.Fatalf("GetPartition returned error: %v", err)
}
if part != nil {
t.Errorf("expected nil for missing partition, got %+v", part)
}
}
func TestGetDiag(t *testing.T) {
client, cleanup := mockServer(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/slurm/v0.0.40/diag" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
resp := map[string]interface{}{
"statistics": map[string]interface{}{
"server_thread_count": 10,
"agent_queue_size": 5,
"jobs_submitted": 100,
"jobs_running": 20,
"schedule_queue_length": 3,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
defer cleanup()
svc := NewClusterService(client, zap.NewNop())
diag, err := svc.GetDiag(context.Background())
if err != nil {
t.Fatalf("GetDiag returned error: %v", err)
}
if diag == nil {
t.Fatal("expected diag response, got nil")
}
if diag.Statistics == nil {
t.Fatal("expected statistics, got nil")
}
if diag.Statistics.ServerThreadCount == nil || *diag.Statistics.ServerThreadCount != 10 {
t.Errorf("expected server_thread_count 10, got %v", diag.Statistics.ServerThreadCount)
}
}
func TestNewSlurmClient(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "jwt.key")
os.WriteFile(keyPath, make([]byte, 32), 0644)
client, err := NewSlurmClient("http://localhost:6820", "root", keyPath)
if err != nil {
t.Fatalf("NewSlurmClient returned error: %v", err)
}
if client == nil {
t.Fatal("expected client, got nil")
}
}
func newClusterServiceWithObserver(srv *httptest.Server) (*ClusterService, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return NewClusterService(client, l), recorded
}
func errorServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"errors": [{"error": "internal server error"}]}`))
}))
}
func TestClusterService_GetNodes_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetNodes(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
if len(entry.Context) == 0 {
t.Error("expected structured fields in log entry")
}
}
func TestClusterService_GetNode_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetNode(context.Background(), "test-node")
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
hasName := false
for _, f := range entry.Context {
if f.Key == "name" && f.String == "test-node" {
hasName = true
}
}
if !hasName {
t.Error("expected 'name' field with value 'test-node' in log entry")
}
}
func TestClusterService_GetPartitions_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetPartitions(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
if len(entry.Context) == 0 {
t.Error("expected structured fields in log entry")
}
}
func TestClusterService_GetPartition_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetPartition(context.Background(), "test-partition")
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
hasName := false
for _, f := range entry.Context {
if f.Key == "name" && f.String == "test-partition" {
hasName = true
}
}
if !hasName {
t.Error("expected 'name' field with value 'test-partition' in log entry")
}
}
func TestClusterService_GetDiag_ErrorLogging(t *testing.T) {
srv := errorServer()
defer srv.Close()
svc, logs := newClusterServiceWithObserver(srv)
_, err := svc.GetDiag(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
if len(entry.Context) == 0 {
t.Error("expected structured fields in log entry")
}
}

View File

@@ -0,0 +1,98 @@
package service
import (
"context"
"fmt"
"io"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// DownloadService handles file downloads with streaming and Range support.
type DownloadService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
bucket string
logger *zap.Logger
}
// NewDownloadService creates a new DownloadService.
func NewDownloadService(storage storage.ObjectStorage, blobStore *store.BlobStore, fileStore *store.FileStore, bucket string, logger *zap.Logger) *DownloadService {
return &DownloadService{
storage: storage,
blobStore: blobStore,
fileStore: fileStore,
bucket: bucket,
logger: logger,
}
}
// Download returns a stream reader for the given file, optionally limited to a byte range.
// Returns (reader, file, blob, start, end, error).
func (s *DownloadService) Download(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
file, err := s.fileStore.GetByID(ctx, fileID)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get file: %w", err)
}
if file == nil {
return nil, nil, nil, 0, 0, fmt.Errorf("file not found")
}
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get blob: %w", err)
}
if blob == nil {
return nil, nil, nil, 0, 0, fmt.Errorf("blob not found")
}
var start, end int64
if rangeHeader != "" {
start, end, err = server.ParseRange(rangeHeader, blob.FileSize)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("parse range: %w", err)
}
} else {
start = 0
end = blob.FileSize - 1
}
opts := storage.GetOptions{
Start: &start,
End: &end,
}
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, opts)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get object: %w", err)
}
return reader, file, blob, start, end, nil
}
// GetFileMetadata returns the file and its associated blob metadata.
func (s *DownloadService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
file, err := s.fileStore.GetByID(ctx, fileID)
if err != nil {
return nil, nil, fmt.Errorf("get file: %w", err)
}
if file == nil {
return nil, nil, fmt.Errorf("file not found")
}
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
if err != nil {
return nil, nil, fmt.Errorf("get blob: %w", err)
}
if blob == nil {
return nil, nil, fmt.Errorf("blob not found")
}
return file, blob, nil
}

View File

@@ -0,0 +1,260 @@
package service
import (
"bytes"
"context"
"io"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
gormlogger "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type mockDownloadStorage struct {
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
}
func (m *mockDownloadStorage) PutObject(_ context.Context, _ string, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockDownloadStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return nil, storage.ObjectInfo{}, nil
}
func (m *mockDownloadStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockDownloadStorage) AbortMultipartUpload(_ context.Context, _ string, _ string, _ string) error {
return nil
}
func (m *mockDownloadStorage) RemoveIncompleteUpload(_ context.Context, _ string, _ string) error {
return nil
}
func (m *mockDownloadStorage) RemoveObject(_ context.Context, _ string, _ string, _ storage.RemoveObjectOptions) error {
return nil
}
func (m *mockDownloadStorage) ListObjects(_ context.Context, _ string, _ string, _ bool) ([]storage.ObjectInfo, error) {
return nil, nil
}
func (m *mockDownloadStorage) RemoveObjects(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
return nil
}
func (m *mockDownloadStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (m *mockDownloadStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (m *mockDownloadStorage) StatObject(_ context.Context, _ string, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupDownloadService(t *testing.T) (*DownloadService, *store.FileStore, *store.BlobStore, *mockDownloadStorage) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
mockStorage := &mockDownloadStorage{}
blobStore := store.NewBlobStore(db)
fileStore := store.NewFileStore(db)
svc := NewDownloadService(mockStorage, blobStore, fileStore, "test-bucket", zap.NewNop())
return svc, fileStore, blobStore, mockStorage
}
func createTestFileAndBlob(t *testing.T, fileStore *store.FileStore, blobStore *store.BlobStore) (*model.File, *model.FileBlob) {
t.Helper()
blob := &model.FileBlob{
SHA256: "abc123def456abc123def456abc123def456abc123def456abc123def456abcd",
MinioKey: "chunks/session1/part-0",
FileSize: 5000,
MimeType: "application/octet-stream",
RefCount: 1,
}
if err := blobStore.Create(context.Background(), blob); err != nil {
t.Fatalf("create blob: %v", err)
}
file := &model.File{
Name: "test.dat",
BlobSHA256: blob.SHA256,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := fileStore.Create(context.Background(), file); err != nil {
t.Fatalf("create file: %v", err)
}
return file, blob
}
func TestDownload_FullFile(t *testing.T) {
svc, fileStore, blobStore, mockStorage := setupDownloadService(t)
file, blob := createTestFileAndBlob(t, fileStore, blobStore)
content := make([]byte, blob.FileSize)
for i := range content {
content[i] = byte(i % 256)
}
mockStorage.getObjectFn = func(_ context.Context, _, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start == nil || opts.End == nil {
t.Fatal("expected Start and End to be set")
}
if *opts.Start != 0 {
t.Fatalf("expected start=0, got %d", *opts.Start)
}
if *opts.End != blob.FileSize-1 {
t.Fatalf("expected end=%d, got %d", blob.FileSize-1, *opts.End)
}
return io.NopCloser(bytes.NewReader(content)), storage.ObjectInfo{Size: blob.FileSize}, nil
}
reader, gotFile, gotBlob, start, end, err := svc.Download(context.Background(), file.ID, "")
if err != nil {
t.Fatalf("Download: %v", err)
}
defer reader.Close()
if gotFile.ID != file.ID {
t.Fatalf("expected file ID %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.SHA256 != blob.SHA256 {
t.Fatalf("expected blob SHA256 %s, got %s", blob.SHA256, gotBlob.SHA256)
}
if start != 0 {
t.Fatalf("expected start=0, got %d", start)
}
if end != blob.FileSize-1 {
t.Fatalf("expected end=%d, got %d", blob.FileSize-1, end)
}
read, _ := io.ReadAll(reader)
if int64(len(read)) != blob.FileSize {
t.Fatalf("expected %d bytes, got %d", blob.FileSize, len(read))
}
}
func TestDownload_WithRange(t *testing.T) {
svc, fileStore, blobStore, mockStorage := setupDownloadService(t)
file, _ := createTestFileAndBlob(t, fileStore, blobStore)
content := make([]byte, 1024)
for i := range content {
content[i] = byte(i % 256)
}
mockStorage.getObjectFn = func(_ context.Context, _, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start == nil || opts.End == nil {
t.Fatal("expected Start and End to be set")
}
if *opts.Start != 0 {
t.Fatalf("expected start=0, got %d", *opts.Start)
}
if *opts.End != 1023 {
t.Fatalf("expected end=1023, got %d", *opts.End)
}
return io.NopCloser(bytes.NewReader(content[:1024])), storage.ObjectInfo{Size: 1024}, nil
}
reader, _, _, start, end, err := svc.Download(context.Background(), file.ID, "bytes=0-1023")
if err != nil {
t.Fatalf("Download: %v", err)
}
defer reader.Close()
if start != 0 {
t.Fatalf("expected start=0, got %d", start)
}
if end != 1023 {
t.Fatalf("expected end=1023, got %d", end)
}
read, _ := io.ReadAll(reader)
if len(read) != 1024 {
t.Fatalf("expected 1024 bytes, got %d", len(read))
}
}
func TestDownload_FileNotFound(t *testing.T) {
svc, _, _, _ := setupDownloadService(t)
_, _, _, _, _, err := svc.Download(context.Background(), 99999, "")
if err == nil {
t.Fatal("expected error for missing file")
}
if err.Error() != "file not found" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDownload_BlobNotFound(t *testing.T) {
svc, fileStore, _, _ := setupDownloadService(t)
file := &model.File{
Name: "orphan.dat",
BlobSHA256: "nonexistent_hash_0000000000000000000000000000000000000000",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := fileStore.Create(context.Background(), file); err != nil {
t.Fatalf("create file: %v", err)
}
_, _, _, _, _, err := svc.Download(context.Background(), file.ID, "")
if err == nil {
t.Fatal("expected error for missing blob")
}
if err.Error() != "blob not found" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestGetFileMetadata(t *testing.T) {
svc, fileStore, blobStore, _ := setupDownloadService(t)
file, _ := createTestFileAndBlob(t, fileStore, blobStore)
gotFile, gotBlob, err := svc.GetFileMetadata(context.Background(), file.ID)
if err != nil {
t.Fatalf("GetFileMetadata: %v", err)
}
if gotFile.ID != file.ID {
t.Fatalf("expected file ID %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.FileSize != 5000 {
t.Fatalf("expected file size 5000, got %d", gotBlob.FileSize)
}
if gotBlob.MimeType != "application/octet-stream" {
t.Fatalf("expected mime type application/octet-stream, got %s", gotBlob.MimeType)
}
}

View File

@@ -0,0 +1,178 @@
package service
import (
"context"
"fmt"
"io"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
)
// FileService handles file listing, metadata, download, and deletion operations.
type FileService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
bucket string
db *gorm.DB
logger *zap.Logger
}
// NewFileService creates a new FileService.
func NewFileService(storage storage.ObjectStorage, blobStore *store.BlobStore, fileStore *store.FileStore, bucket string, db *gorm.DB, logger *zap.Logger) *FileService {
return &FileService{
storage: storage,
blobStore: blobStore,
fileStore: fileStore,
bucket: bucket,
db: db,
logger: logger,
}
}
// ListFiles returns a paginated list of files, optionally filtered by folder or search query.
func (s *FileService) ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
var files []model.File
var total int64
var err error
if search != "" {
files, total, err = s.fileStore.Search(ctx, search, page, pageSize)
} else {
files, total, err = s.fileStore.List(ctx, folderID, page, pageSize)
}
if err != nil {
return nil, 0, fmt.Errorf("list files: %w", err)
}
responses := make([]model.FileResponse, 0, len(files))
for _, f := range files {
blob, err := s.blobStore.GetBySHA256(ctx, f.BlobSHA256)
if err != nil {
return nil, 0, fmt.Errorf("get blob for file %d: %w", f.ID, err)
}
if blob == nil {
return nil, 0, fmt.Errorf("blob not found for file %d", f.ID)
}
responses = append(responses, model.FileResponse{
ID: f.ID,
Name: f.Name,
FolderID: f.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: f.BlobSHA256,
CreatedAt: f.CreatedAt,
UpdatedAt: f.UpdatedAt,
})
}
return responses, total, nil
}
// GetFileMetadata returns the file and its associated blob metadata.
func (s *FileService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
file, err := s.fileStore.GetByID(ctx, fileID)
if err != nil {
return nil, nil, fmt.Errorf("get file: %w", err)
}
if file == nil {
return nil, nil, fmt.Errorf("file not found: %d", fileID)
}
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
if err != nil {
return nil, nil, fmt.Errorf("get blob: %w", err)
}
if blob == nil {
return nil, nil, fmt.Errorf("blob not found for file %d", fileID)
}
return file, blob, nil
}
// DownloadFile returns a reader for the file content, along with file and blob metadata.
// If rangeHeader is non-empty, it parses the range and returns partial content.
func (s *FileService) DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
file, blob, err := s.GetFileMetadata(ctx, fileID)
if err != nil {
return nil, nil, nil, 0, 0, err
}
var start, end int64
if rangeHeader != "" {
start, end, err = server.ParseRange(rangeHeader, blob.FileSize)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("parse range: %w", err)
}
} else {
start = 0
end = blob.FileSize - 1
}
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{
Start: &start,
End: &end,
})
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get object: %w", err)
}
return reader, file, blob, start, end, nil
}
// DeleteFile soft-deletes a file. If no other active files reference the same blob,
// it decrements the blob ref count and removes the object from storage when ref count reaches 0.
func (s *FileService) DeleteFile(ctx context.Context, fileID int64) error {
return s.db.Transaction(func(tx *gorm.DB) error {
txFileStore := store.NewFileStore(tx)
txBlobStore := store.NewBlobStore(tx)
blobSHA256, err := txFileStore.GetBlobSHA256ByID(ctx, fileID)
if err != nil {
return fmt.Errorf("get blob sha256: %w", err)
}
if blobSHA256 == "" {
return fmt.Errorf("file not found: %d", fileID)
}
if err := tx.Delete(&model.File{}, fileID).Error; err != nil {
return fmt.Errorf("soft delete file: %w", err)
}
activeCount, err := txFileStore.CountByBlobSHA256(ctx, blobSHA256)
if err != nil {
return fmt.Errorf("count active refs: %w", err)
}
if activeCount == 0 {
newRefCount, err := txBlobStore.DecrementRef(ctx, blobSHA256)
if err != nil {
return fmt.Errorf("decrement ref: %w", err)
}
if newRefCount == 0 {
blob, err := txBlobStore.GetBySHA256(ctx, blobSHA256)
if err != nil {
return fmt.Errorf("get blob for cleanup: %w", err)
}
if blob != nil {
if err := s.storage.RemoveObject(ctx, s.bucket, blob.MinioKey, storage.RemoveObjectOptions{}); err != nil {
return fmt.Errorf("remove object: %w", err)
}
if err := txBlobStore.Delete(ctx, blobSHA256); err != nil {
return fmt.Errorf("delete blob: %w", err)
}
}
}
}
return nil
})
}

View File

@@ -0,0 +1,484 @@
package service
import (
"bytes"
"context"
"errors"
"io"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type mockFileStorage struct {
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
removeObjectFn func(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error
}
func (m *mockFileStorage) PutObject(_ context.Context, _, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockFileStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return io.NopCloser(strings.NewReader("data")), storage.ObjectInfo{}, nil
}
func (m *mockFileStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockFileStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockFileStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
return nil
}
func (m *mockFileStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
if m.removeObjectFn != nil {
return m.removeObjectFn(ctx, bucket, key, opts)
}
return nil
}
func (m *mockFileStorage) ListObjects(_ context.Context, _ string, _ string, _ bool) ([]storage.ObjectInfo, error) {
return nil, nil
}
func (m *mockFileStorage) RemoveObjects(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
return nil
}
func (m *mockFileStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (m *mockFileStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (m *mockFileStorage) StatObject(_ context.Context, _, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupFileTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func setupFileService(t *testing.T) (*FileService, *mockFileStorage, *gorm.DB) {
t.Helper()
db := setupFileTestDB(t)
ms := &mockFileStorage{}
svc := NewFileService(ms, store.NewBlobStore(db), store.NewFileStore(db), "test-bucket", db, zap.NewNop())
return svc, ms, db
}
func createTestBlob(t *testing.T, db *gorm.DB, sha256, minioKey, mimeType string, fileSize int64, refCount int) *model.FileBlob {
t.Helper()
blob := &model.FileBlob{
SHA256: sha256,
MinioKey: minioKey,
FileSize: fileSize,
MimeType: mimeType,
RefCount: refCount,
}
if err := db.Create(blob).Error; err != nil {
t.Fatalf("create blob: %v", err)
}
return blob
}
func createTestFile(t *testing.T, db *gorm.DB, name, blobSHA256 string, folderID *int64) *model.File {
t.Helper()
file := &model.File{
Name: name,
FolderID: folderID,
BlobSHA256: blobSHA256,
}
if err := db.Create(file).Error; err != nil {
t.Fatalf("create file: %v", err)
}
return file
}
func TestListFiles_Empty(t *testing.T) {
svc, _, _ := setupFileService(t)
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 0 {
t.Errorf("expected total 0, got %d", total)
}
if len(files) != 0 {
t.Errorf("expected empty files, got %d", len(files))
}
}
func TestListFiles_WithFiles(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256abc", "blobs/abc", "text/plain", 1024, 2)
createTestFile(t, db, "file1.txt", blob.SHA256, nil)
createTestFile(t, db, "file2.txt", blob.SHA256, nil)
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
if len(files) != 2 {
t.Fatalf("expected 2 files, got %d", len(files))
}
for _, f := range files {
if f.Size != 1024 {
t.Errorf("expected size 1024, got %d", f.Size)
}
if f.MimeType != "text/plain" {
t.Errorf("expected mime text/plain, got %s", f.MimeType)
}
if f.SHA256 != "sha256abc" {
t.Errorf("expected sha256 sha256abc, got %s", f.SHA256)
}
}
}
func TestListFiles_Search(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256search", "blobs/search", "image/png", 2048, 1)
createTestFile(t, db, "photo.png", blob.SHA256, nil)
createTestFile(t, db, "document.pdf", "sha256other", nil)
createTestBlob(t, db, "sha256other", "blobs/other", "application/pdf", 512, 1)
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "photo")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 1 {
t.Errorf("expected total 1, got %d", total)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Name != "photo.png" {
t.Errorf("expected photo.png, got %s", files[0].Name)
}
}
func TestGetFileMetadata_Found(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256meta", "blobs/meta", "application/json", 42, 1)
file := createTestFile(t, db, "data.json", blob.SHA256, nil)
gotFile, gotBlob, err := svc.GetFileMetadata(context.Background(), file.ID)
if err != nil {
t.Fatalf("GetFileMetadata: %v", err)
}
if gotFile.ID != file.ID {
t.Errorf("expected file id %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.SHA256 != blob.SHA256 {
t.Errorf("expected blob sha256 %s, got %s", blob.SHA256, gotBlob.SHA256)
}
}
func TestGetFileMetadata_NotFound(t *testing.T) {
svc, _, _ := setupFileService(t)
_, _, err := svc.GetFileMetadata(context.Background(), 9999)
if err == nil {
t.Fatal("expected error for missing file")
}
}
func TestDownloadFile_Full(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256dl", "blobs/dl", "text/plain", 100, 1)
file := createTestFile(t, db, "download.txt", blob.SHA256, nil)
content := []byte("hello world")
ms.getObjectFn = func(_ context.Context, _ string, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start == nil || opts.End == nil {
t.Error("expected start and end to be set")
} else if *opts.Start != 0 || *opts.End != 99 {
t.Errorf("expected range 0-99, got %d-%d", *opts.Start, *opts.End)
}
return io.NopCloser(bytes.NewReader(content)), storage.ObjectInfo{}, nil
}
reader, gotFile, gotBlob, start, end, err := svc.DownloadFile(context.Background(), file.ID, "")
if err != nil {
t.Fatalf("DownloadFile: %v", err)
}
defer reader.Close()
if gotFile.ID != file.ID {
t.Errorf("expected file id %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.SHA256 != blob.SHA256 {
t.Errorf("expected blob sha256 %s, got %s", blob.SHA256, gotBlob.SHA256)
}
if start != 0 {
t.Errorf("expected start 0, got %d", start)
}
if end != 99 {
t.Errorf("expected end 99, got %d", end)
}
data, _ := io.ReadAll(reader)
if string(data) != "hello world" {
t.Errorf("expected 'hello world', got %q", string(data))
}
}
func TestDownloadFile_WithRange(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256range", "blobs/range", "text/plain", 1000, 1)
file := createTestFile(t, db, "range.txt", blob.SHA256, nil)
ms.getObjectFn = func(_ context.Context, _ string, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start != nil && *opts.Start != 100 {
t.Errorf("expected start 100, got %d", *opts.Start)
}
if opts.End != nil && *opts.End != 199 {
t.Errorf("expected end 199, got %d", *opts.End)
}
return io.NopCloser(strings.NewReader("partial")), storage.ObjectInfo{}, nil
}
reader, _, _, start, end, err := svc.DownloadFile(context.Background(), file.ID, "bytes=100-199")
if err != nil {
t.Fatalf("DownloadFile: %v", err)
}
defer reader.Close()
if start != 100 {
t.Errorf("expected start 100, got %d", start)
}
if end != 199 {
t.Errorf("expected end 199, got %d", end)
}
}
func TestDeleteFile_LastRef(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256del", "blobs/del", "text/plain", 50, 1)
file := createTestFile(t, db, "delete-me.txt", blob.SHA256, nil)
removed := false
ms.removeObjectFn = func(_ context.Context, bucket, key string, _ storage.RemoveObjectOptions) error {
if bucket != "test-bucket" {
t.Errorf("expected bucket 'test-bucket', got %q", bucket)
}
if key != "blobs/del" {
t.Errorf("expected key 'blobs/del', got %q", key)
}
removed = true
return nil
}
if err := svc.DeleteFile(context.Background(), file.ID); err != nil {
t.Fatalf("DeleteFile: %v", err)
}
if !removed {
t.Error("expected RemoveObject to be called")
}
var count int64
db.Model(&model.FileBlob{}).Where("sha256 = ?", "sha256del").Count(&count)
if count != 0 {
t.Errorf("expected blob to be hard deleted, found %d records", count)
}
var fileCount int64
db.Unscoped().Model(&model.File{}).Where("id = ?", file.ID).Count(&fileCount)
if fileCount != 1 {
t.Errorf("expected file to still exist (soft deleted), found %d", fileCount)
}
var deletedFile model.File
db.Unscoped().First(&deletedFile, file.ID)
if deletedFile.DeletedAt.Time.IsZero() {
t.Error("expected deleted_at to be set")
}
}
func TestDeleteFile_OtherRefsExist(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256multi", "blobs/multi", "text/plain", 50, 3)
file1 := createTestFile(t, db, "ref1.txt", blob.SHA256, nil)
createTestFile(t, db, "ref2.txt", blob.SHA256, nil)
createTestFile(t, db, "ref3.txt", blob.SHA256, nil)
removed := false
ms.removeObjectFn = func(_ context.Context, _, _ string, _ storage.RemoveObjectOptions) error {
removed = true
return nil
}
if err := svc.DeleteFile(context.Background(), file1.ID); err != nil {
t.Fatalf("DeleteFile: %v", err)
}
if removed {
t.Error("expected RemoveObject NOT to be called since other refs exist")
}
var updatedBlob model.FileBlob
db.Where("sha256 = ?", "sha256multi").First(&updatedBlob)
if updatedBlob.RefCount != 3 {
t.Errorf("expected ref_count to remain 3, got %d", updatedBlob.RefCount)
}
}
func TestDeleteFile_SoftDeleteNotAffectRefcount(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256soft", "blobs/soft", "text/plain", 50, 2)
file1 := createTestFile(t, db, "soft1.txt", blob.SHA256, nil)
file2 := createTestFile(t, db, "soft2.txt", blob.SHA256, nil)
if err := svc.DeleteFile(context.Background(), file1.ID); err != nil {
t.Fatalf("DeleteFile: %v", err)
}
var updatedBlob model.FileBlob
db.Where("sha256 = ?", "sha256soft").First(&updatedBlob)
if updatedBlob.RefCount != 2 {
t.Errorf("expected ref_count to remain 2 (soft delete should not decrement), got %d", updatedBlob.RefCount)
}
activeCount, err := store.NewFileStore(db).CountByBlobSHA256(context.Background(), "sha256soft")
if err != nil {
t.Fatalf("CountByBlobSHA256: %v", err)
}
if activeCount != 1 {
t.Errorf("expected 1 active ref after soft delete, got %d", activeCount)
}
var allFiles []model.File
db.Unscoped().Where("blob_sha256 = ?", "sha256soft").Find(&allFiles)
if len(allFiles) != 2 {
t.Errorf("expected 2 total files (one soft deleted), got %d", len(allFiles))
}
if err := svc.DeleteFile(context.Background(), file2.ID); err != nil {
t.Fatalf("DeleteFile second: %v", err)
}
activeCount2, err := store.NewFileStore(db).CountByBlobSHA256(context.Background(), "sha256soft")
if err != nil {
t.Fatalf("CountByBlobSHA256 after second delete: %v", err)
}
if activeCount2 != 0 {
t.Errorf("expected 0 active refs after both deleted, got %d", activeCount2)
}
var finalBlob model.FileBlob
db.Where("sha256 = ?", "sha256soft").First(&finalBlob)
if finalBlob.RefCount != 1 {
t.Errorf("expected ref_count=1 (decremented once from 2), got %d", finalBlob.RefCount)
}
}
func TestDeleteFile_NotFound(t *testing.T) {
svc, _, _ := setupFileService(t)
err := svc.DeleteFile(context.Background(), 9999)
if err == nil {
t.Fatal("expected error for missing file")
}
}
func TestDownloadFile_StorageError(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256err", "blobs/err", "text/plain", 100, 1)
file := createTestFile(t, db, "error.txt", blob.SHA256, nil)
ms.getObjectFn = func(_ context.Context, _ string, _ string, _ storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return nil, storage.ObjectInfo{}, errors.New("storage unavailable")
}
_, _, _, _, _, err := svc.DownloadFile(context.Background(), file.ID, "")
if err == nil {
t.Fatal("expected error from storage")
}
if !strings.Contains(err.Error(), "storage unavailable") {
t.Errorf("expected storage error, got: %v", err)
}
}
func TestListFiles_WithFolderFilter(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256folder", "blobs/folder", "text/plain", 100, 2)
folderID := int64(1)
createTestFile(t, db, "in_folder.txt", blob.SHA256, &folderID)
createTestFile(t, db, "root.txt", blob.SHA256, nil)
files, total, err := svc.ListFiles(context.Background(), &folderID, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 1 {
t.Errorf("expected total 1, got %d", total)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Name != "in_folder.txt" {
t.Errorf("expected in_folder.txt, got %s", files[0].Name)
}
rootFiles, rootTotal, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles root: %v", err)
}
if rootTotal != 1 {
t.Errorf("expected root total 1, got %d", rootTotal)
}
if len(rootFiles) != 1 || rootFiles[0].Name != "root.txt" {
t.Errorf("expected root.txt in root listing")
}
}
func TestGetFileMetadata_BlobMissing(t *testing.T) {
svc, _, db := setupFileService(t)
file := createTestFile(t, db, "orphan.txt", "nonexistent_sha256", nil)
_, _, err := svc.GetFileMetadata(context.Background(), file.ID)
if err == nil {
t.Fatal("expected error when blob is missing")
}
}

View File

@@ -0,0 +1,145 @@
package service
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// FileStagingService batch downloads files from MinIO to a local (NFS) directory,
// deduplicating by blob SHA256 so each unique blob is fetched only once.
type FileStagingService struct {
fileStore *store.FileStore
blobStore *store.BlobStore
storage storage.ObjectStorage
bucket string
logger *zap.Logger
}
func NewFileStagingService(fileStore *store.FileStore, blobStore *store.BlobStore, st storage.ObjectStorage, bucket string, logger *zap.Logger) *FileStagingService {
return &FileStagingService{
fileStore: fileStore,
blobStore: blobStore,
storage: st,
bucket: bucket,
logger: logger,
}
}
// DownloadFilesToDir downloads the given files into destDir.
// Files sharing the same blob SHA256 are deduplicated: the blob is fetched once
// and then copied to each filename. Filenames are sanitized with filepath.Base
// to prevent path traversal.
func (s *FileStagingService) DownloadFilesToDir(ctx context.Context, fileIDs []int64, destDir string) error {
if len(fileIDs) == 0 {
return nil
}
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
if err != nil {
return fmt.Errorf("fetch files: %w", err)
}
type group struct {
primary *model.File // first file — written via io.Copy from MinIO
others []*model.File // remaining files — local copy of primary
}
groups := make(map[string]*group)
for i := range files {
f := &files[i]
g, ok := groups[f.BlobSHA256]
if !ok {
groups[f.BlobSHA256] = &group{primary: f}
} else {
g.others = append(g.others, f)
}
}
sha256s := make([]string, 0, len(groups))
for sh := range groups {
sha256s = append(sha256s, sh)
}
blobs, err := s.blobStore.GetBySHA256s(ctx, sha256s)
if err != nil {
return fmt.Errorf("fetch blobs: %w", err)
}
blobMap := make(map[string]*model.FileBlob, len(blobs))
for i := range blobs {
blobMap[blobs[i].SHA256] = &blobs[i]
}
for sha256, g := range groups {
blob, ok := blobMap[sha256]
if !ok {
return fmt.Errorf("blob %s not found", sha256)
}
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{})
if err != nil {
return fmt.Errorf("get object %s: %w", blob.MinioKey, err)
}
// TODO: handle filename collisions when multiple files have the same Name (low risk without user auth, revisit when auth is added)
primaryName := filepath.Base(g.primary.Name)
primaryPath := filepath.Join(destDir, primaryName)
if err := writeFile(primaryPath, reader); err != nil {
reader.Close()
os.Remove(primaryPath)
return fmt.Errorf("write file %s: %w", primaryName, err)
}
reader.Close()
for _, other := range g.others {
otherName := filepath.Base(other.Name)
otherPath := filepath.Join(destDir, otherName)
if err := copyFile(primaryPath, otherPath); err != nil {
return fmt.Errorf("copy %s to %s: %w", primaryName, otherName, err)
}
}
}
return nil
}
func writeFile(path string, reader io.Reader) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, reader); err != nil {
return err
}
return nil
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
if _, err := io.Copy(out, in); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,232 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type stagingMockStorage struct {
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
}
func (m *stagingMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return nil, storage.ObjectInfo{}, nil
}
func (m *stagingMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *stagingMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *stagingMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
return nil
}
func (m *stagingMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
return nil
}
func (m *stagingMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
return nil
}
func (m *stagingMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
return nil, nil
}
func (m *stagingMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
return nil
}
func (m *stagingMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
return true, nil
}
func (m *stagingMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
return nil
}
func (m *stagingMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupStagingTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func newStagingService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *FileStagingService {
t.Helper()
return NewFileStagingService(
store.NewFileStore(db),
store.NewBlobStore(db),
st,
"test-bucket",
zap.NewNop(),
)
}
func TestFileStaging_DownloadWithDedup(t *testing.T) {
db := setupStagingTestDB(t)
sha1 := "aaa111"
sha2 := "bbb222"
db.Create(&model.FileBlob{SHA256: sha1, MinioKey: "blobs/aaa111", FileSize: 5, MimeType: "text/plain", RefCount: 2})
db.Create(&model.FileBlob{SHA256: sha2, MinioKey: "blobs/bbb222", FileSize: 3, MimeType: "text/plain", RefCount: 1})
db.Create(&model.File{Name: "file1.txt", BlobSHA256: sha1})
db.Create(&model.File{Name: "file2.txt", BlobSHA256: sha1})
db.Create(&model.File{Name: "file3.txt", BlobSHA256: sha2})
var files []model.File
db.Find(&files)
if len(files) < 3 {
t.Fatalf("need 3 files, got %d", len(files))
}
var getObjCalls int32
st := &stagingMockStorage{}
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
atomic.AddInt32(&getObjCalls, 1)
var content string
switch key {
case "blobs/aaa111":
content = "content-a"
case "blobs/bbb222":
content = "content-b"
default:
return nil, storage.ObjectInfo{}, fmt.Errorf("unexpected key %s", key)
}
return io.NopCloser(bytes.NewReader([]byte(content))), storage.ObjectInfo{Key: key}, nil
}
destDir := t.TempDir()
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{files[0].ID, files[1].ID, files[2].ID}, destDir)
if err != nil {
t.Fatalf("DownloadFilesToDir: %v", err)
}
if calls := atomic.LoadInt32(&getObjCalls); calls != 2 {
t.Errorf("GetObject called %d times, want 2", calls)
}
expected := map[string]string{
"file1.txt": "content-a",
"file2.txt": "content-a",
"file3.txt": "content-b",
}
for name, want := range expected {
p := filepath.Join(destDir, name)
data, err := os.ReadFile(p)
if err != nil {
t.Errorf("read %s: %v", name, err)
continue
}
if string(data) != want {
t.Errorf("%s content = %q, want %q", name, data, want)
}
}
}
func TestFileStaging_PathTraversal(t *testing.T) {
db := setupStagingTestDB(t)
sha := "traversal123"
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/traversal", FileSize: 4, MimeType: "text/plain", RefCount: 1})
db.Create(&model.File{Name: "../../../etc/passwd", BlobSHA256: sha})
var file model.File
db.First(&file)
st := &stagingMockStorage{}
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return io.NopCloser(bytes.NewReader([]byte("safe"))), storage.ObjectInfo{Key: key}, nil
}
destDir := t.TempDir()
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
if err != nil {
t.Fatalf("DownloadFilesToDir: %v", err)
}
sanitized := filepath.Join(destDir, "passwd")
data, err := os.ReadFile(sanitized)
if err != nil {
t.Fatalf("read sanitized file: %v", err)
}
if string(data) != "safe" {
t.Errorf("content = %q, want %q", data, "safe")
}
entries, err := os.ReadDir(destDir)
if err != nil {
t.Fatalf("readdir: %v", err)
}
for _, e := range entries {
if e.Name() != "passwd" {
t.Errorf("unexpected file in destDir: %s", e.Name())
}
}
}
func TestFileStaging_EmptyList(t *testing.T) {
db := setupStagingTestDB(t)
st := &stagingMockStorage{}
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{}, t.TempDir())
if err != nil {
t.Errorf("expected nil for empty list, got %v", err)
}
}
func TestFileStaging_GetObjectFails(t *testing.T) {
db := setupStagingTestDB(t)
sha := "fail123"
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/fail", FileSize: 5, MimeType: "text/plain", RefCount: 1})
db.Create(&model.File{Name: "willfail.txt", BlobSHA256: sha})
var file model.File
db.First(&file)
st := &stagingMockStorage{}
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return nil, storage.ObjectInfo{}, fmt.Errorf("minio down")
}
destDir := t.TempDir()
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
if err == nil {
t.Fatal("expected error when GetObject fails")
}
if !strings.Contains(err.Error(), "minio down") {
t.Errorf("error = %q, want 'minio down'", err.Error())
}
}

View File

@@ -0,0 +1,142 @@
package service
import (
"context"
"fmt"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// FolderService provides CRUD operations for folders with path validation
// and directory tree management.
type FolderService struct {
folderStore *store.FolderStore
fileStore *store.FileStore
logger *zap.Logger
}
// NewFolderService creates a new FolderService.
func NewFolderService(folderStore *store.FolderStore, fileStore *store.FileStore, logger *zap.Logger) *FolderService {
return &FolderService{
folderStore: folderStore,
fileStore: fileStore,
logger: logger,
}
}
// CreateFolder validates the name, computes a materialized path, checks for
// duplicates, and persists the folder.
func (s *FolderService) CreateFolder(ctx context.Context, name string, parentID *int64) (*model.FolderResponse, error) {
if err := model.ValidateFolderName(name); err != nil {
return nil, fmt.Errorf("invalid folder name: %w", err)
}
var path string
if parentID == nil {
path = "/" + name + "/"
} else {
parent, err := s.folderStore.GetByID(ctx, *parentID)
if err != nil {
return nil, fmt.Errorf("get parent folder: %w", err)
}
if parent == nil {
return nil, fmt.Errorf("parent folder %d not found", *parentID)
}
path = parent.Path + name + "/"
}
existing, err := s.folderStore.GetByPath(ctx, path)
if err != nil {
return nil, fmt.Errorf("check duplicate path: %w", err)
}
if existing != nil {
return nil, fmt.Errorf("folder with path %q already exists", path)
}
folder := &model.Folder{
Name: name,
ParentID: parentID,
Path: path,
}
if err := s.folderStore.Create(ctx, folder); err != nil {
return nil, fmt.Errorf("create folder: %w", err)
}
return s.toFolderResponse(ctx, folder)
}
// GetFolder retrieves a folder by ID with file and subfolder counts.
func (s *FolderService) GetFolder(ctx context.Context, id int64) (*model.FolderResponse, error) {
folder, err := s.folderStore.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get folder: %w", err)
}
if folder == nil {
return nil, fmt.Errorf("folder %d not found", id)
}
return s.toFolderResponse(ctx, folder)
}
// ListFolders returns all direct children of the given parent folder (or root
// if parentID is nil).
func (s *FolderService) ListFolders(ctx context.Context, parentID *int64) ([]model.FolderResponse, error) {
folders, err := s.folderStore.ListByParentID(ctx, parentID)
if err != nil {
return nil, fmt.Errorf("list folders: %w", err)
}
result := make([]model.FolderResponse, 0, len(folders))
for i := range folders {
resp, err := s.toFolderResponse(ctx, &folders[i])
if err != nil {
return nil, err
}
result = append(result, *resp)
}
return result, nil
}
// DeleteFolder soft-deletes a folder only if it has no children (sub-folders
// or files).
func (s *FolderService) DeleteFolder(ctx context.Context, id int64) error {
hasChildren, err := s.folderStore.HasChildren(ctx, id)
if err != nil {
return fmt.Errorf("check children: %w", err)
}
if hasChildren {
return fmt.Errorf("folder is not empty")
}
if err := s.folderStore.Delete(ctx, id); err != nil {
return fmt.Errorf("delete folder: %w", err)
}
return nil
}
// toFolderResponse converts a Folder model into a FolderResponse DTO with
// computed file and subfolder counts.
func (s *FolderService) toFolderResponse(ctx context.Context, f *model.Folder) (*model.FolderResponse, error) {
subFolders, err := s.folderStore.ListByParentID(ctx, &f.ID)
if err != nil {
return nil, fmt.Errorf("count subfolders: %w", err)
}
_, fileCount, err := s.fileStore.List(ctx, &f.ID, 1, 1)
if err != nil {
return nil, fmt.Errorf("count files: %w", err)
}
return &model.FolderResponse{
ID: f.ID,
Name: f.Name,
ParentID: f.ParentID,
Path: f.Path,
FileCount: fileCount,
SubFolderCount: int64(len(subFolders)),
CreatedAt: f.CreatedAt,
}, nil
}

View File

@@ -0,0 +1,230 @@
package service
import (
"context"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
gormlogger "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupFolderService(t *testing.T) *FolderService {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Folder{}, &model.File{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return NewFolderService(
store.NewFolderStore(db),
store.NewFileStore(db),
zap.NewNop(),
)
}
func TestCreateFolder_ValidName(t *testing.T) {
svc := setupFolderService(t)
resp, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
if resp.Name != "datasets" {
t.Errorf("Name = %q, want %q", resp.Name, "datasets")
}
if resp.Path != "/datasets/" {
t.Errorf("Path = %q, want %q", resp.Path, "/datasets/")
}
if resp.ParentID != nil {
t.Errorf("ParentID should be nil for root folder, got %d", *resp.ParentID)
}
}
func TestCreateFolder_SubFolder(t *testing.T) {
svc := setupFolderService(t)
parent, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("create parent: %v", err)
}
child, err := svc.CreateFolder(context.Background(), "images", &parent.ID)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
if child.Path != "/datasets/images/" {
t.Errorf("Path = %q, want %q", child.Path, "/datasets/images/")
}
if child.ParentID == nil || *child.ParentID != parent.ID {
t.Errorf("ParentID = %v, want %d", child.ParentID, parent.ID)
}
}
func TestCreateFolder_RejectPathTraversal(t *testing.T) {
svc := setupFolderService(t)
for _, name := range []string{"..", "../etc", "/absolute", "a/b"} {
_, err := svc.CreateFolder(context.Background(), name, nil)
if err == nil {
t.Errorf("expected error for folder name %q, got nil", name)
}
}
}
func TestCreateFolder_DuplicatePath(t *testing.T) {
svc := setupFolderService(t)
_, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("first create: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "datasets", nil)
if err == nil {
t.Fatal("expected error for duplicate folder name")
}
if !strings.Contains(err.Error(), "already exists") {
t.Errorf("error should mention 'already exists', got: %v", err)
}
}
func TestCreateFolder_ParentNotFound(t *testing.T) {
svc := setupFolderService(t)
badID := int64(99999)
_, err := svc.CreateFolder(context.Background(), "orphan", &badID)
if err == nil {
t.Fatal("expected error for non-existent parent")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestGetFolder(t *testing.T) {
svc := setupFolderService(t)
created, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
resp, err := svc.GetFolder(context.Background(), created.ID)
if err != nil {
t.Fatalf("GetFolder() error = %v", err)
}
if resp.ID != created.ID {
t.Errorf("ID = %d, want %d", resp.ID, created.ID)
}
if resp.Name != "datasets" {
t.Errorf("Name = %q, want %q", resp.Name, "datasets")
}
if resp.FileCount != 0 {
t.Errorf("FileCount = %d, want 0", resp.FileCount)
}
if resp.SubFolderCount != 0 {
t.Errorf("SubFolderCount = %d, want 0", resp.SubFolderCount)
}
}
func TestGetFolder_NotFound(t *testing.T) {
svc := setupFolderService(t)
_, err := svc.GetFolder(context.Background(), 99999)
if err == nil {
t.Fatal("expected error for non-existent folder")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestListFolders(t *testing.T) {
svc := setupFolderService(t)
parent, err := svc.CreateFolder(context.Background(), "root", nil)
if err != nil {
t.Fatalf("create root: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "child1", &parent.ID)
if err != nil {
t.Fatalf("create child1: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "child2", &parent.ID)
if err != nil {
t.Fatalf("create child2: %v", err)
}
list, err := svc.ListFolders(context.Background(), &parent.ID)
if err != nil {
t.Fatalf("ListFolders() error = %v", err)
}
if len(list) != 2 {
t.Fatalf("len(list) = %d, want 2", len(list))
}
names := make(map[string]bool)
for _, f := range list {
names[f.Name] = true
}
if !names["child1"] || !names["child2"] {
t.Errorf("expected child1 and child2, got %v", names)
}
}
func TestListFolders_Root(t *testing.T) {
svc := setupFolderService(t)
_, err := svc.CreateFolder(context.Background(), "alpha", nil)
if err != nil {
t.Fatalf("create alpha: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "beta", nil)
if err != nil {
t.Fatalf("create beta: %v", err)
}
list, err := svc.ListFolders(context.Background(), nil)
if err != nil {
t.Fatalf("ListFolders() error = %v", err)
}
if len(list) != 2 {
t.Errorf("len(list) = %d, want 2", len(list))
}
}
func TestDeleteFolder_Success(t *testing.T) {
svc := setupFolderService(t)
created, err := svc.CreateFolder(context.Background(), "temp", nil)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
if err := svc.DeleteFolder(context.Background(), created.ID); err != nil {
t.Fatalf("DeleteFolder() error = %v", err)
}
_, err = svc.GetFolder(context.Background(), created.ID)
if err == nil {
t.Error("expected error after deletion")
}
}
func TestDeleteFolder_NonEmpty(t *testing.T) {
svc := setupFolderService(t)
parent, err := svc.CreateFolder(context.Background(), "haschild", nil)
if err != nil {
t.Fatalf("create parent: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "child", &parent.ID)
if err != nil {
t.Fatalf("create child: %v", err)
}
err = svc.DeleteFolder(context.Background(), parent.ID)
if err == nil {
t.Fatal("expected error when deleting non-empty folder")
}
if !strings.Contains(err.Error(), "not empty") {
t.Errorf("error should mention 'not empty', got: %v", err)
}
}

View File

@@ -0,0 +1,496 @@
package service
import (
"context"
"fmt"
"strconv"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
)
// JobService wraps Slurm SDK job operations with model mapping and pagination.
type JobService struct {
client *slurm.Client
logger *zap.Logger
}
// NewJobService creates a new JobService with the given Slurm SDK client.
func NewJobService(client *slurm.Client, logger *zap.Logger) *JobService {
return &JobService{client: client, logger: logger}
}
// SubmitJob submits a new job to Slurm and returns the job ID.
func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest) (*model.JobResponse, error) {
script := req.Script
jobDesc := &slurm.JobDescMsg{
Script: &script,
Partition: strToPtrOrNil(req.Partition),
Qos: strToPtrOrNil(req.QOS),
Name: strToPtrOrNil(req.JobName),
}
if req.WorkDir != "" {
jobDesc.CurrentWorkingDirectory = &req.WorkDir
}
if req.CPUs > 0 {
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
}
if req.TimeLimit != "" {
if mins, err := strconv.ParseInt(req.TimeLimit, 10, 64); err == nil {
jobDesc.TimeLimit = &slurm.Uint32NoVal{Number: &mins}
}
}
jobDesc.Environment = slurm.StringArray{
"PATH=/usr/local/bin:/usr/bin:/bin",
"HOME=/root",
}
submitReq := &slurm.JobSubmitReq{
Script: &script,
Job: jobDesc,
}
s.logger.Debug("slurm API request",
zap.String("operation", "SubmitJob"),
zap.Any("body", submitReq),
)
start := time.Now()
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
return nil, fmt.Errorf("submit job: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Any("body", result),
)
resp := &model.JobResponse{}
if result.Result != nil && result.Result.JobID != nil {
resp.JobID = *result.Result.JobID
} else if result.JobID != nil {
resp.JobID = *result.JobID
}
s.logger.Info("job submitted", zap.String("job_name", req.JobName), zap.Int32("job_id", resp.JobID))
return resp, nil
}
// GetJobs lists all current jobs from Slurm with in-memory pagination.
func (s *JobService) GetJobs(ctx context.Context, query *model.JobListQuery) (*model.JobListResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobs"),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
return nil, fmt.Errorf("get jobs: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapJobInfo(&result.Jobs[i]))
}
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// GetJob retrieves a single job by ID. If the job is not found in the active
// queue (404 or empty result), it falls back to querying SlurmDBD history.
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
if slurm.IsNotFound(err) {
s.logger.Debug("job not in active queue, querying history",
zap.String("job_id", jobID),
)
return s.getJobFromHistory(ctx, jobID)
}
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
return nil, fmt.Errorf("get job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
s.logger.Debug("empty jobs response, querying history",
zap.String("job_id", jobID),
)
return s.getJobFromHistory(ctx, jobID)
}
resp := mapJobInfo(&result.Jobs[0])
return &resp, nil
}
func (s *JobService) getJobFromHistory(ctx context.Context, jobID string) (*model.JobResponse, error) {
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJob(ctx, jobID)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurmdb API error response",
zap.String("operation", "getJobFromHistory"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
if slurm.IsNotFound(err) {
return nil, nil
}
return nil, fmt.Errorf("get job history %s: %w", jobID, err)
}
s.logger.Debug("slurmdb API response",
zap.String("operation", "getJobFromHistory"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
return nil, nil
}
resp := mapSlurmdbJob(&result.Jobs[0])
return &resp, nil
}
// CancelJob cancels a job by ID.
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
s.logger.Debug("slurm API request",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
return fmt.Errorf("cancel job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
s.logger.Info("job cancelled", zap.String("job_id", jobID))
return nil
}
// GetJobHistory queries SlurmDBD for historical jobs with pagination.
func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQuery) (*model.JobListResponse, error) {
opts := &slurm.GetSlurmdbJobsOptions{}
if query.Users != "" {
opts.Users = strToPtr(query.Users)
}
if query.Account != "" {
opts.Account = strToPtr(query.Account)
}
if query.Partition != "" {
opts.Partition = strToPtr(query.Partition)
}
if query.State != "" {
opts.State = strToPtr(query.State)
}
if query.JobName != "" {
opts.JobName = strToPtr(query.JobName)
}
if query.StartTime != "" {
opts.StartTime = strToPtr(query.StartTime)
}
if query.EndTime != "" {
opts.EndTime = strToPtr(query.EndTime)
}
if query.SubmitTime != "" {
opts.SubmitTime = strToPtr(query.SubmitTime)
}
if query.Cluster != "" {
opts.Cluster = strToPtr(query.Cluster)
}
if query.Qos != "" {
opts.Qos = strToPtr(query.Qos)
}
if query.Constraints != "" {
opts.Constraints = strToPtr(query.Constraints)
}
if query.ExitCode != "" {
opts.ExitCode = strToPtr(query.ExitCode)
}
if query.Node != "" {
opts.Node = strToPtr(query.Node)
}
if query.Reservation != "" {
opts.Reservation = strToPtr(query.Reservation)
}
if query.Groups != "" {
opts.Groups = strToPtr(query.Groups)
}
if query.Wckey != "" {
opts.Wckey = strToPtr(query.Wckey)
}
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobHistory"),
zap.Any("body", opts),
)
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
return nil, fmt.Errorf("get job history: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
}
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// ---------------------------------------------------------------------------
// Helper functions
// ---------------------------------------------------------------------------
func strToPtr(s string) *string { return &s }
// strPtrOrNil returns a pointer to s if non-empty, otherwise nil.
func strToPtrOrNil(s string) *string {
if s == "" {
return nil
}
return &s
}
func mapUint32NoValToInt32(v *slurm.Uint32NoVal) *int32 {
if v != nil && v.Number != nil {
n := int32(*v.Number)
return &n
}
return nil
}
// mapJobInfo maps SDK JobInfo to API JobResponse.
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
resp := model.JobResponse{}
if ji.JobID != nil {
resp.JobID = *ji.JobID
}
if ji.Name != nil {
resp.Name = *ji.Name
}
resp.State = ji.JobState
if ji.Partition != nil {
resp.Partition = *ji.Partition
}
resp.Account = derefStr(ji.Account)
resp.User = derefStr(ji.UserName)
resp.Cluster = derefStr(ji.Cluster)
resp.QOS = derefStr(ji.Qos)
resp.Priority = mapUint32NoValToInt32(ji.Priority)
resp.TimeLimit = uint32NoValString(ji.TimeLimit)
resp.StateReason = derefStr(ji.StateReason)
resp.Cpus = mapUint32NoValToInt32(ji.Cpus)
resp.Tasks = mapUint32NoValToInt32(ji.Tasks)
resp.NodeCount = mapUint32NoValToInt32(ji.NodeCount)
resp.BatchHost = derefStr(ji.BatchHost)
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
resp.SubmitTime = ji.SubmitTime.Number
}
if ji.StartTime != nil && ji.StartTime.Number != nil {
resp.StartTime = ji.StartTime.Number
}
if ji.EndTime != nil && ji.EndTime.Number != nil {
resp.EndTime = ji.EndTime.Number
}
if ji.ExitCode != nil && ji.ExitCode.ReturnCode != nil && ji.ExitCode.ReturnCode.Number != nil {
code := int32(*ji.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if ji.Nodes != nil {
resp.Nodes = *ji.Nodes
}
resp.StdOut = derefStr(ji.StandardOutput)
resp.StdErr = derefStr(ji.StandardError)
resp.StdIn = derefStr(ji.StandardInput)
resp.WorkDir = derefStr(ji.CurrentWorkingDirectory)
resp.Command = derefStr(ji.Command)
resp.ArrayJobID = mapUint32NoValToInt32(ji.ArrayJobID)
resp.ArrayTaskID = mapUint32NoValToInt32(ji.ArrayTaskID)
return resp
}
// mapSlurmdbJob maps SDK SlurmDBD Job to API JobResponse.
func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
resp := model.JobResponse{}
if j.JobID != nil {
resp.JobID = *j.JobID
}
if j.Name != nil {
resp.Name = *j.Name
}
if j.State != nil {
resp.State = j.State.Current
resp.StateReason = derefStr(j.State.Reason)
}
if j.Partition != nil {
resp.Partition = *j.Partition
}
resp.Account = derefStr(j.Account)
if j.User != nil {
resp.User = *j.User
}
resp.Cluster = derefStr(j.Cluster)
resp.QOS = derefStr(j.Qos)
resp.Priority = mapUint32NoValToInt32(j.Priority)
if j.Time != nil {
resp.TimeLimit = uint32NoValString(j.Time.Limit)
if j.Time.Submission != nil {
resp.SubmitTime = j.Time.Submission
}
if j.Time.Start != nil {
resp.StartTime = j.Time.Start
}
if j.Time.End != nil {
resp.EndTime = j.Time.End
}
}
if j.ExitCode != nil && j.ExitCode.ReturnCode != nil && j.ExitCode.ReturnCode.Number != nil {
code := int32(*j.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if j.Nodes != nil {
resp.Nodes = *j.Nodes
}
if j.Required != nil {
resp.Cpus = j.Required.CPUs
}
if j.AllocationNodes != nil {
resp.NodeCount = j.AllocationNodes
}
resp.WorkDir = derefStr(j.WorkingDirectory)
return resp
}

View File

@@ -0,0 +1,867 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func mockJobServer(handler http.HandlerFunc) (*slurm.Client, func()) {
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return client, srv.Close
}
func TestSubmitJob(t *testing.T) {
jobID := int32(123)
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/slurm/v0.0.40/job/submit" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
var body slurm.JobSubmitReq
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode body: %v", err)
}
if body.Job == nil || body.Job.Script == nil || *body.Job.Script != "#!/bin/bash\necho hello" {
t.Errorf("unexpected script in request body")
}
resp := slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{
JobID: &jobID,
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "#!/bin/bash\necho hello",
Partition: "normal",
QOS: "high",
JobName: "test-job",
CPUs: 4,
TimeLimit: "60",
})
if err != nil {
t.Fatalf("SubmitJob: %v", err)
}
if resp.JobID != 123 {
t.Errorf("expected JobID 123, got %d", resp.JobID)
}
}
func TestSubmitJob_WithOptionalFields(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body slurm.JobSubmitReq
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode body: %v", err)
}
if body.Job == nil {
t.Fatal("job desc is nil")
}
if body.Job.Partition != nil {
t.Error("expected partition nil for empty string")
}
if body.Job.MinimumCpus != nil {
t.Error("expected minimum_cpus nil when CPUs=0")
}
jobID := int32(456)
resp := slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
resp, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "echo hi",
})
if err != nil {
t.Fatalf("SubmitJob: %v", err)
}
if resp.JobID != 456 {
t.Errorf("expected JobID 456, got %d", resp.JobID)
}
}
func TestSubmitJob_Error(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"internal"}`))
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "echo fail",
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestGetJobs(t *testing.T) {
jobID := int32(100)
name := "my-job"
partition := "gpu"
ts := int64(1700000000)
nodes := "node01"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
Name: &name,
JobState: []string{"RUNNING"},
Partition: &partition,
SubmitTime: &slurm.Uint64NoVal{Number: &ts},
StartTime: &slurm.Uint64NoVal{Number: &ts},
EndTime: &slurm.Uint64NoVal{Number: &ts},
Nodes: &nodes,
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobs(context.Background(), &model.JobListQuery{Page: 1, PageSize: 20})
if err != nil {
t.Fatalf("GetJobs: %v", err)
}
if result.Total != 1 {
t.Fatalf("expected total 1, got %d", result.Total)
}
if len(result.Jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(result.Jobs))
}
j := result.Jobs[0]
if j.JobID != 100 {
t.Errorf("expected JobID 100, got %d", j.JobID)
}
if j.Name != "my-job" {
t.Errorf("expected Name my-job, got %s", j.Name)
}
if len(j.State) != 1 || j.State[0] != "RUNNING" {
t.Errorf("expected State [RUNNING], got %v", j.State)
}
if j.Partition != "gpu" {
t.Errorf("expected Partition gpu, got %s", j.Partition)
}
if j.SubmitTime == nil || *j.SubmitTime != ts {
t.Errorf("expected SubmitTime %d, got %v", ts, j.SubmitTime)
}
if j.Nodes != "node01" {
t.Errorf("expected Nodes node01, got %s", j.Nodes)
}
if result.Page != 1 {
t.Errorf("expected Page 1, got %d", result.Page)
}
if result.PageSize != 20 {
t.Errorf("expected PageSize 20, got %d", result.PageSize)
}
}
func TestGetJob(t *testing.T) {
jobID := int32(200)
name := "single-job"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
Name: &name,
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "200")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job == nil {
t.Fatal("expected job, got nil")
}
if job.JobID != 200 {
t.Errorf("expected JobID 200, got %d", job.JobID)
}
}
func TestGetJob_NotFound(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{Jobs: slurm.JobInfoMsg{}}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "999")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job != nil {
t.Errorf("expected nil for not found, got %+v", job)
}
}
func TestCancelJob(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
t.Errorf("expected DELETE, got %s", r.Method)
}
resp := slurm.OpenapiResp{}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
err := svc.CancelJob(context.Background(), "300")
if err != nil {
t.Fatalf("CancelJob: %v", err)
}
}
func TestCancelJob_Error(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`not found`))
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
err := svc.CancelJob(context.Background(), "999")
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestGetJobHistory(t *testing.T) {
jobID1 := int32(10)
jobID2 := int32(20)
jobID3 := int32(30)
name1 := "hist-1"
name2 := "hist-2"
name3 := "hist-3"
submission1 := int64(1700000000)
submission2 := int64(1700001000)
submission3 := int64(1700002000)
partition := "normal"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
users := r.URL.Query().Get("users")
if users != "testuser" {
t.Errorf("expected users=testuser, got %s", users)
}
resp := slurm.OpenapiSlurmdbdJobsResp{
Jobs: slurm.JobList{
{
JobID: &jobID1,
Name: &name1,
Partition: &partition,
State: &slurm.JobState{Current: []string{"COMPLETED"}},
Time: &slurm.JobTime{Submission: &submission1},
},
{
JobID: &jobID2,
Name: &name2,
Partition: &partition,
State: &slurm.JobState{Current: []string{"FAILED"}},
Time: &slurm.JobTime{Submission: &submission2},
},
{
JobID: &jobID3,
Name: &name3,
Partition: &partition,
State: &slurm.JobState{Current: []string{"CANCELLED"}},
Time: &slurm.JobTime{Submission: &submission3},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
Users: "testuser",
Page: 1,
PageSize: 2,
})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
if result.Total != 3 {
t.Errorf("expected Total 3, got %d", result.Total)
}
if result.Page != 1 {
t.Errorf("expected Page 1, got %d", result.Page)
}
if result.PageSize != 2 {
t.Errorf("expected PageSize 2, got %d", result.PageSize)
}
if len(result.Jobs) != 2 {
t.Fatalf("expected 2 jobs on page 1, got %d", len(result.Jobs))
}
if result.Jobs[0].JobID != 10 {
t.Errorf("expected first job ID 10, got %d", result.Jobs[0].JobID)
}
if result.Jobs[1].JobID != 20 {
t.Errorf("expected second job ID 20, got %d", result.Jobs[1].JobID)
}
if len(result.Jobs[0].State) != 1 || result.Jobs[0].State[0] != "COMPLETED" {
t.Errorf("expected state [COMPLETED], got %v", result.Jobs[0].State)
}
}
func TestGetJobHistory_Page2(t *testing.T) {
jobID1 := int32(10)
jobID2 := int32(20)
name1 := "a"
name2 := "b"
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiSlurmdbdJobsResp{
Jobs: slurm.JobList{
{JobID: &jobID1, Name: &name1},
{JobID: &jobID2, Name: &name2},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
Page: 2,
PageSize: 1,
})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
if result.Total != 2 {
t.Errorf("expected Total 2, got %d", result.Total)
}
if len(result.Jobs) != 1 {
t.Fatalf("expected 1 job on page 2, got %d", len(result.Jobs))
}
if result.Jobs[0].JobID != 20 {
t.Errorf("expected job ID 20, got %d", result.Jobs[0].JobID)
}
}
func TestGetJobHistory_DefaultPagination(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
result, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
if result.Page != 1 {
t.Errorf("expected default page 1, got %d", result.Page)
}
if result.PageSize != 20 {
t.Errorf("expected default pageSize 20, got %d", result.PageSize)
}
}
func TestGetJobHistory_QueryMapping(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
if v := q.Get("account"); v != "proj1" {
t.Errorf("expected account=proj1, got %s", v)
}
if v := q.Get("partition"); v != "gpu" {
t.Errorf("expected partition=gpu, got %s", v)
}
if v := q.Get("state"); v != "COMPLETED" {
t.Errorf("expected state=COMPLETED, got %s", v)
}
if v := q.Get("job_name"); v != "myjob" {
t.Errorf("expected job_name=myjob, got %s", v)
}
if v := q.Get("start_time"); v != "1700000000" {
t.Errorf("expected start_time=1700000000, got %s", v)
}
if v := q.Get("end_time"); v != "1700099999" {
t.Errorf("expected end_time=1700099999, got %s", v)
}
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
json.NewEncoder(w).Encode(resp)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{
Users: "testuser",
Account: "proj1",
Partition: "gpu",
State: "COMPLETED",
JobName: "myjob",
StartTime: "1700000000",
EndTime: "1700099999",
})
if err != nil {
t.Fatalf("GetJobHistory: %v", err)
}
}
func TestGetJobHistory_Error(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"db down"}`))
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestMapJobInfo_ExitCode(t *testing.T) {
returnCode := int64(2)
ji := &slurm.JobInfo{
ExitCode: &slurm.ProcessExitCodeVerbose{
ReturnCode: &slurm.Uint32NoVal{Number: &returnCode},
},
}
resp := mapJobInfo(ji)
if resp.ExitCode == nil || *resp.ExitCode != 2 {
t.Errorf("expected exit code 2, got %v", resp.ExitCode)
}
}
func TestMapSlurmdbJob_NilFields(t *testing.T) {
j := &slurm.Job{}
resp := mapSlurmdbJob(j)
if resp.JobID != 0 {
t.Errorf("expected JobID 0, got %d", resp.JobID)
}
if resp.State != nil {
t.Errorf("expected nil State, got %v", resp.State)
}
if resp.SubmitTime != nil {
t.Errorf("expected nil SubmitTime, got %v", resp.SubmitTime)
}
}
// ---------------------------------------------------------------------------
// Structured logging tests using zaptest/observer
// ---------------------------------------------------------------------------
func newJobServiceWithObserver(srv *httptest.Server) (*JobService, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
client, _ := slurm.NewClient(srv.URL, srv.Client())
return NewJobService(client, l), recorded
}
func TestJobService_SubmitJob_SuccessLog(t *testing.T) {
jobID := int32(789)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{
Script: "echo hi",
JobName: "log-test-job",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["job_name"] != "log-test-job" {
t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"])
}
gotJobID, ok := fields["job_id"]
if !ok {
t.Fatal("expected job_id field in log entry")
}
if gotJobID != int32(789) && gotJobID != int64(789) {
t.Errorf("expected job_id=789, got %v (%T)", gotJobID, gotJobID)
}
}
func TestJobService_SubmitJob_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"internal"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.SubmitJob(context.Background(), &model.SubmitJobRequest{Script: "echo fail"})
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["operation"] != "submit" {
t.Errorf("expected operation=submit, got %v", fields["operation"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_CancelJob_SuccessLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiResp{}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
err := svc.CancelJob(context.Background(), "555")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["job_id"] != "555" {
t.Errorf("expected job_id=555, got %v", fields["job_id"])
}
}
func TestJobService_CancelJob_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`not found`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
err := svc.CancelJob(context.Background(), "999")
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["operation"] != "cancel" {
t.Errorf("expected operation=cancel, got %v", fields["operation"])
}
if fields["job_id"] != "999" {
t.Errorf("expected job_id=999, got %v", fields["job_id"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_GetJobs_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJobs(context.Background(), &model.JobListQuery{Page: 1, PageSize: 20})
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["operation"] != "get_jobs" {
t.Errorf("expected operation=get_jobs, got %v", fields["operation"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_GetJob_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJob(context.Background(), "200")
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["operation"] != "get_job" {
t.Errorf("expected operation=get_job, got %v", fields["operation"])
}
if fields["job_id"] != "200" {
t.Errorf("expected job_id=200, got %v", fields["job_id"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"db down"}`))
}))
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJobHistory(context.Background(), &model.JobHistoryQuery{})
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[2].ContextMap()
if fields["operation"] != "get_job_history" {
t.Errorf("expected operation=get_job_history, got %v", fields["operation"])
}
if _, ok := fields["error"]; !ok {
t.Error("expected error field in log entry")
}
}
// ---------------------------------------------------------------------------
// Fallback to SlurmDBD history tests
// ---------------------------------------------------------------------------
func TestGetJob_FallbackToHistory_Found(t *testing.T) {
jobID := int32(198)
name := "hist-job"
ts := int64(1700000000)
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/slurm/v0.0.40/job/198":
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]interface{}{
"errors": []map[string]interface{}{
{
"description": "Unable to query JobId=198",
"error_number": float64(2017),
"error": "Invalid job id specified",
"source": "_handle_job_get",
},
},
"jobs": []interface{}{},
})
case "/slurmdb/v0.0.40/job/198":
resp := slurm.OpenapiSlurmdbdJobsResp{
Jobs: slurm.JobList{
{
JobID: &jobID,
Name: &name,
State: &slurm.JobState{Current: []string{"COMPLETED"}},
Time: &slurm.JobTime{Submission: &ts, Start: &ts, End: &ts},
},
},
}
json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "198")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job == nil {
t.Fatal("expected job, got nil")
}
if job.JobID != 198 {
t.Errorf("expected JobID 198, got %d", job.JobID)
}
if job.Name != "hist-job" {
t.Errorf("expected Name hist-job, got %s", job.Name)
}
}
func TestGetJob_FallbackToHistory_NotFound(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "999")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job != nil {
t.Errorf("expected nil, got %+v", job)
}
}
func TestGetJob_FallbackToHistory_HistoryError(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/slurm/v0.0.40/job/500":
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]interface{}{
"errors": []map[string]interface{}{
{
"description": "Unable to query JobId=500",
"error_number": float64(2017),
"error": "Invalid job id specified",
"source": "_handle_job_get",
},
},
"jobs": []interface{}{},
})
case "/slurmdb/v0.0.40/job/500":
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"errors":[{"error":"db error"}]}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "500")
if err == nil {
t.Fatal("expected error, got nil")
}
if job != nil {
t.Errorf("expected nil job, got %+v", job)
}
if !strings.Contains(err.Error(), "get job history") {
t.Errorf("expected error to contain 'get job history', got %s", err.Error())
}
}
func TestGetJob_FallbackToHistory_EmptyHistory(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/slurm/v0.0.40/job/777":
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]interface{}{
"errors": []map[string]interface{}{
{
"description": "Unable to query JobId=777",
"error_number": float64(2017),
"error": "Invalid job id specified",
"source": "_handle_job_get",
},
},
"jobs": []interface{}{},
})
case "/slurmdb/v0.0.40/job/777":
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "777")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job != nil {
t.Errorf("expected nil, got %+v", job)
}
}

View File

@@ -0,0 +1,112 @@
package service
import (
"fmt"
"math/rand"
"regexp"
"sort"
"strconv"
"strings"
"gcy_hpc_server/internal/model"
)
var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
// ValidateParams checks that all required parameters are present and values match their types.
// Parameters not in the schema are silently ignored.
func ValidateParams(params []model.ParameterSchema, values map[string]string) error {
var errs []string
for _, p := range params {
if !paramNameRegex.MatchString(p.Name) {
errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name))
continue
}
val, ok := values[p.Name]
if p.Required && !ok {
errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name))
continue
}
if !ok {
continue
}
switch p.Type {
case model.ParamTypeInteger:
if _, err := strconv.Atoi(val); err != nil {
errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val))
}
case model.ParamTypeBoolean:
if val != "true" && val != "false" && val != "1" && val != "0" {
errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val))
}
case model.ParamTypeEnum:
if len(p.Options) > 0 {
found := false
for _, opt := range p.Options {
if val == opt {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val))
}
}
case model.ParamTypeFile, model.ParamTypeDirectory:
case model.ParamTypeString:
}
}
if len(errs) > 0 {
return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; "))
}
return nil
}
// RenderScript replaces $PARAM tokens in the template with user-provided values.
// Only tokens defined in the schema are replaced. Replacement is done longest-name-first
// to avoid partial matches (e.g., $JOB_NAME before $JOB).
// All values are shell-escaped using single-quote wrapping.
func RenderScript(template string, params []model.ParameterSchema, values map[string]string) string {
sorted := make([]model.ParameterSchema, len(params))
copy(sorted, params)
sort.Slice(sorted, func(i, j int) bool {
return len(sorted[i].Name) > len(sorted[j].Name)
})
result := template
for _, p := range sorted {
val, ok := values[p.Name]
if !ok {
if p.Default != "" {
val = p.Default
} else {
continue
}
}
escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'"
result = strings.ReplaceAll(result, "$"+p.Name, escaped)
}
return result
}
// SanitizeDirName sanitizes a directory name.
func SanitizeDirName(name string) string {
replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
return replacer.Replace(name)
}
// RandomSuffix generates a random suffix of length n.
func RandomSuffix(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}

View File

@@ -0,0 +1,15 @@
package service
import (
"gcy_hpc_server/internal/slurm"
)
// NewSlurmClient creates a Slurm SDK client with JWT authentication.
// It reads the JWT key from the given keyPath and signs tokens automatically.
func NewSlurmClient(apiURL, userName, jwtKeyPath string) (*slurm.Client, error) {
return slurm.NewClientWithOpts(
apiURL,
slurm.WithUsername(userName),
slurm.WithJWTKey(jwtKeyPath),
)
}

View File

@@ -0,0 +1,554 @@
package service
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
type TaskService struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
fileStore *store.FileStore // nil ok
blobStore *store.BlobStore // nil ok
stagingSvc *FileStagingService // nil ok — MinIO unavailable
jobSvc *JobService
workDirBase string
logger *zap.Logger
// async processing
taskCh chan int64 // buffered channel, cap=16
cancelFn context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex // protects taskCh from send-on-closed
started bool // prevent double-start
stopped bool
}
func NewTaskService(
taskStore *store.TaskStore,
appStore *store.ApplicationStore,
fileStore *store.FileStore,
blobStore *store.BlobStore,
stagingSvc *FileStagingService,
jobSvc *JobService,
workDirBase string,
logger *zap.Logger,
) *TaskService {
return &TaskService{
taskStore: taskStore,
appStore: appStore,
fileStore: fileStore,
blobStore: blobStore,
stagingSvc: stagingSvc,
jobSvc: jobSvc,
workDirBase: workDirBase,
logger: logger,
taskCh: make(chan int64, 16),
}
}
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
app, err := s.appStore.GetByID(ctx, req.AppID)
if err != nil {
return nil, fmt.Errorf("get application: %w", err)
}
if app == nil {
return nil, fmt.Errorf("application %d not found", req.AppID)
}
// 2. Validate file limit
if len(req.InputFileIDs) > 100 {
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
}
// 3. Deduplicate file IDs
fileIDs := uniqueInt64s(req.InputFileIDs)
// 4. Validate file IDs exist
if s.fileStore != nil && len(fileIDs) > 0 {
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
if err != nil {
return nil, fmt.Errorf("validate file ids: %w", err)
}
found := make(map[int64]bool, len(files))
for _, f := range files {
found[f.ID] = true
}
for _, id := range fileIDs {
if !found[id] {
return nil, fmt.Errorf("file %d not found", id)
}
}
}
// 5. Auto-generate task name if empty
taskName := req.TaskName
if taskName == "" {
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
}
// 6. Marshal values
valuesJSON := json.RawMessage(`{}`)
if len(req.Values) > 0 {
b, err := json.Marshal(req.Values)
if err != nil {
return nil, fmt.Errorf("marshal values: %w", err)
}
valuesJSON = b
}
// 7. Marshal input_file_ids
fileIDsJSON := json.RawMessage(`[]`)
if len(fileIDs) > 0 {
b, err := json.Marshal(fileIDs)
if err != nil {
return nil, fmt.Errorf("marshal file ids: %w", err)
}
fileIDsJSON = b
}
// 8. Create task record
task := &model.Task{
TaskName: taskName,
AppID: app.ID,
AppName: app.Name,
Status: model.TaskStatusSubmitted,
Values: valuesJSON,
InputFileIDs: fileIDsJSON,
SubmittedAt: time.Now(),
}
taskID, err := s.taskStore.Create(ctx, task)
if err != nil {
return nil, fmt.Errorf("create task: %w", err)
}
task.ID = taskID
return task, nil
}
// ProcessTask runs the full synchronous processing pipeline for a task.
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
// 1. Fetch task
task, err := s.taskStore.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("get task: %w", err)
}
if task == nil {
return fmt.Errorf("task %d not found", taskID)
}
fail := func(step, msg string) error {
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
return fmt.Errorf("%s", msg)
}
currentStep := task.CurrentStep
var workDir string
var app *model.Application
if currentStep == "" || currentStep == model.TaskStepPreparing {
// 2. Set preparing
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
}
// 3. Fetch app
app, err = s.appStore.GetByID(ctx, task.AppID)
if err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
}
if app == nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
}
// 4-5. Create work directory
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
if err := os.MkdirAll(workDir, 0777); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
}
// 6. CHMOD traversal — critical for multi-user HPC
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
os.Chmod(dir, 0777)
}
os.Chmod(s.workDirBase, 0777)
// 7. UpdateWorkDir
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
}
} else {
app, err = s.appStore.GetByID(ctx, task.AppID)
if err != nil {
return fail(currentStep, fmt.Sprintf("get application: %v", err))
}
if app == nil {
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
}
workDir = task.WorkDir
}
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
if currentStep == model.TaskStepDownloading && workDir != "" {
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
for _, f := range matches {
os.Remove(f)
}
}
// 8. Set downloading
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
}
// 9. Parse input_file_ids
var fileIDs []int64
if len(task.InputFileIDs) > 0 {
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
}
}
// 10-12. Download files
if len(fileIDs) > 0 {
if s.stagingSvc == nil {
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
}
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
}
}
}
// 13-14. Set ready + submitting
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
}
// 15. Parse app parameters
var params []model.ParameterSchema
if len(app.Parameters) > 0 {
if err := json.Unmarshal(app.Parameters, &params); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
}
}
// 16. Parse task values
values := make(map[string]string)
if len(task.Values) > 0 {
if err := json.Unmarshal(task.Values, &values); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
}
}
if err := ValidateParams(params, values); err != nil {
return fail(model.TaskStepSubmitting, err.Error())
}
// 17. Render script
rendered := RenderScript(app.ScriptTemplate, params, values)
// 18. Submit to Slurm
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
Script: rendered,
WorkDir: workDir,
})
if err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
}
// 19. Update slurm_job_id and status to queued
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
}
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
}
return nil
}
// ListTasks returns a paginated list of tasks.
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
return s.taskStore.List(ctx, query)
}
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
// for old API compatibility.
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
// 1. Create task
task, err := s.CreateTask(ctx, req)
if err != nil {
return nil, err
}
// 2. Process synchronously
if err := s.ProcessTask(ctx, task.ID); err != nil {
return nil, err
}
// 3. Re-fetch to get updated slurm_job_id
task, err = s.taskStore.GetByID(ctx, task.ID)
if err != nil {
return nil, fmt.Errorf("re-fetch task: %w", err)
}
if task == nil || task.SlurmJobID == nil {
return nil, fmt.Errorf("task has no slurm job id after processing")
}
// 4. Return JobResponse
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
}
// uniqueInt64s deduplicates and sorts a slice of int64.
func uniqueInt64s(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
seen := make(map[int64]bool, len(ids))
result := make([]int64, 0, len(ids))
for _, id := range ids {
if !seen[id] {
seen[id] = true
result = append(result, id)
}
}
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
return result
}
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
if len(slurmState) == 0 {
return model.TaskStatusRunning
}
state := strings.ToUpper(slurmState[0])
switch state {
case "PENDING":
return model.TaskStatusQueued
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
return model.TaskStatusRunning
case "COMPLETED":
return model.TaskStatusCompleted
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
return model.TaskStatusFailed
default:
return model.TaskStatusRunning
}
}
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
task, err := s.taskStore.GetByID(ctx, taskID)
if err != nil {
s.logger.Error("failed to fetch task for refresh",
zap.Int64("task_id", taskID),
zap.Error(err),
)
return err
}
if task == nil || task.SlurmJobID == nil {
return nil
}
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
if err != nil {
s.logger.Warn("failed to query slurm job status during refresh",
zap.Int64("task_id", taskID),
zap.Int32("slurm_job_id", *task.SlurmJobID),
zap.Error(err),
)
return nil
}
if jobResp == nil {
return nil
}
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
if newStatus != task.Status {
s.logger.Info("updating task status from slurm",
zap.Int64("task_id", taskID),
zap.String("old_status", task.Status),
zap.String("new_status", newStatus),
)
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
}
return nil
}
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
staleThreshold := 30 * time.Second
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
for _, status := range nonTerminal {
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
Status: status,
Page: 1,
PageSize: 1000,
})
if err != nil {
s.logger.Warn("failed to list tasks for stale refresh",
zap.String("status", status),
zap.Error(err),
)
continue
}
cutoff := time.Now().Add(-staleThreshold)
for i := range tasks {
if tasks[i].UpdatedAt.Before(cutoff) {
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
s.logger.Warn("failed to refresh stale task",
zap.Int64("task_id", tasks[i].ID),
zap.Error(err),
)
}
}
}
}
return nil
}
func (s *TaskService) StartProcessor(ctx context.Context) {
s.mu.Lock()
if s.started {
s.mu.Unlock()
return
}
s.started = true
s.mu.Unlock()
ctx, s.cancelFn = context.WithCancel(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
if r := recover(); r != nil {
s.logger.Error("processor panic", zap.Any("panic", r))
}
}()
for {
select {
case <-ctx.Done():
return
case taskID, ok := <-s.taskCh:
if !ok {
return
}
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
s.processWithRetry(taskCtx, taskID)
cancel()
}
}
}()
s.RecoverStuckTasks(ctx)
}
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
task, err := s.CreateTask(ctx, req)
if err != nil {
return 0, err
}
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return 0, fmt.Errorf("processor stopped, cannot submit task")
}
select {
case s.taskCh <- task.ID:
default:
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
}
s.mu.Unlock()
return task.ID, nil
}
func (s *TaskService) StopProcessor() {
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return
}
s.stopped = true
close(s.taskCh)
s.mu.Unlock()
if s.cancelFn != nil {
s.cancelFn()
}
s.wg.Wait()
s.mu.Lock()
drainCh := s.taskCh
s.taskCh = make(chan int64, 16)
s.mu.Unlock()
for taskID := range drainCh {
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
}
}
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
err := s.ProcessTask(ctx, taskID)
if err == nil {
return
}
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
if fetchErr != nil || task == nil {
return
}
if task.RetryCount < 3 {
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
s.mu.Lock()
if !s.stopped {
select {
case s.taskCh <- taskID:
default:
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
}
}
s.mu.Unlock()
}
}
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
if err != nil {
s.logger.Error("failed to get stuck tasks", zap.Error(err))
return
}
for i := range tasks {
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
s.mu.Lock()
if !s.stopped {
select {
case s.taskCh <- tasks[i].ID:
default:
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
}
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,416 @@
package service
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
gormlogger "gorm.io/gorm/logger"
"gorm.io/gorm"
)
func setupAsyncTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type asyncTestEnv struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
svc *TaskService
srv *httptest.Server
db *gorm.DB
workDirBase string
}
func newAsyncTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *asyncTestEnv {
t.Helper()
db := setupAsyncTestDB(t)
ts := store.NewTaskStore(db)
as := store.NewApplicationStore(db)
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
workDirBase := filepath.Join(t.TempDir(), "workdir")
os.MkdirAll(workDirBase, 0777)
svc := NewTaskService(ts, as, nil, nil, nil, jobSvc, workDirBase, zap.NewNop())
return &asyncTestEnv{
taskStore: ts,
appStore: as,
svc: svc,
srv: srv,
db: db,
workDirBase: workDirBase,
}
}
func (e *asyncTestEnv) close() {
e.srv.Close()
}
func (e *asyncTestEnv) createApp(t *testing.T, name, script string) int64 {
t.Helper()
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
Name: name,
ScriptTemplate: script,
Parameters: json.RawMessage(`[]`),
})
if err != nil {
t.Fatalf("create app: %v", err)
}
return id
}
func TestTaskService_Async_SubmitAndProcess(t *testing.T) {
jobID := int32(42)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "async-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "async-test",
})
if err != nil {
t.Fatalf("SubmitAsync: %v", err)
}
if taskID == 0 {
t.Fatal("expected non-zero task ID")
}
time.Sleep(500 * time.Millisecond)
task, err := env.taskStore.GetByID(ctx, taskID)
if err != nil {
t.Fatalf("GetByID: %v", err)
}
if task.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusQueued)
}
env.svc.StopProcessor()
}
func TestTaskService_Retry_MaxExhaustion(t *testing.T) {
callCount := int32(0)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&callCount, 1)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"slurm down"}`))
}))
defer env.close()
appID := env.createApp(t, "retry-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "retry-test",
})
if err != nil {
t.Fatalf("SubmitAsync: %v", err)
}
time.Sleep(2 * time.Second)
task, _ := env.taskStore.GetByID(ctx, taskID)
if task.Status != model.TaskStatusFailed {
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusFailed)
}
if task.RetryCount < 3 {
t.Errorf("RetryCount = %d, want >= 3", task.RetryCount)
}
env.svc.StopProcessor()
}
func TestTaskService_Recover_StuckTasks(t *testing.T) {
jobID := int32(99)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "stuck-app", "#!/bin/bash\necho hello")
ctx := context.Background()
task := &model.Task{
TaskName: "stuck-task",
AppID: appID,
AppName: "stuck-app",
Status: model.TaskStatusPreparing,
CurrentStep: model.TaskStepPreparing,
RetryCount: 0,
SubmittedAt: time.Now(),
}
taskID, err := env.taskStore.Create(ctx, task)
if err != nil {
t.Fatalf("Create stuck task: %v", err)
}
staleTime := time.Now().Add(-10 * time.Minute)
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, taskID)
env.svc.StartProcessor(ctx)
time.Sleep(1 * time.Second)
updated, _ := env.taskStore.GetByID(ctx, taskID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
env.svc.StopProcessor()
}
func TestTaskService_Shutdown_InFlight(t *testing.T) {
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
jobID := int32(77)
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "shutdown-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "shutdown-test",
})
done := make(chan struct{})
go func() {
env.svc.StopProcessor()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("StopProcessor did not complete within timeout")
}
task, _ := env.taskStore.GetByID(ctx, taskID)
if task.Status != model.TaskStatusQueued && task.Status != model.TaskStatusSubmitted {
t.Logf("task status after shutdown: %q (acceptable)", task.Status)
}
}
func TestTaskService_PanicRecovery(t *testing.T) {
jobID := int32(55)
panicDone := int32(0)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.CompareAndSwapInt32(&panicDone, 0, 1) {
panic("intentional test panic")
}
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "panic-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "panic-test",
})
time.Sleep(1 * time.Second)
atomic.StoreInt32(&panicDone, 1)
env.svc.StopProcessor()
_ = taskID
}
func TestTaskService_SubmitAsync_DuringShutdown(t *testing.T) {
env := newAsyncTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "shutdown-err-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
env.svc.StopProcessor()
_, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "after-shutdown",
})
if err == nil {
t.Fatal("expected error when submitting after shutdown")
}
}
// TestTaskService_SubmitAsync_ChannelFull_NonBlocking verifies SubmitAsync
// returns without blocking when the task channel buffer (cap=16) is full.
// Before fix: SubmitAsync holds s.mu while blocking on full channel → deadlock.
// After fix: non-blocking select returns immediately.
func TestTaskService_SubmitAsync_ChannelFull_NonBlocking(t *testing.T) {
jobID := int32(42)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "channel-full-app", "#!/bin/bash\necho hello")
ctx := context.Background()
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
taskIDs := make([]int64, 17)
for i := range taskIDs {
id, err := env.taskStore.Create(ctx, &model.Task{
TaskName: fmt.Sprintf("fill-%d", i),
AppID: appID,
AppName: "channel-full-app",
Status: model.TaskStatusSubmitted,
CurrentStep: model.TaskStepSubmitting,
SubmittedAt: time.Now(),
})
if err != nil {
t.Fatalf("create fill task %d: %v", i, err)
}
taskIDs[i] = id
}
env.svc.StartProcessor(ctx)
defer env.svc.StopProcessor()
// Consumer grabs first ID immediately; remaining 15 sit in channel.
// Push one more to fill buffer to 16 (full).
for _, id := range taskIDs {
env.svc.taskCh <- id
}
// Overflow submit: must return within 3s (non-blocking after fix)
done := make(chan error, 1)
go func() {
_, submitErr := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "overflow-task",
})
done <- submitErr
}()
select {
case err := <-done:
if err != nil {
t.Logf("SubmitAsync returned error (acceptable after fix): %v", err)
} else {
t.Log("SubmitAsync returned without blocking — channel send is non-blocking")
}
case <-time.After(3 * time.Second):
t.Fatal("SubmitAsync blocked for >3s — channel send is blocking, potential deadlock")
}
}
// TestTaskService_Retry_ChannelFull_NonBlocking verifies processWithRetry
// does not deadlock when re-enqueuing a failed task into a full channel.
// Before fix: processWithRetry holds s.mu while blocking on s.taskCh <- taskID → deadlock.
// After fix: non-blocking select drops the retry with a Warn log.
func TestTaskService_Retry_ChannelFull_NonBlocking(t *testing.T) {
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"slurm down"}`))
}))
defer env.close()
appID := env.createApp(t, "retry-full-app", "#!/bin/bash\necho hello")
ctx := context.Background()
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
taskIDs := make([]int64, 17)
for i := range taskIDs {
id, err := env.taskStore.Create(ctx, &model.Task{
TaskName: fmt.Sprintf("retry-%d", i),
AppID: appID,
AppName: "retry-full-app",
Status: model.TaskStatusSubmitted,
CurrentStep: model.TaskStepSubmitting,
RetryCount: 0,
SubmittedAt: time.Now(),
})
if err != nil {
t.Fatalf("create retry task %d: %v", i, err)
}
taskIDs[i] = id
}
env.svc.StartProcessor(ctx)
// Push all 17 IDs: consumer grabs one (processing ~1s), 16 fill the buffer
for _, id := range taskIDs {
env.svc.taskCh <- id
}
// Wait for consumer to finish first task and attempt retry into full channel
time.Sleep(2 * time.Second)
// If processWithRetry deadlocked holding s.mu, StopProcessor hangs on mutex acquisition
done := make(chan struct{})
go func() {
env.svc.StopProcessor()
close(done)
}()
select {
case <-done:
t.Log("StopProcessor completed — retry channel send is non-blocking")
case <-time.After(5 * time.Second):
t.Fatal("StopProcessor did not complete within 5s — deadlock from retry channel send")
}
}

View File

@@ -0,0 +1,294 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskSvcTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type taskSvcTestEnv struct {
taskStore *store.TaskStore
jobSvc *JobService
svc *TaskService
srv *httptest.Server
db *gorm.DB
}
func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv {
t.Helper()
db := newTaskSvcTestDB(t)
ts := store.NewTaskStore(db)
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop())
return &taskSvcTestEnv{
taskStore: ts,
jobSvc: jobSvc,
svc: svc,
srv: srv,
db: db,
}
}
func (e *taskSvcTestEnv) close() {
e.srv.Close()
}
func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task {
return &model.Task{
TaskName: name,
AppID: 1,
AppName: "test-app",
Status: status,
CurrentStep: "",
RetryCount: 0,
UserID: "user1",
SubmittedAt: time.Now(),
SlurmJobID: slurmJobID,
}
}
func TestTaskService_MapSlurmState_AllStates(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
cases := []struct {
input []string
expected string
}{
{[]string{"PENDING"}, model.TaskStatusQueued},
{[]string{"RUNNING"}, model.TaskStatusRunning},
{[]string{"CONFIGURING"}, model.TaskStatusRunning},
{[]string{"COMPLETING"}, model.TaskStatusRunning},
{[]string{"COMPLETED"}, model.TaskStatusCompleted},
{[]string{"FAILED"}, model.TaskStatusFailed},
{[]string{"CANCELLED"}, model.TaskStatusFailed},
{[]string{"TIMEOUT"}, model.TaskStatusFailed},
{[]string{"NODE_FAIL"}, model.TaskStatusFailed},
{[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed},
{[]string{"PREEMPTED"}, model.TaskStatusFailed},
{[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning},
{[]string{"unknown_state"}, model.TaskStatusRunning},
{[]string{"pending"}, model.TaskStatusQueued},
{[]string{"Running"}, model.TaskStatusRunning},
}
for _, tc := range cases {
got := env.svc.mapSlurmStateToTaskStatus(tc.input)
if got != tc.expected {
t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected)
}
}
}
func TestTaskService_MapSlurmState_Empty(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
got := env.svc.mapSlurmStateToTaskStatus([]string{})
if got != model.TaskStatusRunning {
t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning)
}
got = env.svc.mapSlurmStateToTaskStatus(nil)
if got != model.TaskStatusRunning {
t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"RUNNING"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID)
id, err := env.taskStore.Create(ctx, task)
if err != nil {
t.Fatalf("Create: %v", err)
}
err = env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("refreshTaskStatus: %v", err)
}
updated, _ := env.taskStore.GetByID(ctx, id)
if updated.Status != model.TaskStatusRunning {
t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
ctx := context.Background()
task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusQueued {
t.Errorf("status should remain unchanged, got %q", got.Status)
}
}
func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("expected no error (soft fail), got %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusQueued {
t.Errorf("status should remain unchanged on slurm error, got %q", got.Status)
}
}
func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"RUNNING"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("refreshTaskStatus: %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusRunning {
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) {
jobID := int32(42)
slurmQueried := false
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slurmQueried = true
w.WriteHeader(http.StatusInternalServerError)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID)
id, _ := env.taskStore.Create(ctx, task)
freshTask, _ := env.taskStore.GetByID(ctx, id)
if freshTask == nil {
t.Fatal("task not found")
}
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id)
err := env.svc.RefreshStaleTasks(ctx)
if err != nil {
t.Fatalf("RefreshStaleTasks: %v", err)
}
if slurmQueried {
t.Error("expected no Slurm query for fresh task")
}
}
func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"COMPLETED"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID)
id, _ := env.taskStore.Create(ctx, task)
staleTime := time.Now().Add(-60 * time.Second)
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id)
err := env.svc.RefreshStaleTasks(ctx)
if err != nil {
t.Fatalf("RefreshStaleTasks: %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusCompleted {
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted)
}
}

View File

@@ -0,0 +1,538 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
gormlogger "gorm.io/gorm/logger"
"gorm.io/gorm"
)
func setupTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type taskTestEnv struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
fileStore *store.FileStore
blobStore *store.BlobStore
svc *TaskService
srv *httptest.Server
db *gorm.DB
workDirBase string
}
func newTaskTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *taskTestEnv {
t.Helper()
db := setupTaskTestDB(t)
ts := store.NewTaskStore(db)
as := store.NewApplicationStore(db)
fs := store.NewFileStore(db)
bs := store.NewBlobStore(db)
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
workDirBase := filepath.Join(t.TempDir(), "workdir")
os.MkdirAll(workDirBase, 0777)
svc := NewTaskService(ts, as, fs, bs, nil, jobSvc, workDirBase, zap.NewNop())
return &taskTestEnv{
taskStore: ts,
appStore: as,
fileStore: fs,
blobStore: bs,
svc: svc,
srv: srv,
db: db,
workDirBase: workDirBase,
}
}
func (e *taskTestEnv) close() {
e.srv.Close()
}
func (e *taskTestEnv) createApp(t *testing.T, name, script string, params json.RawMessage) int64 {
t.Helper()
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
Name: name,
ScriptTemplate: script,
Parameters: params,
})
if err != nil {
t.Fatalf("create app: %v", err)
}
return id
}
func TestTaskService_CreateTask_Success(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "my-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
TaskName: "test-task",
Values: map[string]string{"KEY": "val"},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if task.ID == 0 {
t.Error("expected non-zero task ID")
}
if task.AppID != appID {
t.Errorf("AppID = %d, want %d", task.AppID, appID)
}
if task.AppName != "my-app" {
t.Errorf("AppName = %q, want %q", task.AppName, "my-app")
}
if task.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusSubmitted)
}
if task.TaskName != "test-task" {
t.Errorf("TaskName = %q, want %q", task.TaskName, "test-task")
}
var values map[string]string
if err := json.Unmarshal(task.Values, &values); err != nil {
t.Fatalf("unmarshal values: %v", err)
}
if values["KEY"] != "val" {
t.Errorf("values[KEY] = %q, want %q", values["KEY"], "val")
}
}
func TestTaskService_CreateTask_InvalidAppID(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: 999,
})
if err == nil {
t.Fatal("expected error for invalid app_id")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestTaskService_CreateTask_ExceedsFileLimit(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
fileIDs := make([]int64, 101)
for i := range fileIDs {
fileIDs[i] = int64(i + 1)
}
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
InputFileIDs: fileIDs,
})
if err == nil {
t.Fatal("expected error for exceeding file limit")
}
if !strings.Contains(err.Error(), "exceeds limit") {
t.Errorf("error should mention limit, got: %v", err)
}
}
func TestTaskService_CreateTask_DuplicateFileIDs(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
ctx := context.Background()
for _, id := range []int64{1, 2} {
f := &model.File{
Name: "file.txt",
BlobSHA256: "abc123",
}
if err := env.fileStore.Create(ctx, f); err != nil {
t.Fatalf("create file: %v", err)
}
if f.ID != id {
t.Fatalf("expected file ID %d, got %d", id, f.ID)
}
}
task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{
AppID: appID,
InputFileIDs: []int64{1, 1, 2, 2},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
var fileIDs []int64
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
t.Fatalf("unmarshal file ids: %v", err)
}
if len(fileIDs) != 2 {
t.Fatalf("expected 2 deduplicated file IDs, got %d: %v", len(fileIDs), fileIDs)
}
if fileIDs[0] != 1 || fileIDs[1] != 2 {
t.Errorf("expected [1,2], got %v", fileIDs)
}
}
func TestTaskService_CreateTask_AutoName(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "My Cool App", "#!/bin/bash\necho hi", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if !strings.HasPrefix(task.TaskName, "My_Cool_App_") {
t.Errorf("auto-generated name should start with 'My_Cool_App_', got %q", task.TaskName)
}
}
func TestTaskService_CreateTask_NilValues(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: nil,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if string(task.Values) != `{}` {
t.Errorf("Values = %q, want {}", string(task.Values))
}
}
func TestTaskService_ProcessTask_Success(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "test-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{"INPUT": "hello"},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err != nil {
t.Fatalf("ProcessTask: %v", err)
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
if updated.SlurmJobID == nil || *updated.SlurmJobID != 42 {
t.Errorf("SlurmJobID = %v, want 42", updated.SlurmJobID)
}
if updated.WorkDir == "" {
t.Error("WorkDir should not be empty")
}
if !strings.HasPrefix(updated.WorkDir, env.workDirBase) {
t.Errorf("WorkDir = %q, should start with %q", updated.WorkDir, env.workDirBase)
}
info, err := os.Stat(updated.WorkDir)
if err != nil {
t.Fatalf("stat workdir: %v", err)
}
if !info.IsDir() {
t.Error("WorkDir should be a directory")
}
}
func TestTaskService_ProcessTask_TaskNotFound(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
err := env.svc.ProcessTask(context.Background(), 999)
if err == nil {
t.Fatal("expected error for non-existent task")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestTaskService_ProcessTask_SlurmError(t *testing.T) {
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"slurm down"}`))
}))
defer env.close()
appID := env.createApp(t, "test-app", "#!/bin/bash\necho hello", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err == nil {
t.Fatal("expected error from Slurm")
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusFailed {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusFailed)
}
if updated.CurrentStep != model.TaskStepSubmitting {
t.Errorf("CurrentStep = %q, want %q", updated.CurrentStep, model.TaskStepSubmitting)
}
if !strings.Contains(updated.ErrorMessage, "submit job") {
t.Errorf("ErrorMessage should mention 'submit job', got: %q", updated.ErrorMessage)
}
}
func TestTaskService_ProcessTaskSync(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "sync-app", "#!/bin/bash\necho hello", nil)
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
AppID: appID,
})
if err != nil {
t.Fatalf("ProcessTaskSync: %v", err)
}
if resp.JobID != 42 {
t.Errorf("JobID = %d, want 42", resp.JobID)
}
}
func TestTaskService_ProcessTaskSync_NoMinIO(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "no-minio-app", "#!/bin/bash\necho hello", nil)
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
AppID: appID,
InputFileIDs: nil,
})
if err != nil {
t.Fatalf("ProcessTaskSync: %v", err)
}
if resp.JobID != 42 {
t.Errorf("JobID = %d, want 42", resp.JobID)
}
}
func TestTaskService_ProcessTask_NilValues(t *testing.T) {
jobID := int32(55)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "nil-val-app", "#!/bin/bash\necho hello", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: nil,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err != nil {
t.Fatalf("ProcessTask: %v", err)
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
}
func TestTaskService_ListTasks(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "list-app", "#!/bin/bash\necho hi", nil)
for i := 0; i < 3; i++ {
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
TaskName: "task-" + string(rune('A'+i)),
})
if err != nil {
t.Fatalf("CreateTask %d: %v", i, err)
}
}
tasks, total, err := env.svc.ListTasks(context.Background(), &model.TaskListQuery{
Page: 1,
PageSize: 10,
})
if err != nil {
t.Fatalf("ListTasks: %v", err)
}
if total != 3 {
t.Errorf("total = %d, want 3", total)
}
if len(tasks) != 3 {
t.Errorf("len(tasks) = %d, want 3", len(tasks))
}
}
func TestTaskService_ProcessTask_ValidateParams_MissingRequired(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
// App requires INPUT param, but we submit without it
appID := env.createApp(t, "validation-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{}, // missing required INPUT
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err == nil {
t.Fatal("expected error for missing required parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
}
errStr := err.Error()
if !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "missing") && !strings.Contains(errStr, "INPUT") {
t.Errorf("error should mention 'validation', 'missing', or 'INPUT', got: %v", err)
}
}
func TestTaskService_ProcessTask_ValidateParams_InvalidInteger(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
// App expects integer param NUM, but we submit "abc"
appID := env.createApp(t, "int-validation-app", "#!/bin/bash\necho $NUM", json.RawMessage(`[{"name":"NUM","type":"integer","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{"NUM": "abc"}, // invalid integer
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err == nil {
t.Fatal("expected error for invalid integer parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
}
errStr := err.Error()
if !strings.Contains(errStr, "integer") && !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "NUM") {
t.Errorf("error should mention 'integer', 'validation', or 'NUM', got: %v", err)
}
}
func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) {
jobID := int32(99)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "valid-params-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{"INPUT": "hello"},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err != nil {
t.Fatalf("ProcessTask with valid params: %v", err)
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
if updated.SlurmJobID == nil || *updated.SlurmJobID != 99 {
t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID)
}
}

View File

@@ -0,0 +1,443 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UploadService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
uploadStore *store.UploadStore
cfg config.MinioConfig
db *gorm.DB
logger *zap.Logger
}
func NewUploadService(
st storage.ObjectStorage,
blobStore *store.BlobStore,
fileStore *store.FileStore,
uploadStore *store.UploadStore,
cfg config.MinioConfig,
db *gorm.DB,
logger *zap.Logger,
) *UploadService {
return &UploadService{
storage: st,
blobStore: blobStore,
fileStore: fileStore,
uploadStore: uploadStore,
cfg: cfg,
db: db,
logger: logger,
}
}
func (s *UploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
if err := model.ValidateFileName(req.FileName); err != nil {
return nil, fmt.Errorf("invalid file name: %w", err)
}
if req.FileSize < 0 {
return nil, fmt.Errorf("file size cannot be negative")
}
if req.FileSize > s.cfg.MaxFileSize {
return nil, fmt.Errorf("file size %d exceeds maximum %d", req.FileSize, s.cfg.MaxFileSize)
}
chunkSize := s.cfg.ChunkSize
if req.ChunkSize != nil {
chunkSize = *req.ChunkSize
}
if chunkSize < s.cfg.MinChunkSize {
return nil, fmt.Errorf("chunk size %d is below minimum %d", chunkSize, s.cfg.MinChunkSize)
}
totalChunks := int((req.FileSize + chunkSize - 1) / chunkSize)
if req.FileSize == 0 {
totalChunks = 0
}
if totalChunks > 10000 {
return nil, fmt.Errorf("total chunks %d exceeds limit of 10000", totalChunks)
}
blob, err := s.blobStore.GetBySHA256(ctx, req.SHA256)
if err != nil {
return nil, fmt.Errorf("check dedup: %w", err)
}
if blob != nil {
if err := s.blobStore.IncrementRef(ctx, req.SHA256); err != nil {
return nil, fmt.Errorf("increment ref: %w", err)
}
file := &model.File{
Name: req.FileName,
FolderID: req.FolderID,
BlobSHA256: req.SHA256,
}
if err := s.fileStore.Create(ctx, file); err != nil {
return nil, fmt.Errorf("create file: %w", err)
}
return model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: blob.SHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}, nil
}
mimeType := req.MimeType
if mimeType == "" {
mimeType = "application/octet-stream"
}
session := &model.UploadSession{
FileName: req.FileName,
FileSize: req.FileSize,
ChunkSize: chunkSize,
TotalChunks: totalChunks,
SHA256: req.SHA256,
FolderID: req.FolderID,
Status: "pending",
MinioPrefix: fmt.Sprintf("uploads/%d/", time.Now().UnixNano()),
MimeType: mimeType,
ExpiresAt: time.Now().Add(time.Duration(s.cfg.SessionTTL) * time.Hour),
}
if err := s.uploadStore.CreateSession(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
return model.UploadSessionResponse{
ID: session.ID,
FileName: session.FileName,
FileSize: session.FileSize,
ChunkSize: session.ChunkSize,
TotalChunks: session.TotalChunks,
SHA256: session.SHA256,
Status: session.Status,
UploadedChunks: []int{},
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
}, nil
}
func (s *UploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
if session == nil {
return fmt.Errorf("session not found")
}
switch session.Status {
case "pending", "uploading", "failed":
default:
return fmt.Errorf("cannot upload to session with status %q", session.Status)
}
if chunkIndex < 0 || chunkIndex >= session.TotalChunks {
return fmt.Errorf("chunk index %d out of range [0, %d)", chunkIndex, session.TotalChunks)
}
key := fmt.Sprintf("%schunk_%05d", session.MinioPrefix, chunkIndex)
hasher := sha256.New()
teeReader := io.TeeReader(reader, hasher)
_, err = s.storage.PutObject(ctx, s.cfg.Bucket, key, teeReader, size, storage.PutObjectOptions{
DisableMultipart: true,
})
if err != nil {
return fmt.Errorf("put object: %w", err)
}
chunkSHA256 := hex.EncodeToString(hasher.Sum(nil))
chunk := &model.UploadChunk{
SessionID: sessionID,
ChunkIndex: chunkIndex,
MinioKey: key,
SHA256: chunkSHA256,
Size: size,
Status: "uploaded",
}
if err := s.uploadStore.UpsertChunk(ctx, chunk); err != nil {
return fmt.Errorf("upsert chunk: %w", err)
}
if session.Status == "pending" {
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "uploading"); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("update status to uploading: %w", err)
}
}
}
return nil
}
func (s *UploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
session, chunks, err := s.uploadStore.GetSessionWithChunks(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("get session: %w", err)
}
if session == nil {
return nil, fmt.Errorf("session not found")
}
uploadedChunks := make([]int, 0, len(chunks))
for _, c := range chunks {
if c.Status == "uploaded" {
uploadedChunks = append(uploadedChunks, c.ChunkIndex)
}
}
return &model.UploadSessionResponse{
ID: session.ID,
FileName: session.FileName,
FileSize: session.FileSize,
ChunkSize: session.ChunkSize,
TotalChunks: session.TotalChunks,
SHA256: session.SHA256,
Status: session.Status,
UploadedChunks: uploadedChunks,
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
}, nil
}
func (s *UploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
var fileResp *model.FileResponse
var totalChunks int
var minioPrefix string
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var session model.UploadSession
query := tx.WithContext(ctx)
if !isSQLite(tx) {
query = query.Clauses(clause.Locking{Strength: "UPDATE"})
}
if err := query.First(&session, sessionID).Error; err != nil {
return fmt.Errorf("get session: %w", err)
}
switch session.Status {
case "uploading", "failed":
default:
return fmt.Errorf("cannot complete session with status %q", session.Status)
}
if session.TotalChunks > 0 {
var chunkCount int64
if err := tx.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Count(&chunkCount).Error; err != nil {
return fmt.Errorf("count chunks: %w", err)
}
if int(chunkCount) != session.TotalChunks {
return fmt.Errorf("not all chunks uploaded: %d/%d", chunkCount, session.TotalChunks)
}
}
totalChunks = session.TotalChunks
minioPrefix = session.MinioPrefix
if err := updateStatusTx(tx, ctx, sessionID, "merging"); err != nil {
return fmt.Errorf("update status to merging: %w", err)
}
blob, err := s.blobStore.GetBySHA256ForUpdate(ctx, tx, session.SHA256)
if err != nil {
return fmt.Errorf("get blob for update: %w", err)
}
if blob != nil {
result := tx.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ?", session.SHA256).
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
if result.Error != nil {
return fmt.Errorf("increment ref: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("blob not found for ref increment")
}
} else {
if session.TotalChunks > 0 {
chunkKeys := make([]string, session.TotalChunks)
for i := 0; i < session.TotalChunks; i++ {
chunkKeys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
}
dstKey := "files/" + session.SHA256
if _, err := s.storage.ComposeObject(ctx, s.cfg.Bucket, dstKey, chunkKeys); err != nil {
return errComposeFailed{err: err}
}
}
blobRecord := &model.FileBlob{
SHA256: session.SHA256,
MinioKey: "files/" + session.SHA256,
FileSize: session.FileSize,
MimeType: session.MimeType,
RefCount: 1,
}
if session.TotalChunks == 0 {
blobRecord.MinioKey = "files/" + session.SHA256
blobRecord.FileSize = 0
}
if err := tx.WithContext(ctx).Create(blobRecord).Error; err != nil {
return fmt.Errorf("create blob: %w", err)
}
}
file := &model.File{
Name: session.FileName,
FolderID: session.FolderID,
BlobSHA256: session.SHA256,
}
if err := tx.WithContext(ctx).Create(file).Error; err != nil {
return fmt.Errorf("create file: %w", err)
}
if err := updateStatusTx(tx, ctx, sessionID, "completed"); err != nil {
return fmt.Errorf("update status to completed: %w", err)
}
fileResp = &model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: session.FileSize,
MimeType: session.MimeType,
SHA256: session.SHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}
return nil
})
if err != nil {
var cfe errComposeFailed
if errors.As(err, &cfe) {
if statusErr := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "failed"); statusErr != nil {
s.logger.Warn("failed to mark session as failed", zap.Int64("session_id", sessionID), zap.Error(statusErr))
}
return nil, fmt.Errorf("compose object: %w", cfe.err)
}
return nil, err
}
if totalChunks > 0 {
keys := make([]string, totalChunks)
for i := 0; i < totalChunks; i++ {
keys[i] = fmt.Sprintf("%schunk_%05d", minioPrefix, i)
}
go func() {
bgCtx := context.Background()
if delErr := s.storage.RemoveObjects(bgCtx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
s.logger.Warn("delete temp chunks", zap.Error(delErr))
}
}()
}
return fileResp, nil
}
func (s *UploadService) CancelUpload(ctx context.Context, sessionID int64) error {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
if session == nil {
return fmt.Errorf("session not found")
}
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "cancelled"); err != nil {
return fmt.Errorf("update status to cancelled: %w", err)
}
if session.TotalChunks > 0 {
keys, listErr := s.listChunkKeys(ctx, sessionID)
if listErr != nil {
s.logger.Warn("list chunk keys for cancel", zap.Error(listErr))
} else if len(keys) > 0 {
if delErr := s.storage.RemoveObjects(ctx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
s.logger.Warn("remove chunk objects for cancel", zap.Error(delErr))
}
}
}
if err := s.uploadStore.DeleteSession(ctx, sessionID); err != nil {
return fmt.Errorf("delete session: %w", err)
}
return nil
}
func (s *UploadService) listChunkKeys(ctx context.Context, sessionID int64) ([]string, error) {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil || session == nil {
return nil, fmt.Errorf("get session: %w", err)
}
keys := make([]string, session.TotalChunks)
for i := 0; i < session.TotalChunks; i++ {
keys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
}
return keys, nil
}
func updateStatusTx(tx *gorm.DB, ctx context.Context, id int64, status string) error {
result := tx.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
func isSQLite(db *gorm.DB) bool {
return db.Dialector.Name() == "sqlite"
}
func objectKeys(objects []storage.ObjectInfo) []string {
keys := make([]string, len(objects))
for i, o := range objects {
keys[i] = o.Key
}
return keys
}
type errComposeFailed struct {
err error
}
func (e errComposeFailed) Error() string {
return fmt.Sprintf("compose failed: %v", e.err)
}

View File

@@ -0,0 +1,678 @@
package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type uploadMockStorage struct {
putObjectFn func(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error)
composeObjectFn func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error)
listObjectsFn func(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error)
removeObjectsFn func(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error
removeObjectFn func(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
bucketExistsFn func(ctx context.Context, bucket string) (bool, error)
makeBucketFn func(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error
statObjectFn func(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error)
abortMultipartFn func(ctx context.Context, bucket, object, uploadID string) error
removeIncompleteFn func(ctx context.Context, bucket, object string) error
}
func (m *uploadMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
if m.putObjectFn != nil {
return m.putObjectFn(ctx, bucket, key, reader, size, opts)
}
io.Copy(io.Discard, reader)
return storage.UploadInfo{ETag: "etag", Size: size}, nil
}
func (m *uploadMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return nil, storage.ObjectInfo{}, nil
}
func (m *uploadMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
if m.composeObjectFn != nil {
return m.composeObjectFn(ctx, bucket, dst, sources)
}
return storage.UploadInfo{ETag: "composed", Size: 0}, nil
}
func (m *uploadMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
if m.abortMultipartFn != nil {
return m.abortMultipartFn(ctx, bucket, object, uploadID)
}
return nil
}
func (m *uploadMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
if m.removeIncompleteFn != nil {
return m.removeIncompleteFn(ctx, bucket, object)
}
return nil
}
func (m *uploadMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
if m.removeObjectFn != nil {
return m.removeObjectFn(ctx, bucket, key, opts)
}
return nil
}
func (m *uploadMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
if m.listObjectsFn != nil {
return m.listObjectsFn(ctx, bucket, prefix, recursive)
}
return nil, nil
}
func (m *uploadMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
if m.removeObjectsFn != nil {
return m.removeObjectsFn(ctx, bucket, keys, opts)
}
return nil
}
func (m *uploadMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
if m.bucketExistsFn != nil {
return m.bucketExistsFn(ctx, bucket)
}
return true, nil
}
func (m *uploadMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
if m.makeBucketFn != nil {
return m.makeBucketFn(ctx, bucket, opts)
}
return nil
}
func (m *uploadMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
if m.statObjectFn != nil {
return m.statObjectFn(ctx, bucket, key, opts)
}
return storage.ObjectInfo{}, nil
}
func setupUploadTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}, &model.UploadSession{}, &model.UploadChunk{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func uploadTestConfig() config.MinioConfig {
return config.MinioConfig{
Bucket: "test-bucket",
ChunkSize: 16 << 20,
MaxFileSize: 50 << 30,
MinChunkSize: 5 << 20,
SessionTTL: 48,
}
}
func newUploadTestService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *UploadService {
t.Helper()
return NewUploadService(
st,
store.NewBlobStore(db),
store.NewFileStore(db),
store.NewUploadStore(db),
uploadTestConfig(),
db,
zap.NewNop(),
)
}
func sha256Of(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
func TestInitUpload_CreatesSession(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.txt",
FileSize: 32 << 20,
SHA256: "abc123",
})
if err != nil {
t.Fatalf("InitUpload: %v", err)
}
sessResp, ok := resp.(model.UploadSessionResponse)
if !ok {
t.Fatalf("expected UploadSessionResponse, got %T", resp)
}
if sessResp.Status != "pending" {
t.Errorf("status = %q, want pending", sessResp.Status)
}
if sessResp.TotalChunks != 2 {
t.Errorf("TotalChunks = %d, want 2", sessResp.TotalChunks)
}
if sessResp.FileName != "test.txt" {
t.Errorf("FileName = %q, want test.txt", sessResp.FileName)
}
}
func TestInitUpload_DedupBlobExists(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
blobSHA := sha256Of([]byte("hello"))
blob := &model.FileBlob{
SHA256: blobSHA,
MinioKey: "files/" + blobSHA,
FileSize: 5,
MimeType: "text/plain",
RefCount: 1,
}
if err := db.Create(blob).Error; err != nil {
t.Fatalf("create blob: %v", err)
}
resp, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "mydoc.txt",
FileSize: 5,
SHA256: blobSHA,
})
if err != nil {
t.Fatalf("InitUpload: %v", err)
}
fileResp, ok := resp.(model.FileResponse)
if !ok {
t.Fatalf("expected FileResponse, got %T", resp)
}
if fileResp.Name != "mydoc.txt" {
t.Errorf("Name = %q, want mydoc.txt", fileResp.Name)
}
if fileResp.SHA256 != blobSHA {
t.Errorf("SHA256 mismatch")
}
var after model.FileBlob
db.First(&after, blob.ID)
if after.RefCount != 2 {
t.Errorf("RefCount = %d, want 2", after.RefCount)
}
}
func TestInitUpload_ChunkSizeTooSmall(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
tiny := int64(1024)
_, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.txt",
FileSize: 10 << 20,
SHA256: "abc",
ChunkSize: &tiny,
})
if err == nil {
t.Fatal("expected error for chunk size too small")
}
if !strings.Contains(err.Error(), "below minimum") {
t.Errorf("error = %q, want 'below minimum'", err.Error())
}
}
func TestInitUpload_TooManyChunks(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
cfg := uploadTestConfig()
cfg.ChunkSize = 1
cfg.MinChunkSize = 1
svc2 := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
_, err := svc2.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "big.txt",
FileSize: 10002,
SHA256: "abc",
})
if err == nil {
t.Fatal("expected error for too many chunks")
}
if !strings.Contains(err.Error(), "exceeds limit") {
t.Errorf("error = %q, want 'exceeds limit'", err.Error())
}
}
func TestInitUpload_DangerousFilename(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
for _, name := range []string{"", "..", "foo/bar", "foo\\bar", " name"} {
_, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: name,
FileSize: 100,
SHA256: "abc",
})
if err == nil {
t.Errorf("expected error for filename %q", name)
}
}
}
func TestUploadChunk_UploadsAndStoresSHA256(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 << 20,
SHA256: "deadbeef",
})
sessResp := resp.(model.UploadSessionResponse)
data := []byte("chunk data here")
chunkSHA := sha256Of(data)
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader(data), int64(len(data)))
if err != nil {
t.Fatalf("UploadChunk: %v", err)
}
var chunk model.UploadChunk
db.Where("session_id = ? AND chunk_index = ?", sessResp.ID, 0).First(&chunk)
if chunk.SHA256 != chunkSHA {
t.Errorf("chunk SHA256 = %q, want %q", chunk.SHA256, chunkSHA)
}
if chunk.Status != "uploaded" {
t.Errorf("chunk status = %q, want uploaded", chunk.Status)
}
session, _ := svc.uploadStore.GetSession(context.Background(), sessResp.ID)
if session.Status != "uploading" {
t.Errorf("session status = %q, want uploading", session.Status)
}
}
func TestUploadChunk_RejectsCompletedSession(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 << 20,
SHA256: "deadbeef",
})
sessResp := resp.(model.UploadSessionResponse)
svc.uploadStore.UpdateSessionStatus(context.Background(), sessResp.ID, "completed")
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader([]byte("x")), 1)
if err == nil {
t.Fatal("expected error for completed session")
}
if !strings.Contains(err.Error(), "cannot upload") {
t.Errorf("error = %q", err.Error())
}
}
func TestUploadChunk_RejectsExpiredSession(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 << 20,
SHA256: "deadbeef",
})
sessResp := resp.(model.UploadSessionResponse)
svc.uploadStore.UpdateSessionStatus(context.Background(), sessResp.ID, "expired")
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader([]byte("x")), 1)
if err == nil {
t.Fatal("expected error for expired session")
}
}
func TestGetUploadStatus_ReturnsChunkIndices(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 32 << 20,
SHA256: "abc",
})
sessResp := resp.(model.UploadSessionResponse)
data := []byte("chunk0data")
svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader(data), int64(len(data)))
status, err := svc.GetUploadStatus(context.Background(), sessResp.ID)
if err != nil {
t.Fatalf("GetUploadStatus: %v", err)
}
if len(status.UploadedChunks) != 1 || status.UploadedChunks[0] != 0 {
t.Errorf("UploadedChunks = %v, want [0]", status.UploadedChunks)
}
if status.TotalChunks != 2 {
t.Errorf("TotalChunks = %d, want 2", status.TotalChunks)
}
}
func TestCompleteUpload_CreatesBlobAndFile(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "aaa111"
sess := &model.UploadSession{
FileName: "test.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/testcomplete/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/testcomplete/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/testcomplete/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload: %v", err)
}
if fileResp.Name != "test.bin" {
t.Errorf("Name = %q, want test.bin", fileResp.Name)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 = %q, want %q", fileResp.SHA256, fileSHA)
}
var blob model.FileBlob
db.Where("sha256 = ?", fileSHA).First(&blob)
if blob.RefCount != 1 {
t.Errorf("blob RefCount = %d, want 1", blob.RefCount)
}
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
if session == nil {
t.Fatal("session should exist")
}
if session.Status != "completed" {
t.Errorf("session status = %q, want completed", session.Status)
}
}
func TestCompleteUpload_ReusesExistingBlob(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
composeCalled := false
st.composeObjectFn = func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
composeCalled = true
return storage.UploadInfo{}, nil
}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "reuse123"
db.Create(&model.FileBlob{
SHA256: fileSHA,
MinioKey: "files/" + fileSHA,
FileSize: 10 << 20,
MimeType: "application/octet-stream",
RefCount: 1,
})
sess := &model.UploadSession{
FileName: "reuse.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/reuse/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/reuse/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/reuse/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload: %v", err)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 mismatch")
}
if composeCalled {
t.Error("ComposeObject should not be called when blob exists")
}
var blob model.FileBlob
db.Where("sha256 = ?", fileSHA).First(&blob)
if blob.RefCount != 2 {
t.Errorf("RefCount = %d, want 2", blob.RefCount)
}
}
func TestCompleteUpload_NotAllChunks(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
sess := &model.UploadSession{
FileName: "partial.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: "partial123",
Status: "uploading",
MinioPrefix: "uploads/partial/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/partial/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
_, err := svc.CompleteUpload(context.Background(), sess.ID)
if err == nil {
t.Fatal("expected error for incomplete chunks")
}
if !strings.Contains(err.Error(), "not all chunks uploaded") {
t.Errorf("error = %q", err.Error())
}
}
func TestCompleteUpload_ComposeObjectFails(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
st.composeObjectFn = func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, fmt.Errorf("compose failed")
}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "fail123"
sess := &model.UploadSession{
FileName: "fail.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/fail/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/fail/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/fail/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
_, err := svc.CompleteUpload(context.Background(), sess.ID)
if err == nil {
t.Fatal("expected error for compose failure")
}
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
if session.Status != "failed" {
t.Errorf("session status = %q, want failed", session.Status)
}
}
func TestCompleteUpload_RetriesFailedSession(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "retry123"
sess := &model.UploadSession{
FileName: "retry.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "failed",
MinioPrefix: "uploads/retry/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/retry/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/retry/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload on retry: %v", err)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 mismatch")
}
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
if session.Status != "completed" {
t.Errorf("session status = %q, want completed", session.Status)
}
}
func TestCancelUpload_CleansUp(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "cancel.bin",
FileSize: 10 << 20,
SHA256: "cancel123",
})
sessResp := resp.(model.UploadSessionResponse)
err := svc.CancelUpload(context.Background(), sessResp.ID)
if err != nil {
t.Fatalf("CancelUpload: %v", err)
}
session, _ := svc.uploadStore.GetSession(context.Background(), sessResp.ID)
if session != nil {
t.Error("session should be deleted")
}
}
func TestZeroByteFile_CompletesImmediately(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
fileSHA := "empty123"
sess := &model.UploadSession{
FileName: "empty.bin",
FileSize: 0,
ChunkSize: 16 << 20,
TotalChunks: 0,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/empty/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload zero byte: %v", err)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 mismatch")
}
if fileResp.Size != 0 {
t.Errorf("Size = %d, want 0", fileResp.Size)
}
var blob model.FileBlob
db.Where("sha256 = ?", fileSHA).First(&blob)
if blob.RefCount != 1 {
t.Errorf("RefCount = %d, want 1", blob.RefCount)
}
uploadStore := store.NewUploadStore(db)
session, _ := uploadStore.GetSession(context.Background(), sess.ID)
if session == nil {
t.Fatal("session should exist")
}
if session.Status != "completed" {
t.Errorf("session status = %q, want completed", session.Status)
}
}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
)
const (
@@ -16,6 +17,8 @@ const (
DefaultBaseURL = "http://localhost:6820/"
// DefaultUserAgent is the default User-Agent header value.
DefaultUserAgent = "slurm-go-sdk"
// DefaultTimeout is the default HTTP request timeout.
DefaultTimeout = 30 * time.Second
)
// Client manages communication with the Slurm REST API.
@@ -85,7 +88,7 @@ type Response struct {
// http.DefaultClient is used.
func NewClient(baseURL string, httpClient *http.Client) (*Client, error) {
if httpClient == nil {
httpClient = http.DefaultClient
httpClient = &http.Client{Timeout: DefaultTimeout}
}
parsedURL, err := url.Parse(baseURL)

View File

@@ -20,7 +20,7 @@ func TestNewClient(t *testing.T) {
t.Errorf("expected UserAgent %q, got %q", DefaultUserAgent, client.UserAgent)
}
if client.client == nil {
t.Error("expected http.Client to be initialized (nil httpClient should default to http.DefaultClient)")
t.Error("expected http.Client to be initialized (nil httpClient should create a client with default timeout)")
}
_, err = NewClient("://invalid", nil)
@@ -137,11 +137,11 @@ func TestClient_ErrorHandling(t *testing.T) {
t.Fatal("expected error for 500 response")
}
errorResp, ok := err.(*ErrorResponse)
errorResp, ok := err.(*SlurmAPIError)
if !ok {
t.Fatalf("expected *ErrorResponse, got %T", err)
t.Fatalf("expected *SlurmAPIError, got %T", err)
}
if errorResp.Response.StatusCode != 500 {
t.Errorf("expected status 500, got %d", errorResp.Response.StatusCode)
if errorResp.StatusCode != 500 {
t.Errorf("expected status 500, got %d", errorResp.StatusCode)
}
}

View File

@@ -1,38 +1,85 @@
package slurm
import (
"encoding/json"
"fmt"
"io"
"net/http"
)
// ErrorResponse represents an error returned by the Slurm REST API.
type ErrorResponse struct {
Response *http.Response
Message string
// errorResponseFields is used to parse errors/warnings from a Slurm API error body.
type errorResponseFields struct {
Errors OpenapiErrors `json:"errors,omitempty"`
Warnings OpenapiWarnings `json:"warnings,omitempty"`
}
func (r *ErrorResponse) Error() string {
// SlurmAPIError represents a structured error returned by the Slurm REST API.
// It captures both the HTTP details and the parsed Slurm error array when available.
type SlurmAPIError struct {
Response *http.Response
StatusCode int
Errors OpenapiErrors
Warnings OpenapiWarnings
Message string // raw body fallback when JSON parsing fails
}
func (e *SlurmAPIError) Error() string {
if len(e.Errors) > 0 {
first := e.Errors[0]
detail := ""
if first.Error != nil {
detail = *first.Error
} else if first.Description != nil {
detail = *first.Description
}
if detail != "" {
return fmt.Sprintf("%v %v: %d %s",
e.Response.Request.Method, e.Response.Request.URL,
e.StatusCode, detail)
}
}
return fmt.Sprintf("%v %v: %d %s",
r.Response.Request.Method, r.Response.Request.URL,
r.Response.StatusCode, r.Message)
e.Response.Request.Method, e.Response.Request.URL,
e.StatusCode, e.Message)
}
// IsNotFound reports whether err is a SlurmAPIError with HTTP 404 status.
func IsNotFound(err error) bool {
if apiErr, ok := err.(*SlurmAPIError); ok {
return apiErr.StatusCode == http.StatusNotFound
}
return false
}
// CheckResponse checks the API response for errors. It returns nil if the
// response is a 2xx status code. For non-2xx codes, it reads the response
// body and returns an ErrorResponse.
// body, attempts to parse structured Slurm errors, and returns a SlurmAPIError.
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; c >= 200 && c <= 299 {
return nil
}
errorResponse := &ErrorResponse{Response: r}
data, err := io.ReadAll(r.Body)
if err != nil || len(data) == 0 {
errorResponse.Message = r.Status
return errorResponse
return &SlurmAPIError{
Response: r,
StatusCode: r.StatusCode,
Message: r.Status,
}
}
errorResponse.Message = string(data)
return errorResponse
apiErr := &SlurmAPIError{
Response: r,
StatusCode: r.StatusCode,
Message: string(data),
}
// Try to extract structured errors/warnings from JSON body.
var fields errorResponseFields
if json.Unmarshal(data, &fields) == nil {
apiErr.Errors = fields.Errors
apiErr.Warnings = fields.Warnings
}
return apiErr
}

View File

@@ -0,0 +1,220 @@
package slurm
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCheckResponse(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
wantErr bool
wantStatusCode int
wantErrors int
wantWarnings int
wantMessageContains string
}{
{
name: "2xx response returns nil",
statusCode: http.StatusOK,
body: `{"meta":{}}`,
wantErr: false,
},
{
name: "404 with valid JSON error body",
statusCode: http.StatusNotFound,
body: `{"errors":[{"description":"Unable to query JobId=198","error_number":2017,"error":"Invalid job id specified","source":"_handle_job_get"}],"warnings":[]}`,
wantErr: true,
wantStatusCode: 404,
wantErrors: 1,
wantWarnings: 0,
wantMessageContains: "Invalid job id specified",
},
{
name: "500 with non-JSON body",
statusCode: http.StatusInternalServerError,
body: "internal server error",
wantErr: true,
wantStatusCode: 500,
wantErrors: 0,
wantWarnings: 0,
wantMessageContains: "internal server error",
},
{
name: "503 with empty body returns http.Status text",
statusCode: http.StatusServiceUnavailable,
body: "",
wantErr: true,
wantStatusCode: 503,
wantErrors: 0,
wantWarnings: 0,
wantMessageContains: "503 Service Unavailable",
},
{
name: "400 with valid JSON but empty errors array",
statusCode: http.StatusBadRequest,
body: `{"errors":[],"warnings":[]}`,
wantErr: true,
wantStatusCode: 400,
wantErrors: 0,
wantWarnings: 0,
wantMessageContains: `{"errors":[],"warnings":[]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
rec.WriteHeader(tt.statusCode)
if tt.body != "" {
rec.Body.WriteString(tt.body)
}
err := CheckResponse(rec.Result())
if (err != nil) != tt.wantErr {
t.Fatalf("CheckResponse() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
return
}
apiErr, ok := err.(*SlurmAPIError)
if !ok {
t.Fatalf("expected *SlurmAPIError, got %T", err)
}
if apiErr.StatusCode != tt.wantStatusCode {
t.Errorf("StatusCode = %d, want %d", apiErr.StatusCode, tt.wantStatusCode)
}
if len(apiErr.Errors) != tt.wantErrors {
t.Errorf("len(Errors) = %d, want %d", len(apiErr.Errors), tt.wantErrors)
}
if len(apiErr.Warnings) != tt.wantWarnings {
t.Errorf("len(Warnings) = %d, want %d", len(apiErr.Warnings), tt.wantWarnings)
}
if !strings.Contains(apiErr.Message, tt.wantMessageContains) {
t.Errorf("Message = %q, want to contain %q", apiErr.Message, tt.wantMessageContains)
}
})
}
}
func TestIsNotFound(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "404 SlurmAPIError returns true",
err: &SlurmAPIError{StatusCode: http.StatusNotFound},
want: true,
},
{
name: "500 SlurmAPIError returns false",
err: &SlurmAPIError{StatusCode: http.StatusInternalServerError},
want: false,
},
{
name: "200 SlurmAPIError returns false",
err: &SlurmAPIError{StatusCode: http.StatusOK},
want: false,
},
{
name: "plain error returns false",
err: fmt.Errorf("some error"),
want: false,
},
{
name: "nil returns false",
err: nil,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsNotFound(tt.err); got != tt.want {
t.Errorf("IsNotFound() = %v, want %v", got, tt.want)
}
})
}
}
func TestSlurmAPIError_Error(t *testing.T) {
fakeReq := httptest.NewRequest("GET", "http://localhost/slurm/v0.0.40/job/123", nil)
tests := []struct {
name string
err *SlurmAPIError
wantContains []string
}{
{
name: "with Error field set",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusNotFound,
Errors: OpenapiErrors{{Error: Ptr("Job not found")}},
Message: "raw body",
},
wantContains: []string{"404", "Job not found"},
},
{
name: "with Description field set when Error is nil",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusBadRequest,
Errors: OpenapiErrors{{Description: Ptr("Unable to query")}},
Message: "raw body",
},
wantContains: []string{"400", "Unable to query"},
},
{
name: "with both Error and Description nil falls through to Message",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusInternalServerError,
Errors: OpenapiErrors{{}},
Message: "something went wrong",
},
wantContains: []string{"500", "something went wrong"},
},
{
name: "with empty Errors slice falls through to Message",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusServiceUnavailable,
Errors: OpenapiErrors{},
Message: "service unavailable fallback",
},
wantContains: []string{"503", "service unavailable fallback"},
},
{
name: "with non-empty Errors but empty detail string falls through to Message",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusBadGateway,
Errors: OpenapiErrors{{ErrorNumber: Ptr(int32(42))}},
Message: "gateway error detail",
},
wantContains: []string{"502", "gateway error detail"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.err.Error()
for _, substr := range tt.wantContains {
if !strings.Contains(got, substr) {
t.Errorf("Error() = %q, want to contain %q", got, substr)
}
}
})
}
}

View File

@@ -24,6 +24,8 @@ func defaultClientConfig() *clientConfig {
}
}
const defaultHTTPTimeout = 30 * time.Second
// WithJWTKey specifies the path to the JWT key file.
func WithJWTKey(path string) ClientOption {
return func(c *clientConfig) error {
@@ -89,11 +91,12 @@ func NewClientWithOpts(baseURL string, opts ...ClientOption) (*Client, error) {
}
tr := NewJWTAuthTransport(cfg.username, key, transportOpts...)
httpClient = tr.Client()
httpClient = &http.Client{
Transport: tr,
Timeout: defaultHTTPTimeout,
}
} else if cfg.httpClient != nil {
httpClient = cfg.httpClient
} else {
httpClient = http.DefaultClient
}
return NewClient(baseURL, httpClient)

View File

@@ -2,7 +2,6 @@ package slurm
import (
"crypto/rand"
"net/http"
"os"
"path/filepath"
"testing"
@@ -61,8 +60,8 @@ func TestNewClientWithOpts_BackwardCompatible(t *testing.T) {
if client == nil {
t.Fatal("expected non-nil client")
}
if client.client != http.DefaultClient {
t.Error("expected http.DefaultClient when no options provided")
if client.client.Timeout != DefaultTimeout {
t.Errorf("expected Timeout=%v, got %v", DefaultTimeout, client.client.Timeout)
}
}

View File

@@ -54,8 +54,8 @@ type PartitionInfoMaximumsOversubscribe struct {
// PartitionInfoMaximums represents maximum resource limits for a partition (v0.0.40_partition_info.maximums).
type PartitionInfoMaximums struct {
CpusPerNode *int32 `json:"cpus_per_node,omitempty"`
CpusPerSocket *int32 `json:"cpus_per_socket,omitempty"`
CpusPerNode *Uint32NoVal `json:"cpus_per_node,omitempty"`
CpusPerSocket *Uint32NoVal `json:"cpus_per_socket,omitempty"`
MemoryPerCPU *int64 `json:"memory_per_cpu,omitempty"`
PartitionMemoryPerCPU *Uint64NoVal `json:"partition_memory_per_cpu,omitempty"`
PartitionMemoryPerNode *Uint64NoVal `json:"partition_memory_per_node,omitempty"`

View File

@@ -43,8 +43,8 @@ func TestPartitionInfoRoundTrip(t *testing.T) {
},
GraceTime: Ptr(int32(300)),
Maximums: &PartitionInfoMaximums{
CpusPerNode: Ptr(int32(128)),
CpusPerSocket: Ptr(int32(64)),
CpusPerNode: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(128))},
CpusPerSocket: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(64))},
MemoryPerCPU: Ptr(int64(8192)),
PartitionMemoryPerCPU: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(8192))},
PartitionMemoryPerNode: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(262144))},

286
internal/storage/minio.go Normal file
View File

@@ -0,0 +1,286 @@
package storage
import (
"context"
"fmt"
"io"
"time"
"gcy_hpc_server/internal/config"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
// ObjectInfo contains metadata about a stored object.
type ObjectInfo struct {
Key string
Size int64
LastModified time.Time
ETag string
ContentType string
}
// UploadInfo contains metadata about an uploaded object.
type UploadInfo struct {
ETag string
Size int64
}
// GetOptions specifies parameters for GetObject, including optional Range.
type GetOptions struct {
Start *int64 // Range start byte offset (nil = no range)
End *int64 // Range end byte offset (nil = no range)
}
// MultipartUpload represents an incomplete multipart upload.
type MultipartUpload struct {
ObjectName string
UploadID string
Initiated time.Time
}
// RemoveObjectsOptions specifies options for removing multiple objects.
type RemoveObjectsOptions struct {
ForceDelete bool
}
// PutObjectOptions for PutObject.
type PutObjectOptions struct {
ContentType string
DisableMultipart bool // true for small chunks (already pre-split)
}
// RemoveObjectOptions for RemoveObject.
type RemoveObjectOptions struct {
ForceDelete bool
}
// MakeBucketOptions for MakeBucket.
type MakeBucketOptions struct {
Region string
}
// StatObjectOptions for StatObject.
type StatObjectOptions struct{}
// ObjectStorage defines the interface for object storage operations.
// Implementations should wrap MinIO SDK calls with custom transfer types.
type ObjectStorage interface {
// PutObject uploads an object from a reader.
PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts PutObjectOptions) (UploadInfo, error)
// GetObject retrieves an object. opts may contain Range parameters.
GetObject(ctx context.Context, bucket, key string, opts GetOptions) (io.ReadCloser, ObjectInfo, error)
// ComposeObject merges multiple source objects into a single destination.
ComposeObject(ctx context.Context, bucket, dst string, sources []string) (UploadInfo, error)
// AbortMultipartUpload aborts an incomplete multipart upload by upload ID.
AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error
// RemoveIncompleteUpload removes all incomplete uploads for an object.
// This is the preferred cleanup method — it encapsulates list + abort logic.
RemoveIncompleteUpload(ctx context.Context, bucket, object string) error
// RemoveObject deletes a single object.
RemoveObject(ctx context.Context, bucket, key string, opts RemoveObjectOptions) error
// ListObjects lists objects with a given prefix.
ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]ObjectInfo, error)
// RemoveObjects deletes multiple objects.
RemoveObjects(ctx context.Context, bucket string, keys []string, opts RemoveObjectsOptions) error
// BucketExists checks if a bucket exists.
BucketExists(ctx context.Context, bucket string) (bool, error)
// MakeBucket creates a new bucket.
MakeBucket(ctx context.Context, bucket string, opts MakeBucketOptions) error
// StatObject gets metadata about an object without downloading it.
StatObject(ctx context.Context, bucket, key string, opts StatObjectOptions) (ObjectInfo, error)
}
var _ ObjectStorage = (*MinioClient)(nil)
type MinioClient struct {
core *minio.Core
bucket string
}
func NewMinioClient(cfg config.MinioConfig) (*MinioClient, error) {
transport, err := minio.DefaultTransport(cfg.UseSSL)
if err != nil {
return nil, fmt.Errorf("create default transport: %w", err)
}
transport.MaxIdleConnsPerHost = 100
transport.IdleConnTimeout = 90 * time.Second
core, err := minio.NewCore(cfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
Secure: cfg.UseSSL,
Transport: transport,
})
if err != nil {
return nil, fmt.Errorf("create minio client: %w", err)
}
mc := &MinioClient{core: core, bucket: cfg.Bucket}
ctx := context.Background()
exists, err := core.BucketExists(ctx, cfg.Bucket)
if err != nil {
return nil, fmt.Errorf("check bucket: %w", err)
}
if !exists {
if err := core.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{}); err != nil {
return nil, fmt.Errorf("create bucket: %w", err)
}
}
return mc, nil
}
func (m *MinioClient) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts PutObjectOptions) (UploadInfo, error) {
info, err := m.core.PutObject(ctx, bucket, key, reader, size, "", "", minio.PutObjectOptions{
ContentType: opts.ContentType,
DisableMultipart: opts.DisableMultipart,
})
if err != nil {
return UploadInfo{}, fmt.Errorf("put object %s/%s: %w", bucket, key, err)
}
return UploadInfo{ETag: info.ETag, Size: info.Size}, nil
}
func (m *MinioClient) GetObject(ctx context.Context, bucket, key string, opts GetOptions) (io.ReadCloser, ObjectInfo, error) {
var gopts minio.GetObjectOptions
if opts.Start != nil || opts.End != nil {
start := int64(0)
end := int64(0)
if opts.Start != nil {
start = *opts.Start
}
if opts.End != nil {
end = *opts.End
}
if err := gopts.SetRange(start, end); err != nil {
return nil, ObjectInfo{}, fmt.Errorf("set range: %w", err)
}
}
body, info, _, err := m.core.GetObject(ctx, bucket, key, gopts)
if err != nil {
return nil, ObjectInfo{}, fmt.Errorf("get object %s/%s: %w", bucket, key, err)
}
return body, toObjectInfo(info), nil
}
func (m *MinioClient) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (UploadInfo, error) {
srcs := make([]minio.CopySrcOptions, len(sources))
for i, src := range sources {
srcs[i] = minio.CopySrcOptions{Bucket: bucket, Object: src}
}
do := minio.CopyDestOptions{
Bucket: bucket,
Object: dst,
ReplaceMetadata: true,
}
info, err := m.core.ComposeObject(ctx, do, srcs...)
if err != nil {
return UploadInfo{}, fmt.Errorf("compose object %s/%s: %w", bucket, dst, err)
}
return UploadInfo{ETag: info.ETag, Size: info.Size}, nil
}
func (m *MinioClient) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
if err := m.core.AbortMultipartUpload(ctx, bucket, object, uploadID); err != nil {
return fmt.Errorf("abort multipart upload %s/%s %s: %w", bucket, object, uploadID, err)
}
return nil
}
func (m *MinioClient) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
if err := m.core.RemoveIncompleteUpload(ctx, bucket, object); err != nil {
return fmt.Errorf("remove incomplete upload %s/%s: %w", bucket, object, err)
}
return nil
}
func (m *MinioClient) RemoveObject(ctx context.Context, bucket, key string, opts RemoveObjectOptions) error {
if err := m.core.RemoveObject(ctx, bucket, key, minio.RemoveObjectOptions{
ForceDelete: opts.ForceDelete,
}); err != nil {
return fmt.Errorf("remove object %s/%s: %w", bucket, key, err)
}
return nil
}
func (m *MinioClient) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]ObjectInfo, error) {
ch := m.core.Client.ListObjects(ctx, bucket, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: recursive,
})
var result []ObjectInfo
for obj := range ch {
if obj.Err != nil {
return result, fmt.Errorf("list objects %s/%s: %w", bucket, prefix, obj.Err)
}
result = append(result, toObjectInfo(obj))
}
return result, nil
}
func (m *MinioClient) RemoveObjects(ctx context.Context, bucket string, keys []string, opts RemoveObjectsOptions) error {
objectsCh := make(chan minio.ObjectInfo, len(keys))
for _, key := range keys {
objectsCh <- minio.ObjectInfo{Key: key}
}
close(objectsCh)
errCh := m.core.RemoveObjects(ctx, bucket, objectsCh, minio.RemoveObjectsOptions{})
for err := range errCh {
if err.Err != nil {
return fmt.Errorf("remove object %s: %w", err.ObjectName, err.Err)
}
}
return nil
}
func (m *MinioClient) BucketExists(ctx context.Context, bucket string) (bool, error) {
ok, err := m.core.BucketExists(ctx, bucket)
if err != nil {
return false, fmt.Errorf("bucket exists %s: %w", bucket, err)
}
return ok, nil
}
func (m *MinioClient) MakeBucket(ctx context.Context, bucket string, opts MakeBucketOptions) error {
if err := m.core.MakeBucket(ctx, bucket, minio.MakeBucketOptions{
Region: opts.Region,
}); err != nil {
return fmt.Errorf("make bucket %s: %w", bucket, err)
}
return nil
}
func (m *MinioClient) StatObject(ctx context.Context, bucket, key string, _ StatObjectOptions) (ObjectInfo, error) {
info, err := m.core.StatObject(ctx, bucket, key, minio.StatObjectOptions{})
if err != nil {
return ObjectInfo{}, fmt.Errorf("stat object %s/%s: %w", bucket, key, err)
}
return toObjectInfo(info), nil
}
func toObjectInfo(info minio.ObjectInfo) ObjectInfo {
return ObjectInfo{
Key: info.Key,
Size: info.Size,
LastModified: info.LastModified,
ETag: info.ETag,
ContentType: info.ContentType,
}
}

View File

@@ -0,0 +1,7 @@
package storage
import "testing"
func TestMinioClientImplementsObjectStorage(t *testing.T) {
var _ ObjectStorage = (*MinioClient)(nil)
}

View File

@@ -0,0 +1,114 @@
package store
import (
"context"
"encoding/json"
"errors"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type ApplicationStore struct {
db *gorm.DB
}
func NewApplicationStore(db *gorm.DB) *ApplicationStore {
return &ApplicationStore{db: db}
}
func (s *ApplicationStore) List(ctx context.Context, page, pageSize int) ([]model.Application, int, error) {
var apps []model.Application
var total int64
if err := s.db.WithContext(ctx).Model(&model.Application{}).Count(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&apps).Error; err != nil {
return nil, 0, err
}
return apps, int(total), nil
}
func (s *ApplicationStore) GetByID(ctx context.Context, id int64) (*model.Application, error) {
var app model.Application
err := s.db.WithContext(ctx).First(&app, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &app, nil
}
func (s *ApplicationStore) Create(ctx context.Context, req *model.CreateApplicationRequest) (int64, error) {
params := req.Parameters
if len(params) == 0 {
params = json.RawMessage(`[]`)
}
app := &model.Application{
Name: req.Name,
Description: req.Description,
Icon: req.Icon,
Category: req.Category,
ScriptTemplate: req.ScriptTemplate,
Parameters: params,
Scope: req.Scope,
}
if err := s.db.WithContext(ctx).Create(app).Error; err != nil {
return 0, err
}
return app.ID, nil
}
func (s *ApplicationStore) Update(ctx context.Context, id int64, req *model.UpdateApplicationRequest) error {
updates := map[string]interface{}{}
if req.Name != nil {
updates["name"] = *req.Name
}
if req.Description != nil {
updates["description"] = *req.Description
}
if req.Icon != nil {
updates["icon"] = *req.Icon
}
if req.Category != nil {
updates["category"] = *req.Category
}
if req.ScriptTemplate != nil {
updates["script_template"] = *req.ScriptTemplate
}
if req.Parameters != nil {
updates["parameters"] = *req.Parameters
}
if req.Scope != nil {
updates["scope"] = *req.Scope
}
if len(updates) == 0 {
return nil
}
result := s.db.WithContext(ctx).Model(&model.Application{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
func (s *ApplicationStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.Application{}, id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,227 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newAppTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Application{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func TestApplicationStore_Create_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, err := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "gromacs",
Description: "Molecular dynamics simulator",
Category: "simulation",
ScriptTemplate: "#!/bin/bash\nmodule load gromacs",
Parameters: []byte(`[{"name":"ntasks","type":"number","required":true}]`),
Scope: "system",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Errorf("Create() id = %d, want positive", id)
}
}
func TestApplicationStore_GetByID_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "lammps",
ScriptTemplate: "#!/bin/bash\nmodule load lammps",
})
app, err := s.GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if app == nil {
t.Fatal("GetByID() returned nil, expected application")
}
if app.Name != "lammps" {
t.Errorf("Name = %q, want %q", app.Name, "lammps")
}
if app.ID != id {
t.Errorf("ID = %d, want %d", app.ID, id)
}
}
func TestApplicationStore_GetByID_NotFound(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
app, err := s.GetByID(context.Background(), 99999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if app != nil {
t.Error("GetByID() expected nil for not-found, got non-nil")
}
}
func TestApplicationStore_List_Pagination(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
for i := 0; i < 5; i++ {
s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "app-" + string(rune('A'+i)),
ScriptTemplate: "#!/bin/bash\necho " + string(rune('A'+i)),
})
}
apps, total, err := s.List(context.Background(), 1, 3)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
if len(apps) != 3 {
t.Errorf("len(apps) = %d, want 3", len(apps))
}
apps2, total2, err := s.List(context.Background(), 2, 3)
if err != nil {
t.Fatalf("List() page 2 error = %v", err)
}
if total2 != 5 {
t.Errorf("total2 = %d, want 5", total2)
}
if len(apps2) != 2 {
t.Errorf("len(apps2) = %d, want 2", len(apps2))
}
}
func TestApplicationStore_Update_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "orig",
ScriptTemplate: "#!/bin/bash\necho original",
})
newName := "updated"
newDesc := "updated description"
err := s.Update(context.Background(), id, &model.UpdateApplicationRequest{
Name: &newName,
Description: &newDesc,
})
if err != nil {
t.Fatalf("Update() error = %v", err)
}
app, _ := s.GetByID(context.Background(), id)
if app.Name != "updated" {
t.Errorf("Name = %q, want %q", app.Name, "updated")
}
if app.Description != "updated description" {
t.Errorf("Description = %q, want %q", app.Description, "updated description")
}
}
func TestApplicationStore_Update_NotFound(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
name := "nope"
err := s.Update(context.Background(), 99999, &model.UpdateApplicationRequest{
Name: &name,
})
if err == nil {
t.Fatal("Update() expected error for not-found, got nil")
}
}
func TestApplicationStore_Delete_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "to-delete",
ScriptTemplate: "#!/bin/bash\necho bye",
})
err := s.Delete(context.Background(), id)
if err != nil {
t.Fatalf("Delete() error = %v", err)
}
app, _ := s.GetByID(context.Background(), id)
if app != nil {
t.Error("GetByID() after delete returned non-nil")
}
}
func TestApplicationStore_Delete_Idempotent(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
err := s.Delete(context.Background(), 99999)
if err != nil {
t.Fatalf("Delete() non-existent error = %v, want nil (idempotent)", err)
}
}
func TestApplicationStore_Create_DuplicateName(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
_, err := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "dup-app",
ScriptTemplate: "#!/bin/bash\necho 1",
})
if err != nil {
t.Fatalf("first Create() error = %v", err)
}
_, err = s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "dup-app",
ScriptTemplate: "#!/bin/bash\necho 2",
})
if err == nil {
t.Fatal("expected error for duplicate name, got nil")
}
}
func TestApplicationStore_Create_EmptyParameters(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, err := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "no-params",
ScriptTemplate: "#!/bin/bash\necho hello",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
app, _ := s.GetByID(context.Background(), id)
if string(app.Parameters) != "[]" {
t.Errorf("Parameters = %q, want []", string(app.Parameters))
}
}

View File

@@ -0,0 +1,108 @@
package store
import (
"context"
"errors"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// BlobStore manages physical file blobs with reference counting.
type BlobStore struct {
db *gorm.DB
}
// NewBlobStore creates a new BlobStore.
func NewBlobStore(db *gorm.DB) *BlobStore {
return &BlobStore{db: db}
}
// Create inserts a new FileBlob record.
func (s *BlobStore) Create(ctx context.Context, blob *model.FileBlob) error {
return s.db.WithContext(ctx).Create(blob).Error
}
// GetBySHA256 returns the FileBlob with the given SHA256 hash.
// Returns (nil, nil) if not found.
func (s *BlobStore) GetBySHA256(ctx context.Context, sha256 string) (*model.FileBlob, error) {
var blob model.FileBlob
err := s.db.WithContext(ctx).Where("sha256 = ?", sha256).First(&blob).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &blob, nil
}
// IncrementRef atomically increments the ref_count for the blob with the given SHA256.
func (s *BlobStore) IncrementRef(ctx context.Context, sha256 string) error {
result := s.db.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ?", sha256).
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DecrementRef atomically decrements the ref_count for the blob with the given SHA256.
// Returns the new ref_count after decrementing.
func (s *BlobStore) DecrementRef(ctx context.Context, sha256 string) (int64, error) {
result := s.db.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ? AND ref_count > 0", sha256).
UpdateColumn("ref_count", gorm.Expr("ref_count - 1"))
if result.Error != nil {
return 0, result.Error
}
if result.RowsAffected == 0 {
return 0, gorm.ErrRecordNotFound
}
var blob model.FileBlob
if err := s.db.WithContext(ctx).Where("sha256 = ?", sha256).First(&blob).Error; err != nil {
return 0, err
}
return int64(blob.RefCount), nil
}
// Delete removes a FileBlob record by SHA256 (hard delete).
func (s *BlobStore) Delete(ctx context.Context, sha256 string) error {
result := s.db.WithContext(ctx).Where("sha256 = ?", sha256).Delete(&model.FileBlob{})
if result.Error != nil {
return result.Error
}
return nil
}
func (s *BlobStore) GetBySHA256s(ctx context.Context, sha256s []string) ([]model.FileBlob, error) {
var blobs []model.FileBlob
if len(sha256s) == 0 {
return blobs, nil
}
err := s.db.WithContext(ctx).Where("sha256 IN ?", sha256s).Find(&blobs).Error
return blobs, err
}
// GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock.
// Returns (nil, nil) if not found.
func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) {
var blob model.FileBlob
err := tx.WithContext(ctx).
Clauses(clause.Locking{Strength: "UPDATE"}).
Where("sha256 = ?", sha256).First(&blob).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &blob, nil
}

View File

@@ -0,0 +1,199 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupBlobTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.FileBlob{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestBlobStore_Create(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blob := &model.FileBlob{
SHA256: "abc123",
MinioKey: "files/abc123",
FileSize: 1024,
MimeType: "application/octet-stream",
RefCount: 0,
}
if err := store.Create(ctx, blob); err != nil {
t.Fatalf("Create() error = %v", err)
}
if blob.ID == 0 {
t.Error("Create() did not set ID")
}
}
func TestBlobStore_GetBySHA256(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
blob, err := store.GetBySHA256(ctx, "abc")
if err != nil {
t.Fatalf("GetBySHA256() error = %v", err)
}
if blob == nil {
t.Fatal("GetBySHA256() returned nil")
}
if blob.SHA256 != "abc" {
t.Errorf("SHA256 = %q, want %q", blob.SHA256, "abc")
}
if blob.RefCount != 0 {
t.Errorf("RefCount = %d, want 0", blob.RefCount)
}
}
func TestBlobStore_GetBySHA256_NotFound(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blob, err := store.GetBySHA256(ctx, "nonexistent")
if err != nil {
t.Fatalf("GetBySHA256() error = %v", err)
}
if blob != nil {
t.Error("GetBySHA256() should return nil for not found")
}
}
func TestBlobStore_IncrementDecrementRef(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
if err := store.IncrementRef(ctx, "abc"); err != nil {
t.Fatalf("IncrementRef() error = %v", err)
}
blob, _ := store.GetBySHA256(ctx, "abc")
if blob.RefCount != 1 {
t.Errorf("RefCount after 1st increment = %d, want 1", blob.RefCount)
}
store.IncrementRef(ctx, "abc")
blob, _ = store.GetBySHA256(ctx, "abc")
if blob.RefCount != 2 {
t.Errorf("RefCount after 2nd increment = %d, want 2", blob.RefCount)
}
refCount, err := store.DecrementRef(ctx, "abc")
if err != nil {
t.Fatalf("DecrementRef() error = %v", err)
}
if refCount != 1 {
t.Errorf("DecrementRef() returned %d, want 1", refCount)
}
refCount, _ = store.DecrementRef(ctx, "abc")
if refCount != 0 {
t.Errorf("DecrementRef() returned %d, want 0", refCount)
}
}
func TestBlobStore_Delete(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
if err := store.Delete(ctx, "abc"); err != nil {
t.Fatalf("Delete() error = %v", err)
}
blob, _ := store.GetBySHA256(ctx, "abc")
if blob != nil {
t.Error("Delete() did not remove blob")
}
}
func TestBlobStore_GetBySHA256s(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "h1", MinioKey: "files/h1", FileSize: 100})
store.Create(ctx, &model.FileBlob{SHA256: "h2", MinioKey: "files/h2", FileSize: 200})
store.Create(ctx, &model.FileBlob{SHA256: "h3", MinioKey: "files/h3", FileSize: 300})
blobs, err := store.GetBySHA256s(ctx, []string{"h1", "h3"})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 2 {
t.Fatalf("len(blobs) = %d, want 2", len(blobs))
}
keys := map[string]bool{}
for _, b := range blobs {
keys[b.SHA256] = true
}
if !keys["h1"] || !keys["h3"] {
t.Errorf("expected h1 and h3, got %v", blobs)
}
}
func TestBlobStore_GetBySHA256s_Empty(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blobs, err := store.GetBySHA256s(ctx, []string{})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 0 {
t.Errorf("len(blobs) = %d, want 0", len(blobs))
}
}
func TestBlobStore_GetBySHA256s_NotFound(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blobs, err := store.GetBySHA256s(ctx, []string{"nonexistent"})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 0 {
t.Errorf("len(blobs) = %d, want 0 for non-existent SHA256s", len(blobs))
}
}
func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "dup", MinioKey: "files/dup1", FileSize: 100})
err := store.Create(ctx, &model.FileBlob{SHA256: "dup", MinioKey: "files/dup2", FileSize: 200})
if err == nil {
t.Error("expected error for duplicate SHA256, got nil")
}
}

View File

@@ -0,0 +1,108 @@
package store
import (
"context"
"errors"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type FileStore struct {
db *gorm.DB
}
func NewFileStore(db *gorm.DB) *FileStore {
return &FileStore{db: db}
}
func (s *FileStore) Create(ctx context.Context, file *model.File) error {
return s.db.WithContext(ctx).Create(file).Error
}
func (s *FileStore) GetByID(ctx context.Context, id int64) (*model.File, error) {
var file model.File
err := s.db.WithContext(ctx).First(&file, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &file, nil
}
func (s *FileStore) List(ctx context.Context, folderID *int64, page, pageSize int) ([]model.File, int64, error) {
query := s.db.WithContext(ctx).Model(&model.File{})
if folderID == nil {
query = query.Where("folder_id IS NULL")
} else {
query = query.Where("folder_id = ?", *folderID)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var files []model.File
offset := (page - 1) * pageSize
if err := query.Order("id DESC").Limit(pageSize).Offset(offset).Find(&files).Error; err != nil {
return nil, 0, err
}
return files, total, nil
}
func (s *FileStore) Search(ctx context.Context, queryStr string, page, pageSize int) ([]model.File, int64, error) {
query := s.db.WithContext(ctx).Model(&model.File{}).Where("name LIKE ?", "%"+queryStr+"%")
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var files []model.File
offset := (page - 1) * pageSize
if err := query.Order("id DESC").Limit(pageSize).Offset(offset).Find(&files).Error; err != nil {
return nil, 0, err
}
return files, total, nil
}
func (s *FileStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.File{}, id)
if result.Error != nil {
return result.Error
}
return nil
}
func (s *FileStore) CountByBlobSHA256(ctx context.Context, blobSHA256 string) (int64, error) {
var count int64
err := s.db.WithContext(ctx).Model(&model.File{}).
Where("blob_sha256 = ?", blobSHA256).
Count(&count).Error
return count, err
}
func (s *FileStore) GetByIDs(ctx context.Context, ids []int64) ([]model.File, error) {
var files []model.File
if len(ids) == 0 {
return files, nil
}
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&files).Error
return files, err
}
func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) {
var file model.File
err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil
}
if err != nil {
return "", err
}
return file.BlobSHA256, nil
}

View File

@@ -0,0 +1,323 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupFileTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestFileStore_CreateAndGetByID(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
file := &model.File{
Name: "test.bin",
BlobSHA256: "abc123",
}
if err := store.Create(ctx, file); err != nil {
t.Fatalf("Create() error = %v", err)
}
if file.ID == 0 {
t.Fatal("Create() did not set ID")
}
got, err := store.GetByID(ctx, file.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.Name != "test.bin" {
t.Errorf("Name = %q, want %q", got.Name, "test.bin")
}
if got.BlobSHA256 != "abc123" {
t.Errorf("BlobSHA256 = %q, want %q", got.BlobSHA256, "abc123")
}
}
func TestFileStore_GetByID_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
got, err := store.GetByID(ctx, 999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() should return nil for not found")
}
}
func TestFileStore_ListByFolder(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
folderID := int64(1)
store.Create(ctx, &model.File{Name: "f1.bin", BlobSHA256: "a1", FolderID: &folderID})
store.Create(ctx, &model.File{Name: "f2.bin", BlobSHA256: "a2", FolderID: &folderID})
store.Create(ctx, &model.File{Name: "root.bin", BlobSHA256: "a3"}) // root (folder_id=nil)
files, total, err := store.List(ctx, &folderID, 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
if len(files) != 2 {
t.Errorf("len(files) = %d, want 2", len(files))
}
}
func TestFileStore_ListRootFolder(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "root.bin", BlobSHA256: "a1"})
folderID := int64(1)
store.Create(ctx, &model.File{Name: "sub.bin", BlobSHA256: "a2", FolderID: &folderID})
files, total, err := store.List(ctx, nil, 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
if len(files) != 1 {
t.Errorf("len(files) = %d, want 1", len(files))
}
if files[0].Name != "root.bin" {
t.Errorf("files[0].Name = %q, want %q", files[0].Name, "root.bin")
}
}
func TestFileStore_Pagination(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
for i := 0; i < 25; i++ {
store.Create(ctx, &model.File{Name: "file.bin", BlobSHA256: "hash"})
}
files, total, err := store.List(ctx, nil, 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 25 {
t.Errorf("total = %d, want 25", total)
}
if len(files) != 10 {
t.Errorf("page 1 len = %d, want 10", len(files))
}
files, _, _ = store.List(ctx, nil, 3, 10)
if len(files) != 5 {
t.Errorf("page 3 len = %d, want 5", len(files))
}
}
func TestFileStore_Search(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "experiment_results.csv", BlobSHA256: "a1"})
store.Create(ctx, &model.File{Name: "training_log.txt", BlobSHA256: "a2"})
store.Create(ctx, &model.File{Name: "model_weights.bin", BlobSHA256: "a3"})
files, total, err := store.Search(ctx, "results", 1, 10)
if err != nil {
t.Fatalf("Search() error = %v", err)
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
if len(files) != 1 || files[0].Name != "experiment_results.csv" {
t.Errorf("expected experiment_results.csv, got %v", files)
}
}
func TestFileStore_Delete_SoftDelete(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
file := &model.File{Name: "deleteme.bin", BlobSHA256: "abc"}
store.Create(ctx, file)
if err := store.Delete(ctx, file.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
// GetByID should return nil (soft deleted)
got, err := store.GetByID(ctx, file.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() should return nil after soft delete")
}
// List should not include soft deleted
_, total, _ := store.List(ctx, nil, 1, 10)
if total != 0 {
t.Errorf("total after delete = %d, want 0", total)
}
}
func TestFileStore_CountByBlobSHA256(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "shared_hash"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "shared_hash"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "shared_hash"})
count, err := store.CountByBlobSHA256(ctx, "shared_hash")
if err != nil {
t.Fatalf("CountByBlobSHA256() error = %v", err)
}
if count != 3 {
t.Errorf("count = %d, want 3", count)
}
// Soft delete one
store.Delete(ctx, 1)
count, _ = store.CountByBlobSHA256(ctx, "shared_hash")
if count != 2 {
t.Errorf("count after soft delete = %d, want 2", count)
}
}
func TestFileStore_GetBlobSHA256ByID(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
file := &model.File{Name: "test.bin", BlobSHA256: "my_hash"}
store.Create(ctx, file)
sha256, err := store.GetBlobSHA256ByID(ctx, file.ID)
if err != nil {
t.Fatalf("GetBlobSHA256ByID() error = %v", err)
}
if sha256 != "my_hash" {
t.Errorf("sha256 = %q, want %q", sha256, "my_hash")
}
}
func TestFileStore_GetByIDs(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
files, err := store.GetByIDs(ctx, []int64{1, 3})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 2 {
t.Fatalf("len(files) = %d, want 2", len(files))
}
names := map[string]bool{}
for _, f := range files {
names[f.Name] = true
}
if !names["a.bin"] || !names["c.bin"] {
t.Errorf("expected a.bin and c.bin, got %v", files)
}
}
func TestFileStore_GetByIDs_Empty(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
files, err := store.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 0 {
t.Errorf("len(files) = %d, want 0", len(files))
}
}
func TestFileStore_GetByIDs_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
files, err := store.GetByIDs(ctx, []int64{999})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 0 {
t.Errorf("len(files) = %d, want 0 for non-existent IDs", len(files))
}
}
func TestFileStore_GetByIDs_SoftDeleteExcluded(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
store.Delete(ctx, 2)
files, err := store.GetByIDs(ctx, []int64{1, 2, 3})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 2 {
t.Fatalf("len(files) = %d, want 2 (soft-deleted excluded)", len(files))
}
for _, f := range files {
if f.ID == 2 {
t.Error("soft-deleted file ID 2 should not appear")
}
}
}
func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
sha256, err := store.GetBlobSHA256ByID(ctx, 999)
if err != nil {
t.Fatalf("GetBlobSHA256ByID() error = %v", err)
}
if sha256 != "" {
t.Errorf("sha256 = %q, want empty for not found", sha256)
}
}

View File

@@ -0,0 +1,105 @@
package store
import (
"context"
"errors"
"fmt"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type FolderStore struct {
db *gorm.DB
}
func NewFolderStore(db *gorm.DB) *FolderStore {
return &FolderStore{db: db}
}
func (s *FolderStore) Create(ctx context.Context, folder *model.Folder) error {
return s.db.WithContext(ctx).Create(folder).Error
}
func (s *FolderStore) GetByID(ctx context.Context, id int64) (*model.Folder, error) {
var folder model.Folder
err := s.db.WithContext(ctx).First(&folder, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &folder, nil
}
func (s *FolderStore) GetByPath(ctx context.Context, path string) (*model.Folder, error) {
var folder model.Folder
err := s.db.WithContext(ctx).Where("path = ?", path).First(&folder).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &folder, nil
}
func (s *FolderStore) ListByParentID(ctx context.Context, parentID *int64) ([]model.Folder, error) {
var folders []model.Folder
query := s.db.WithContext(ctx)
if parentID == nil {
query = query.Where("parent_id IS NULL")
} else {
query = query.Where("parent_id = ?", *parentID)
}
if err := query.Order("name ASC").Find(&folders).Error; err != nil {
return nil, err
}
return folders, nil
}
// GetSubTree returns all folders whose path starts with the given prefix.
func (s *FolderStore) GetSubTree(ctx context.Context, path string) ([]model.Folder, error) {
var folders []model.Folder
if err := s.db.WithContext(ctx).Where("path LIKE ?", path+"%").Find(&folders).Error; err != nil {
return nil, err
}
return folders, nil
}
// HasChildren checks if a folder has sub-folders or files.
func (s *FolderStore) HasChildren(ctx context.Context, id int64) (bool, error) {
folder, err := s.GetByID(ctx, id)
if err != nil {
return false, err
}
if folder == nil {
return false, nil
}
// Check for sub-folders
var subFolderCount int64
if err := s.db.WithContext(ctx).Model(&model.Folder{}).Where("parent_id = ?", id).Count(&subFolderCount).Error; err != nil {
return false, fmt.Errorf("count sub-folders: %w", err)
}
if subFolderCount > 0 {
return true, nil
}
// Check for files
var fileCount int64
if err := s.db.WithContext(ctx).Model(&model.File{}).Where("folder_id = ?", id).Count(&fileCount).Error; err != nil {
return false, fmt.Errorf("count files: %w", err)
}
return fileCount > 0, nil
}
func (s *FolderStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.Folder{}, id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,294 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupFolderTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Folder{}, &model.File{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestFolderStore_CreateAndGetByID(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{
Name: "data",
Path: "/data/",
}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create() error = %v", err)
}
if folder.ID <= 0 {
t.Fatalf("Create() id = %d, want positive", folder.ID)
}
got, err := s.GetByID(context.Background(), folder.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.Name != "data" {
t.Errorf("Name = %q, want %q", got.Name, "data")
}
if got.Path != "/data/" {
t.Errorf("Path = %q, want %q", got.Path, "/data/")
}
}
func TestFolderStore_GetByID_NotFound(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
got, err := s.GetByID(context.Background(), 99999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() expected nil for not-found")
}
}
func TestFolderStore_GetByPath(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{
Name: "data",
Path: "/data/",
}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create() error = %v", err)
}
got, err := s.GetByPath(context.Background(), "/data/")
if err != nil {
t.Fatalf("GetByPath() error = %v", err)
}
if got == nil {
t.Fatal("GetByPath() returned nil")
}
if got.ID != folder.ID {
t.Errorf("ID = %d, want %d", got.ID, folder.ID)
}
got, err = s.GetByPath(context.Background(), "/nonexistent/")
if err != nil {
t.Fatalf("GetByPath() nonexistent error = %v", err)
}
if got != nil {
t.Error("GetByPath() expected nil for nonexistent path")
}
}
func TestFolderStore_ListByParentID(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
root1 := &model.Folder{Name: "alpha", Path: "/alpha/"}
root2 := &model.Folder{Name: "beta", Path: "/beta/"}
if err := s.Create(context.Background(), root1); err != nil {
t.Fatalf("Create root1: %v", err)
}
if err := s.Create(context.Background(), root2); err != nil {
t.Fatalf("Create root2: %v", err)
}
sub := &model.Folder{Name: "sub", Path: "/alpha/sub/", ParentID: &root1.ID}
if err := s.Create(context.Background(), sub); err != nil {
t.Fatalf("Create sub: %v", err)
}
roots, err := s.ListByParentID(context.Background(), nil)
if err != nil {
t.Fatalf("ListByParentID(nil) error = %v", err)
}
if len(roots) != 2 {
t.Fatalf("root folders = %d, want 2", len(roots))
}
if roots[0].Name != "alpha" {
t.Errorf("roots[0].Name = %q, want %q (alphabetical)", roots[0].Name, "alpha")
}
children, err := s.ListByParentID(context.Background(), &root1.ID)
if err != nil {
t.Fatalf("ListByParentID(root1) error = %v", err)
}
if len(children) != 1 {
t.Fatalf("children = %d, want 1", len(children))
}
if children[0].Name != "sub" {
t.Errorf("children[0].Name = %q, want %q", children[0].Name, "sub")
}
}
func TestFolderStore_GetSubTree(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
data := &model.Folder{Name: "data", Path: "/data/"}
if err := s.Create(context.Background(), data); err != nil {
t.Fatalf("Create data: %v", err)
}
results := &model.Folder{Name: "results", Path: "/data/results/", ParentID: &data.ID}
if err := s.Create(context.Background(), results); err != nil {
t.Fatalf("Create results: %v", err)
}
other := &model.Folder{Name: "other", Path: "/other/"}
if err := s.Create(context.Background(), other); err != nil {
t.Fatalf("Create other: %v", err)
}
subtree, err := s.GetSubTree(context.Background(), "/data/")
if err != nil {
t.Fatalf("GetSubTree() error = %v", err)
}
if len(subtree) != 2 {
t.Fatalf("subtree = %d, want 2", len(subtree))
}
}
func TestFolderStore_HasChildren_WithSubFolders(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
parent := &model.Folder{Name: "parent", Path: "/parent/"}
if err := s.Create(context.Background(), parent); err != nil {
t.Fatalf("Create parent: %v", err)
}
child := &model.Folder{Name: "child", Path: "/parent/child/", ParentID: &parent.ID}
if err := s.Create(context.Background(), child); err != nil {
t.Fatalf("Create child: %v", err)
}
has, err := s.HasChildren(context.Background(), parent.ID)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if !has {
t.Error("HasChildren() = false, want true (has sub-folders)")
}
}
func TestFolderStore_HasChildren_WithFiles(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{Name: "docs", Path: "/docs/"}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create folder: %v", err)
}
file := &model.File{
Name: "readme.txt",
FolderID: &folder.ID,
BlobSHA256: "abc123",
}
if err := db.Create(file).Error; err != nil {
t.Fatalf("Create file: %v", err)
}
has, err := s.HasChildren(context.Background(), folder.ID)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if !has {
t.Error("HasChildren() = false, want true (has files)")
}
}
func TestFolderStore_HasChildren_Empty(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{Name: "empty", Path: "/empty/"}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create folder: %v", err)
}
has, err := s.HasChildren(context.Background(), folder.ID)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if has {
t.Error("HasChildren() = true, want false (empty folder)")
}
}
func TestFolderStore_HasChildren_NotFound(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
has, err := s.HasChildren(context.Background(), 99999)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if has {
t.Error("HasChildren() = true for nonexistent, want false")
}
}
func TestFolderStore_Delete(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{Name: "temp", Path: "/temp/"}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create() error = %v", err)
}
if err := s.Delete(context.Background(), folder.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
got, err := s.GetByID(context.Background(), folder.ID)
if err != nil {
t.Fatalf("GetByID() after delete error = %v", err)
}
if got != nil {
t.Error("GetByID() after delete returned non-nil, expected soft-deleted")
}
}
func TestFolderStore_Delete_Idempotent(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
if err := s.Delete(context.Background(), 99999); err != nil {
t.Fatalf("Delete() non-existent error = %v, want nil (idempotent)", err)
}
}
func TestFolderStore_Path_UniqueConstraint(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
f1 := &model.Folder{Name: "data", Path: "/data/"}
if err := s.Create(context.Background(), f1); err != nil {
t.Fatalf("Create first: %v", err)
}
f2 := &model.Folder{Name: "data2", Path: "/data/"}
if err := s.Create(context.Background(), f2); err == nil {
t.Fatal("expected error for duplicate path, got nil")
}
}

52
internal/store/mysql.go Normal file
View File

@@ -0,0 +1,52 @@
package store
import (
"fmt"
"time"
"gcy_hpc_server/internal/logger"
"gcy_hpc_server/internal/model"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// NewGormDB opens a GORM MySQL connection with sensible defaults.
func NewGormDB(dsn string, zapLogger *zap.Logger, gormLevel string) (*gorm.DB, error) {
gormCfg := &gorm.Config{
Logger: logger.NewGormLogger(zapLogger, gormLevel),
}
db, err := gorm.Open(mysql.Open(dsn), gormCfg)
if err != nil {
return nil, fmt.Errorf("failed to open gorm mysql: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(5)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping mysql: %w", err)
}
return db, nil
}
// AutoMigrate runs GORM auto-migration for all models.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&model.Application{},
&model.FileBlob{},
&model.File{},
&model.Folder{},
&model.UploadSession{},
&model.UploadChunk{},
&model.Task{},
)
}

View File

@@ -0,0 +1,14 @@
package store
import (
"testing"
"go.uber.org/zap"
)
func TestNewGormDBInvalidDSN(t *testing.T) {
_, err := NewGormDB("invalid:dsn@tcp(localhost:99999)/nonexistent?parseTime=true", zap.NewNop(), "warn")
if err == nil {
t.Fatal("expected error for invalid DSN, got nil")
}
}

View File

@@ -0,0 +1,141 @@
package store
import (
"context"
"errors"
"fmt"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type TaskStore struct {
db *gorm.DB
}
func NewTaskStore(db *gorm.DB) *TaskStore {
return &TaskStore{db: db}
}
func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) {
if err := s.db.WithContext(ctx).Create(task).Error; err != nil {
return 0, err
}
return task.ID, nil
}
func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) {
var task model.Task
err := s.db.WithContext(ctx).First(&task, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &task, nil
}
func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
page := query.Page
pageSize := query.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 10
}
q := s.db.WithContext(ctx).Model(&model.Task{})
if query.Status != "" {
q = q.Where("status = ?", query.Status)
}
var total int64
if err := q.Count(&total).Error; err != nil {
return nil, 0, err
}
var tasks []model.Task
offset := (page - 1) * pageSize
if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
return nil, 0, err
}
return tasks, total, nil
}
func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error {
updates := map[string]interface{}{
"status": status,
"error_message": errorMsg,
}
now := time.Now()
switch status {
case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady,
model.TaskStatusQueued, model.TaskStatusRunning:
updates["started_at"] = &now
case model.TaskStatusCompleted, model.TaskStatusFailed:
updates["finished_at"] = &now
}
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
Update("slurm_job_id", slurmJobID)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
Update("work_dir", workDir)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
"status": status,
"current_step": currentStep,
"retry_count": retryCount,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) {
cutoff := time.Now().Add(-maxAge)
var tasks []model.Task
err := s.db.WithContext(ctx).
Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}).
Where("updated_at < ?", cutoff).
Find(&tasks).Error
return tasks, err
}

View File

@@ -0,0 +1,229 @@
package store
import (
"context"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func makeTestTask(name, status string) *model.Task {
return &model.Task{
TaskName: name,
AppID: 1,
AppName: "test-app",
Status: status,
CurrentStep: "",
RetryCount: 0,
UserID: "user1",
SubmittedAt: time.Now(),
}
}
func TestTaskStore_CreateAndGetByID(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
task := makeTestTask("test-task", model.TaskStatusSubmitted)
id, err := s.Create(ctx, task)
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Errorf("Create() id = %d, want positive", id)
}
got, err := s.GetByID(ctx, id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.TaskName != "test-task" {
t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task")
}
if got.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
}
if got.ID != id {
t.Errorf("ID = %d, want %d", got.ID, id)
}
}
func TestTaskStore_GetByID_NotFound(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
got, err := s.GetByID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() expected nil for not-found, got non-nil")
}
}
func TestTaskStore_ListPagination(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted))
}
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
if len(tasks) != 3 {
t.Errorf("len(tasks) = %d, want 3", len(tasks))
}
tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3})
if err != nil {
t.Fatalf("List() page 2 error = %v", err)
}
if total2 != 5 {
t.Errorf("total2 = %d, want 5", total2)
}
if len(tasks2) != 2 {
t.Errorf("len(tasks2) = %d, want 2", len(tasks2))
}
}
func TestTaskStore_ListStatusFilter(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning))
s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning))
s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted))
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
for _, t2 := range tasks {
if t2.Status != model.TaskStatusRunning {
t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning)
}
}
}
func TestTaskStore_UpdateStatus(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted))
err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "")
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
got, _ := s.GetByID(ctx, id)
if got.Status != model.TaskStatusPreparing {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing)
}
if got.StartedAt == nil {
t.Error("StartedAt expected non-nil after preparing status")
}
err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke")
if err != nil {
t.Fatalf("UpdateStatus(failed) error = %v", err)
}
got, _ = s.GetByID(ctx, id)
if got.ErrorMessage != "something broke" {
t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke")
}
if got.FinishedAt == nil {
t.Error("FinishedAt expected non-nil after failed status")
}
}
func TestTaskStore_UpdateRetryState(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
task := makeTestTask("retry-test", model.TaskStatusFailed)
task.CurrentStep = model.TaskStepDownloading
task.RetryCount = 1
id, _ := s.Create(ctx, task)
err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2)
if err != nil {
t.Fatalf("UpdateRetryState() error = %v", err)
}
got, _ := s.GetByID(ctx, id)
if got.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
}
if got.CurrentStep != model.TaskStepDownloading {
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading)
}
if got.RetryCount != 2 {
t.Errorf("RetryCount = %d, want 2", got.RetryCount)
}
}
func TestTaskStore_GetStuckTasks(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
stuck := makeTestTask("stuck-1", model.TaskStatusDownloading)
stuck.UpdatedAt = time.Now().Add(-1 * time.Hour)
s.Create(ctx, stuck)
recent := makeTestTask("recent-1", model.TaskStatusDownloading)
s.Create(ctx, recent)
done := makeTestTask("done-1", model.TaskStatusCompleted)
done.UpdatedAt = time.Now().Add(-2 * time.Hour)
s.Create(ctx, done)
tasks, err := s.GetStuckTasks(ctx, 30*time.Minute)
if err != nil {
t.Fatalf("GetStuckTasks() error = %v", err)
}
if len(tasks) != 1 {
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
}
if tasks[0].TaskName != "stuck-1" {
t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1")
}
}

View File

@@ -0,0 +1,117 @@
package store
import (
"context"
"errors"
"fmt"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UploadStore manages upload sessions and chunks with idempotent upsert support.
type UploadStore struct {
db *gorm.DB
}
// NewUploadStore creates a new UploadStore.
func NewUploadStore(db *gorm.DB) *UploadStore {
return &UploadStore{db: db}
}
// CreateSession inserts a new upload session.
func (s *UploadStore) CreateSession(ctx context.Context, session *model.UploadSession) error {
return s.db.WithContext(ctx).Create(session).Error
}
// GetSession returns an upload session by ID. Returns (nil, nil) if not found.
func (s *UploadStore) GetSession(ctx context.Context, id int64) (*model.UploadSession, error) {
var session model.UploadSession
err := s.db.WithContext(ctx).First(&session, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &session, nil
}
// GetSessionWithChunks returns an upload session along with its chunks ordered by chunk_index.
func (s *UploadStore) GetSessionWithChunks(ctx context.Context, id int64) (*model.UploadSession, []model.UploadChunk, error) {
session, err := s.GetSession(ctx, id)
if err != nil {
return nil, nil, err
}
if session == nil {
return nil, nil, nil
}
var chunks []model.UploadChunk
if err := s.db.WithContext(ctx).Where("session_id = ?", id).Order("chunk_index ASC").Find(&chunks).Error; err != nil {
return nil, nil, err
}
return session, chunks, nil
}
// UpdateSessionStatus updates the status field of an upload session.
// Returns gorm.ErrRecordNotFound if no row was affected.
func (s *UploadStore) UpdateSessionStatus(ctx context.Context, id int64, status string) error {
result := s.db.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// ListExpiredSessions returns sessions that are not in a terminal state and have expired.
func (s *UploadStore) ListExpiredSessions(ctx context.Context) ([]model.UploadSession, error) {
var sessions []model.UploadSession
err := s.db.WithContext(ctx).
Where("status NOT IN ?", []string{"completed", "cancelled", "expired"}).
Where("expires_at < ?", time.Now()).
Find(&sessions).Error
return sessions, err
}
// DeleteSession removes all chunks for a session, then the session itself.
func (s *UploadStore) DeleteSession(ctx context.Context, id int64) error {
if err := s.db.WithContext(ctx).Where("session_id = ?", id).Delete(&model.UploadChunk{}).Error; err != nil {
return fmt.Errorf("delete chunks: %w", err)
}
result := s.db.WithContext(ctx).Delete(&model.UploadSession{}, id)
return result.Error
}
// UpsertChunk inserts a chunk or updates it if the (session_id, chunk_index) pair already exists.
// Uses GORM clause.OnConflict for dialect-neutral upsert (works with both SQLite and MySQL).
func (s *UploadStore) UpsertChunk(ctx context.Context, chunk *model.UploadChunk) error {
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "session_id"}, {Name: "chunk_index"}},
DoUpdates: clause.AssignmentColumns([]string{"minio_key", "sha256", "size", "status", "updated_at"}),
}).Create(chunk).Error
}
// GetUploadedChunkIndices returns the chunk indices that have been successfully uploaded.
func (s *UploadStore) GetUploadedChunkIndices(ctx context.Context, sessionID int64) ([]int, error) {
var indices []int
err := s.db.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Pluck("chunk_index", &indices).Error
return indices, err
}
// CountUploadedChunks returns the number of chunks with status "uploaded" for a session.
func (s *UploadStore) CountUploadedChunks(ctx context.Context, sessionID int64) (int, error) {
var count int64
err := s.db.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Count(&count).Error
return int(count), err
}

Some files were not shown because too many files have changed in this diff Show More