Compare commits

...

40 Commits

Author SHA1 Message Date
dailz
9092278d26 refactor(test): update test expectations for removed submit route
- Comment out submit route assertions in main_test.go and server_test.go

- Comment out TestTask_OldAPICompatibility in task_test.go

- Update expected route count 31→30 in testenv env_test.go

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:16:07 +08:00
dailz
7c374f4fd5 refactor(handler,server): disable SubmitApplication endpoint, replaced by POST /tasks
- Comment out SubmitApplication handler method

- Comment out route registration in server.go (interface + router + placeholder)

- Comment out related handler tests

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:15:55 +08:00
dailz
36d842350c refactor(service): disable SubmitFromApplication fallback, fully replaced by POST /tasks
- Comment out SubmitFromApplication method and its fallback path

- Comment out 5 tests that tested the old direct-submission code

- Remove unused imports after commenting out the method

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:15:42 +08:00
dailz
80f2bd32d9 docs(openapi): update spec to match code — add Tasks, fix schemas, remove submit endpoint
- Add Tasks tag, /tasks paths, and Task schemas (CreateTaskRequest, TaskResponse, TaskListResponse)

- Fix SubmitJobRequest.work_dir, InitUploadRequest mime_type/chunk_size, UploadSessionResponse.created_at

- Fix FolderResponse: add file_count/subfolder_count, remove updated_at

- Fix response wrapping for File/Upload/Folder endpoints to use ApiResponseSuccess

- Remove /applications/{id}/submit path and ApplicationSubmitRequest schema

- Update Applications tag description

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-16 15:15:31 +08:00
dailz
52a34e2cb0 feat(client): add CLI client entry point 2026-04-16 13:24:21 +08:00
dailz
b9b2f0d9b4 feat(testutil): add MockSlurm, MockMinIO, TestEnv and 37 integration tests
- mockminio: in-memory ObjectStorage with all 11 methods, thread-safe, SHA256 ETag, Range support
- mockslurm: httptest server with 11 Slurm REST API endpoints, job eviction from active to history queue
- testenv: one-line test environment factory (SQLite + MockSlurm + MockMinIO + all stores/services/handlers + httptest server)
- integration tests: 37 tests covering Jobs(5), Cluster(5), App(6), Upload(5), File(4), Folder(4), Task(4), E2E(1)
- no external dependencies, no existing files modified
2026-04-16 13:23:27 +08:00
dailz
73504f9fdb feat(app): add TaskPoller, wire DI, and add task integration tests 2026-04-15 21:31:17 +08:00
dailz
3f8a680c99 feat(handler): add TaskHandler endpoints and register task routes 2026-04-15 21:31:11 +08:00
dailz
ec64300ff2 feat(service): add TaskService, FileStagingService, and refactor ApplicationService for task submission 2026-04-15 21:31:02 +08:00
dailz
acf8c1d62b feat(store): add TaskStore CRUD and batch query methods for files and blobs 2026-04-15 21:30:51 +08:00
dailz
d46a784efb feat(model): add Task model, DTOs, and status constants for task submission system 2026-04-15 21:30:44 +08:00
dailz
79870333cb fix(service): tolerate concurrent pending-to-uploading status race in UploadChunk
When multiple chunk uploads race on the pending→uploading transition, ignore ErrRecordNotFound from UpdateSessionStatus since another request already completed the update.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 10:27:12 +08:00
dailz
d9a60c3511 fix(model): rename Application table to hpc_applications
Avoid table name collision with other systems.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:32:11 +08:00
dailz
20576bc325 docs(openapi): add file storage API specifications
Add 13 endpoints for chunked upload, file management, and folder CRUD with 6 new schemas.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:30:18 +08:00
dailz
c0176d7764 feat(app): wire file storage DI, cleanup worker, and integration tests
Add DI wiring with graceful MinIO fallback, background cleanup worker for expired sessions and leaked multipart uploads, and end-to-end integration tests.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:23:25 +08:00
dailz
2298e92516 feat(handler): add upload, file, and folder handlers with routes
Add UploadHandler (5 endpoints), FileHandler (4 endpoints), FolderHandler (4 endpoints) with Gin route registration in server.go.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:23:17 +08:00
dailz
f0847d3978 feat(service): add upload, download, file, and folder services
Add UploadService (dedup, chunk lifecycle, ComposeObject), DownloadService (Range support), FileService (ref counting), FolderService (path validation).

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:23:09 +08:00
dailz
a114821615 feat(server): add streaming response helpers for file download
Add ParseRange, StreamFile, StreamRange for full and partial content delivery.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:58 +08:00
dailz
bf89de12f0 feat(store): add blob, file, folder, and upload stores
Add BlobStore (ref counting), FileStore (soft delete + pagination), FolderStore (materialized path), UploadStore (idempotent upsert), and update AutoMigrate.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:44 +08:00
dailz
c861ff3adf feat(storage): add ObjectStorage interface and MinIO client
Add ObjectStorage interface (11 methods) with MinioClient implementation using minio-go Core.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:33 +08:00
dailz
0e4f523746 feat(model): add file storage GORM models and DTOs
Add FileBlob, File, Folder, UploadSession, UploadChunk models with validators.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:25 +08:00
dailz
44895214d4 feat(config): add MinIO object storage configuration
Add MinioConfig struct with connection, bucket, chunk size, and session TTL settings.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-15 09:22:18 +08:00
dailz
a65c8762af fix(service): add environment variables and fix work directory permissions for Slurm job submission
Slurm requires environment variables in job submission; without them it returns 'batch job cannot run without an environment'. Also chmod the entire directory path to 0777 to bypass umask, ensuring Slurm and compute node users can write.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-14 13:06:51 +08:00
dailz
04f99cc1c4 docs(openapi): update spec for Application Definition
Add 6 application endpoints and schemas to OpenAPI spec. Update .gitignore.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:13:02 +08:00
dailz
32f5792b68 feat(service): pass work directory to Slurm job submission
Add WorkDir to SubmitJobRequest and pass it as CurrentWorkingDirectory to Slurm REST API. Fixes Slurm 500 error when working directory is not specified.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:12:28 +08:00
dailz
328691adff feat(config): add WorkDirBase for application job working directory
Add WorkDirBase config field for auto-generated job working directories. Pattern: {base}/{app_name}/{timestamp}_{random}/

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:11:48 +08:00
dailz
10bb15e5b2 feat(handler): add Application handler, routes, and wiring
Add ApplicationHandler with CRUD + Submit endpoints. Register 6 routes, wire in app.go, update main_test.go references. 22 handler tests.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:10:54 +08:00
dailz
d3eb728c2f feat(service): add Application service with parameter validation and script rendering
Add ApplicationService with ValidateParams, RenderScript, SubmitFromApplication. Includes shell escaping, longest-first parameter replacement, and work directory generation. 15 tests.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:10:09 +08:00
dailz
4a8153aa6c feat(model): add Application model and store
Add Application and ParameterSchema models with CRUD store. Includes 10 store tests and ParamType constants.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:08:24 +08:00
dailz
dd8d226e78 refactor: remove JobTemplate production code
Remove all JobTemplate model, store, handler, migrations, and wiring. Replaced by Application Definition system.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-13 17:07:46 +08:00
dailz
62e458cb7a docs(openapi): update GET /jobs with pagination and JobListResponse
Add page/page_size query parameters, change response from JobResponse[] to JobListResponse, add 400 error code.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 15:15:24 +08:00
dailz
2cb6fbecdd feat(service): add pagination to GetJobs endpoint
GetJobs now accepts page/page_size query parameters and returns JobListResponse instead of raw array. Uses in-memory pagination matching GetJobHistory pattern.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 15:14:56 +08:00
dailz
35a4017b8e docs(model): add Chinese field comments to all model structs
Add inline comments to SubmitJobRequest, JobListResponse, JobHistoryQuery, JobTemplate, CreateTemplateRequest, and UpdateTemplateRequest fields, consistent with existing cluster.go and JobResponse style.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 13:53:54 +08:00
dailz
f4177dd287 feat(service): add GetJob fallback to SlurmDBD history and expand query params
GetJob now falls back to SlurmDBD history when active queue returns 404 or empty jobs. Expand JobHistoryQuery from 7 to 16 filter params (add SubmitTime, Cluster, Qos, Constraints, ExitCode, Node, Reservation, Groups, Wckey).

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 13:43:31 +08:00
dailz
b3d787c97b fix(slurm): parse structured errors from non-2xx Slurm API responses
Replace ErrorResponse with SlurmAPIError that extracts structured errors/warnings from JSON body when Slurm returns non-2xx (e.g. 404 with valid JSON). Add IsNotFound helper for fallback logic.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 13:43:17 +08:00
dailz
30f0fbc34b fix(slurm): correct PartitionInfoMaximums CpusPerNode/CpusPerSocket types to Uint32NoVal
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:39:29 +08:00
dailz
34ba617cbf fix(test): update log assertions for debug logging and field expansion
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:13:13 +08:00
dailz
824d9e816f feat(service): map additional Slurm SDK fields and fix ExitCode/Default bugs
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:12:51 +08:00
dailz
85901fe18a feat(model): expand API response fields to expose full Slurm data
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 11:12:33 +08:00
dailz
270552ba9a feat(service): add debug logging for Slurm API calls with request/response body and latency
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-10 10:28:58 +08:00
103 changed files with 22673 additions and 1147 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
bin/
*.exe
.sisyphus/

662
cmd/client/main.go Normal file
View File

@@ -0,0 +1,662 @@
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// ── API response types ───────────────────────────────────────────────
type apiResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data"`
Error string `json:"error,omitempty"`
}
type sessionResponse struct {
ID int64 `json:"id"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
ChunkSize int64 `json:"chunk_size"`
TotalChunks int `json:"total_chunks"`
SHA256 string `json:"sha256"`
Status string `json:"status"`
UploadedChunks []int `json:"uploaded_chunks"`
}
type fileResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Size int64 `json:"size"`
MimeType string `json:"mime_type"`
SHA256 string `json:"sha256"`
CreatedAt string `json:"created_at"`
}
type folderResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parent_id,omitempty"`
Path string `json:"path"`
CreatedAt string `json:"created_at"`
}
type listFilesResponse struct {
Files []fileResponse `json:"files"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// ── Helpers ──────────────────────────────────────────────────────────
func formatSize(b int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case b >= GB:
return fmt.Sprintf("%.2f GB", float64(b)/float64(GB))
case b >= MB:
return fmt.Sprintf("%.2f MB", float64(b)/float64(MB))
case b >= KB:
return fmt.Sprintf("%.2f KB", float64(b)/float64(KB))
default:
return fmt.Sprintf("%d B", b)
}
}
func doRequest(server, method, path string, body io.Reader, contentType string) (*apiResponse, error) {
url := server + path
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(raw))
}
var apiResp apiResponse
if err := json.Unmarshal(raw, &apiResp); err != nil {
return nil, fmt.Errorf("parsing response: %w\nbody: %s", err, string(raw))
}
if !apiResp.Success {
return &apiResp, fmt.Errorf("API error: %s", apiResp.Error)
}
return &apiResp, nil
}
func computeSHA256(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
// ── Commands ─────────────────────────────────────────────────────────
func cmdMkdir(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> mkdir <name> [-parent <id>]")
os.Exit(1)
}
name := args[0]
var parentID *int64
for i := 1; i+1 < len(args); i += 2 {
if args[i] == "-parent" {
v, err := strconv.ParseInt(args[i+1], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid parent id: %s\n", args[i+1])
os.Exit(1)
}
parentID = &v
}
}
payload := map[string]interface{}{"name": name}
if parentID != nil {
payload["parent_id"] = *parentID
}
body, _ := json.Marshal(payload)
resp, err := doRequest(server, http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body), "application/json")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var folder folderResponse
if err := json.Unmarshal(resp.Data, &folder); err != nil {
fmt.Fprintf(os.Stderr, "parsing folder response: %v\n", err)
os.Exit(1)
}
fmt.Printf("Folder created: id=%d name=%q path=%q\n", folder.ID, folder.Name, folder.Path)
}
func cmdListFolders(server string, args []string) {
path := "/api/v1/files/folders?"
for i := 0; i+1 < len(args); i += 2 {
if args[i] == "-parent" {
path += "parent_id=" + args[i+1] + "&"
}
}
resp, err := doRequest(server, http.MethodGet, path, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var folders []folderResponse
if err := json.Unmarshal(resp.Data, &folders); err != nil {
fmt.Fprintf(os.Stderr, "parsing folders response: %v\n", err)
os.Exit(1)
}
if len(folders) == 0 {
fmt.Println("No folders found.")
return
}
fmt.Printf("%-8s %-30s %-10s %s\n", "ID", "Name", "ParentID", "Path")
fmt.Println(strings.Repeat("-", 80))
for _, f := range folders {
pid := "<root>"
if f.ParentID != nil {
pid = strconv.FormatInt(*f.ParentID, 10)
}
fmt.Printf("%-8d %-30s %-10s %s\n", f.ID, f.Name, pid, f.Path)
}
}
func cmdDeleteFolder(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> delete-folder <id>")
os.Exit(1)
}
id := args[0]
_, err := doRequest(server, http.MethodDelete, "/api/v1/files/folders/"+id, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
fmt.Printf("Folder %s deleted.\n", id)
}
func cmdUpload(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> upload <file> [-folder <id>] [-chunk-size <bytes>]")
os.Exit(1)
}
filePath := args[0]
var folderID *int64
chunkSize := int64(8 * 1024 * 1024) // 8 MB default
for i := 1; i+1 < len(args); i += 2 {
switch args[i] {
case "-folder":
v, err := strconv.ParseInt(args[i+1], 10, 64)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid folder id: %s\n", args[i+1])
os.Exit(1)
}
folderID = &v
case "-chunk-size":
v, err := strconv.ParseInt(args[i+1], 10, 64)
if err != nil || v <= 0 {
fmt.Fprintf(os.Stderr, "invalid chunk size: %s\n", args[i+1])
os.Exit(1)
}
chunkSize = v
}
}
start := time.Now()
// Open file and get size
f, err := os.Open(filePath)
if err != nil {
fmt.Fprintf(os.Stderr, "cannot open file: %v\n", err)
os.Exit(1)
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
fmt.Fprintf(os.Stderr, "cannot stat file: %v\n", err)
os.Exit(1)
}
fileSize := fi.Size()
fileName := filepath.Base(filePath)
// Compute SHA256
fmt.Printf("Computing SHA256 for %s (%s)...\n", fileName, formatSize(fileSize))
sha, err := computeSHA256(filePath)
if err != nil {
fmt.Fprintf(os.Stderr, "sha256 error: %v\n", err)
os.Exit(1)
}
fmt.Printf("SHA256: %s\n", sha)
// Init upload
payload := map[string]interface{}{
"file_name": fileName,
"file_size": fileSize,
"sha256": sha,
"chunk_size": chunkSize,
}
if folderID != nil {
payload["folder_id"] = *folderID
}
body, _ := json.Marshal(payload)
resp, err := doRequest(server, http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body), "application/json")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
// Check for dedup (秒传) — response is a file, not a session
var sess sessionResponse
if err := json.Unmarshal(resp.Data, &sess); err == nil && sess.Status != "" && sess.TotalChunks > 0 {
// It's a session, proceed with chunk uploads
} else {
// Try to parse as file (dedup hit)
var fr fileResponse
if err := json.Unmarshal(resp.Data, &fr); err == nil && fr.Name != "" {
elapsed := time.Since(start)
fmt.Printf("⚡ 秒传 (instant upload)! File already exists.\n")
fmt.Printf(" ID: %d\n", fr.ID)
fmt.Printf(" Name: %s\n", fr.Name)
fmt.Printf(" Size: %s\n", formatSize(fr.Size))
fmt.Printf(" SHA256: %s\n", fr.SHA256)
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
return
}
// If neither, just try the session path anyway
if err := json.Unmarshal(resp.Data, &sess); err != nil {
fmt.Fprintf(os.Stderr, "unexpected response: %s\n", string(resp.Data))
os.Exit(1)
}
}
sessionID := sess.ID
totalChunks := sess.TotalChunks
fmt.Printf("Upload session created: id=%d total_chunks=%d chunk_size=%s\n",
sessionID, totalChunks, formatSize(chunkSize))
// Upload chunks (4 concurrent workers)
const uploadWorkers = 4
type chunkResult struct {
index int
err error
}
work := make(chan int, totalChunks)
results := make(chan chunkResult, totalChunks)
for i := 0; i < totalChunks; i++ {
work <- i
}
close(work)
var uploaded int64 // atomic counter for progress
var wg sync.WaitGroup
for w := 0; w < uploadWorkers; w++ {
wg.Add(1)
go func() {
defer wg.Done()
for idx := range work {
offset := int64(idx) * chunkSize
remaining := fileSize - offset
thisChunkSize := chunkSize
if remaining < chunkSize {
thisChunkSize = remaining
}
sectionReader := io.NewSectionReader(f, offset, thisChunkSize)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("chunk", fileName)
if err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
if _, err := io.Copy(part, sectionReader); err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
writer.Close()
chunkURL := fmt.Sprintf("%s/api/v1/files/uploads/%d/chunks/%d", server, sessionID, idx)
chunkReq, err := http.NewRequest(http.MethodPut, chunkURL, &buf)
if err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
chunkReq.Header.Set("Content-Type", writer.FormDataContentType())
chunkResp, err := http.DefaultClient.Do(chunkReq)
if err != nil {
results <- chunkResult{index: idx, err: err}
continue
}
raw, _ := io.ReadAll(chunkResp.Body)
chunkResp.Body.Close()
if chunkResp.StatusCode >= 400 {
results <- chunkResult{index: idx, err: fmt.Errorf("HTTP %d: %s", chunkResp.StatusCode, string(raw))}
continue
}
var chunkAPIResp apiResponse
if err := json.Unmarshal(raw, &chunkAPIResp); err == nil && !chunkAPIResp.Success {
results <- chunkResult{index: idx, err: fmt.Errorf("%s", chunkAPIResp.Error)}
continue
}
results <- chunkResult{index: idx, err: nil}
}
}()
}
go func() {
wg.Wait()
close(results)
}()
var firstErr error
for res := range results {
if res.err != nil {
if firstErr == nil {
firstErr = res.err
}
fmt.Fprintf(os.Stderr, "chunk %d failed: %v\n", res.index, res.err)
}
done := atomic.AddInt64(&uploaded, 1)
pct := float64(done) / float64(totalChunks) * 100
doneBytes := done * chunkSize
if doneBytes > fileSize {
doneBytes = fileSize
}
fmt.Printf("\rUploading: %.1f%% (%s / %s)", pct, formatSize(doneBytes), formatSize(fileSize))
}
fmt.Println()
if firstErr != nil {
fmt.Fprintf(os.Stderr, "upload failed: %v\n", firstErr)
os.Exit(1)
}
// Complete upload
completeURL := fmt.Sprintf("/api/v1/files/uploads/%d/complete", sessionID)
resp, err = doRequest(server, http.MethodPost, completeURL, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var result fileResponse
if err := json.Unmarshal(resp.Data, &result); err != nil {
fmt.Fprintf(os.Stderr, "parsing complete response: %v\n", err)
os.Exit(1)
}
elapsed := time.Since(start)
speed := float64(fileSize) / elapsed.Seconds()
fmt.Println("✓ Upload complete!")
fmt.Printf(" ID: %d\n", result.ID)
fmt.Printf(" Name: %s\n", result.Name)
fmt.Printf(" Size: %s\n", formatSize(result.Size))
fmt.Printf(" SHA256: %s\n", result.SHA256)
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
fmt.Printf(" Speed: %s/s\n", formatSize(int64(speed)))
}
func cmdDownload(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> download <file_id> [-o <output>]")
os.Exit(1)
}
fileID := args[0]
var output string
for i := 1; i+1 < len(args); i += 2 {
if args[i] == "-o" {
output = args[i+1]
}
}
start := time.Now()
url := server + "/api/v1/files/" + fileID + "/download"
resp, err := http.Get(url)
if err != nil {
fmt.Fprintf(os.Stderr, "download error: %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
fmt.Fprintf(os.Stderr, "download failed (HTTP %d): %s\n", resp.StatusCode, string(body))
os.Exit(1)
}
// Determine output filename
if output == "" {
// Try Content-Disposition header
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
if idx := strings.Index(cd, "filename="); idx != -1 {
output = strings.Trim(cd[idx+9:], `"`)
}
}
if output == "" {
output = "download_" + fileID
}
}
outFile, err := os.Create(output)
if err != nil {
fmt.Fprintf(os.Stderr, "cannot create output file: %v\n", err)
os.Exit(1)
}
defer outFile.Close()
written, err := io.Copy(outFile, resp.Body)
if err != nil {
fmt.Fprintf(os.Stderr, "download write error: %v\n", err)
os.Exit(1)
}
elapsed := time.Since(start)
speed := float64(written) / elapsed.Seconds()
fmt.Println("✓ Download complete!")
fmt.Printf(" File: %s\n", output)
fmt.Printf(" Size: %s\n", formatSize(written))
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
fmt.Printf(" Speed: %s/s\n", formatSize(int64(speed)))
}
func cmdListFiles(server string, args []string) {
var folderID, page, pageSize string
for i := 0; i+1 < len(args); i += 2 {
switch args[i] {
case "-folder":
folderID = args[i+1]
case "-page":
page = args[i+1]
case "-page-size":
pageSize = args[i+1]
}
}
path := "/api/v1/files?"
if folderID != "" {
path += "folder_id=" + folderID + "&"
}
if page != "" {
path += "page=" + page + "&"
}
if pageSize != "" {
path += "page_size=" + pageSize + "&"
}
resp, err := doRequest(server, http.MethodGet, path, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var list listFilesResponse
if err := json.Unmarshal(resp.Data, &list); err != nil {
fmt.Fprintf(os.Stderr, "parsing files response: %v\n", err)
os.Exit(1)
}
if len(list.Files) == 0 {
fmt.Println("No files found.")
return
}
fmt.Printf("%-8s %-30s %-12s %-10s %s\n", "ID", "Name", "Size", "MIME", "Created")
fmt.Println(strings.Repeat("-", 90))
for _, f := range list.Files {
name := f.Name
if len(name) > 28 {
name = name[:25] + "..."
}
fmt.Printf("%-8d %-30s %-12s %-10s %s\n", f.ID, name, formatSize(f.Size), f.MimeType, f.CreatedAt)
}
fmt.Printf("\nTotal: %d Page: %d PageSize: %d\n", list.Total, list.Page, list.PageSize)
}
func cmdDeleteFile(server string, args []string) {
if len(args) == 0 {
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> delete-file <file_id>")
os.Exit(1)
}
fileID := args[0]
_, err := doRequest(server, http.MethodDelete, "/api/v1/files/"+fileID, nil, "")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
fmt.Printf("File %s deleted.\n", fileID)
}
// ── Main ─────────────────────────────────────────────────────────────
func usage() {
fmt.Fprintf(os.Stderr, `HPC File Storage Client
Usage: client -server <addr> <command> [args]
Commands:
mkdir <name> [-parent <id>] 创建文件夹
ls-folders [-parent <id>] 列出文件夹
delete-folder <id> 删除文件夹
upload <file> [-folder <id>] [-chunk-size <bytes>] 上传文件(分片)
download <file_id> [-o <output>] 下载文件
ls-files [-folder <id>] [-page <n>] [-page-size <n>] 列出文件
delete-file <file_id> 删除文件
Global flags:
-server <addr> 服务器地址 (默认 http://localhost:8080)
`)
}
func main() {
server := "http://localhost:8080"
var command string
var cmdArgs []string
args := os.Args[1:]
i := 0
for i < len(args) {
if args[i] == "-server" && i+1 < len(args) {
server = args[i+1]
i += 2
continue
}
if command == "" {
command = args[i]
} else {
cmdArgs = append(cmdArgs, args[i])
}
i++
}
if command == "" {
usage()
os.Exit(1)
}
switch command {
case "mkdir":
cmdMkdir(server, cmdArgs)
case "ls-folders":
cmdListFolders(server, cmdArgs)
case "delete-folder":
cmdDeleteFolder(server, cmdArgs)
case "upload":
cmdUpload(server, cmdArgs)
case "download":
cmdDownload(server, cmdArgs)
case "ls-files":
cmdListFiles(server, cmdArgs)
case "delete-file":
cmdDeleteFile(server, cmdArgs)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
usage()
os.Exit(1)
}
}

773
cmd/server/file_test.go Normal file
View File

@@ -0,0 +1,773 @@
package main
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"sort"
"strings"
"sync"
"testing"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// ---------------------------------------------------------------------------
// In-memory ObjectStorage mock
// ---------------------------------------------------------------------------
type inMemoryStorage struct {
mu sync.RWMutex
objects map[string][]byte
bucket string
}
var _ storage.ObjectStorage = (*inMemoryStorage)(nil)
func (s *inMemoryStorage) PutObject(_ context.Context, _, key string, reader io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
data, err := io.ReadAll(reader)
if err != nil {
return storage.UploadInfo{}, fmt.Errorf("read all: %w", err)
}
s.mu.Lock()
s.objects[key] = data
s.mu.Unlock()
h := sha256.Sum256(data)
return storage.UploadInfo{ETag: hex.EncodeToString(h[:]), Size: int64(len(data))}, nil
}
func (s *inMemoryStorage) GetObject(_ context.Context, _, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
s.mu.RLock()
data, ok := s.objects[key]
s.mu.RUnlock()
if !ok {
return nil, storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
}
size := int64(len(data))
start := int64(0)
end := size - 1
if opts.Start != nil {
start = *opts.Start
}
if opts.End != nil {
end = *opts.End
}
if end >= size {
end = size - 1
}
section := io.NewSectionReader(bytes.NewReader(data), start, end-start+1)
info := storage.ObjectInfo{Key: key, Size: size}
return io.NopCloser(section), info, nil
}
func (s *inMemoryStorage) ComposeObject(_ context.Context, _, dst string, sources []string) (storage.UploadInfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
var buf bytes.Buffer
for _, src := range sources {
data, ok := s.objects[src]
if !ok {
return storage.UploadInfo{}, fmt.Errorf("source object %s not found", src)
}
buf.Write(data)
}
combined := buf.Bytes()
s.objects[dst] = combined
h := sha256.Sum256(combined)
return storage.UploadInfo{ETag: hex.EncodeToString(h[:]), Size: int64(len(combined))}, nil
}
func (s *inMemoryStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
return nil
}
func (s *inMemoryStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
return nil
}
func (s *inMemoryStorage) RemoveObject(_ context.Context, _, key string, _ storage.RemoveObjectOptions) error {
s.mu.Lock()
delete(s.objects, key)
s.mu.Unlock()
return nil
}
func (s *inMemoryStorage) ListObjects(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []storage.ObjectInfo
for k, v := range s.objects {
if strings.HasPrefix(k, prefix) {
result = append(result, storage.ObjectInfo{Key: k, Size: int64(len(v))})
}
}
sort.Slice(result, func(i, j int) bool { return result[i].Key < result[j].Key })
return result, nil
}
func (s *inMemoryStorage) RemoveObjects(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
s.mu.Lock()
for _, k := range keys {
delete(s.objects, k)
}
s.mu.Unlock()
return nil
}
func (s *inMemoryStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (s *inMemoryStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (s *inMemoryStorage) StatObject(_ context.Context, _, key string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
s.mu.RLock()
data, ok := s.objects[key]
s.mu.RUnlock()
if !ok {
return storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
}
return storage.ObjectInfo{Key: key, Size: int64(len(data))}, nil
}
// ---------------------------------------------------------------------------
// Test helpers
// ---------------------------------------------------------------------------
func setupFileTestRouter(t *testing.T) (*gin.Engine, *gorm.DB, *inMemoryStorage) {
t.Helper()
gin.SetMode(gin.TestMode)
dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
t.Fatal(err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(1)
db.AutoMigrate(&model.FileBlob{}, &model.File{}, &model.Folder{}, &model.UploadSession{}, &model.UploadChunk{})
memStore := &inMemoryStorage{objects: make(map[string][]byte)}
blobStore := store.NewBlobStore(db)
fileStore := store.NewFileStore(db)
folderStore := store.NewFolderStore(db)
uploadStore := store.NewUploadStore(db)
cfg := config.MinioConfig{
ChunkSize: 16 << 20,
MaxFileSize: 50 << 30,
MinChunkSize: 5 << 20,
SessionTTL: 48,
Bucket: "files",
}
uploadSvc := service.NewUploadService(memStore, blobStore, fileStore, uploadStore, cfg, db, zap.NewNop())
_ = service.NewDownloadService(memStore, blobStore, fileStore, "files", zap.NewNop())
folderSvc := service.NewFolderService(folderStore, fileStore, zap.NewNop())
fileSvc := service.NewFileService(memStore, blobStore, fileStore, "files", db, zap.NewNop())
uploadH := handler.NewUploadHandler(uploadSvc, zap.NewNop())
fileH := handler.NewFileHandler(fileSvc, zap.NewNop())
folderH := handler.NewFolderHandler(folderSvc, zap.NewNop())
r := gin.New()
r.Use(gin.Recovery())
v1 := r.Group("/api/v1")
files := v1.Group("/files")
uploads := files.Group("/uploads")
uploads.POST("", uploadH.InitUpload)
uploads.GET("/:id", uploadH.GetUploadStatus)
uploads.PUT("/:id/chunks/:index", uploadH.UploadChunk)
uploads.POST("/:id/complete", uploadH.CompleteUpload)
uploads.DELETE("/:id", uploadH.CancelUpload)
files.GET("", fileH.ListFiles)
files.GET("/:id", fileH.GetFile)
files.GET("/:id/download", fileH.DownloadFile)
files.DELETE("/:id", fileH.DeleteFile)
folders := files.Group("/folders")
folders.POST("", folderH.CreateFolder)
folders.GET("", folderH.ListFolders)
folders.GET("/:id", folderH.GetFolder)
folders.DELETE("/:id", folderH.DeleteFolder)
return r, db, memStore
}
// apiResponse mirrors server.APIResponse for decoding.
type apiResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) apiResponse {
t.Helper()
var resp apiResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode response: %v, body: %s", err, w.Body.String())
}
return resp
}
func createChunkRequest(t *testing.T, url string, data []byte) *http.Request {
t.Helper()
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("chunk", "chunk.bin")
if err != nil {
t.Fatal(err)
}
part.Write(data)
writer.Close()
req, err := http.NewRequest("PUT", url, &buf)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
return req
}
// helperUploadFile performs a full upload lifecycle: init → upload chunks → complete.
// Returns the file ID from the completed upload response.
func helperUploadFile(t *testing.T, router *gin.Engine, fileName string, fileData []byte, sha256Hash string, folderID *int64, chunkSize int64) int64 {
t.Helper()
fileSize := int64(len(fileData))
initBody := model.InitUploadRequest{
FileName: fileName,
FileSize: fileSize,
SHA256: sha256Hash,
FolderID: folderID,
ChunkSize: &chunkSize,
}
initJSON, _ := json.Marshal(initBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
resp := decodeResponse(t, w)
if !resp.Success {
t.Fatalf("init upload failed: %s", resp.Error)
}
if w.Code == http.StatusOK {
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("decode dedup file response: %v", err)
}
return fileResp.ID
}
if w.Code != http.StatusCreated {
t.Fatalf("init upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
var session model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &session); err != nil {
t.Fatalf("failed to decode session: %v", err)
}
totalChunks := session.TotalChunks
for i := 0; i < totalChunks; i++ {
start := int64(i) * chunkSize
end := start + chunkSize
if end > fileSize {
end = fileSize
}
chunkData := fileData[start:end]
url := fmt.Sprintf("/api/v1/files/uploads/%d/chunks/%d", session.ID, i)
cw := httptest.NewRecorder()
creq := createChunkRequest(t, url, chunkData)
router.ServeHTTP(cw, creq)
if cw.Code != http.StatusOK {
t.Fatalf("upload chunk %d: expected 200, got %d: %s", i, cw.Code, cw.Body.String())
}
}
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("complete upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("failed to decode file response: %v", err)
}
return fileResp.ID
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
func TestFileFullLifecycle(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
// Create test file data
fileData := []byte("Hello, World! This is a test file for the full lifecycle integration test.")
fileSize := int64(len(fileData))
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20) // 5MB min chunk size
// 1. Init upload
initBody := model.InitUploadRequest{
FileName: "test.txt",
FileSize: fileSize,
SHA256: sha256Hash,
ChunkSize: &chunkSize,
}
initJSON, _ := json.Marshal(initBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("init upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp := decodeResponse(t, w)
var session model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &session); err != nil {
t.Fatalf("failed to decode session: %v", err)
}
if session.Status != "pending" {
t.Fatalf("expected status pending, got %s", session.Status)
}
if session.TotalChunks != 1 {
t.Fatalf("expected 1 chunk, got %d", session.TotalChunks)
}
// 2. Upload chunk 0
url := fmt.Sprintf("/api/v1/files/uploads/%d/chunks/0", session.ID)
w = httptest.NewRecorder()
req = createChunkRequest(t, url, fileData)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("upload chunk: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 3. Complete upload
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("complete upload: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("failed to decode file response: %v", err)
}
if fileResp.Name != "test.txt" {
t.Fatalf("expected name test.txt, got %s", fileResp.Name)
}
// 4. Download file
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("download: expected 200, got %d: %s", w.Code, w.Body.String())
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatalf("downloaded data mismatch: got %q, want %q", w.Body.String(), string(fileData))
}
// 5. Delete file
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 6. Verify file is gone (download should fail)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code == http.StatusOK {
t.Fatal("expected download to fail after delete, got 200")
}
}
func TestFileDedup(t *testing.T) {
router, db, _ := setupFileTestRouter(t)
fileData := []byte("Duplicate content for dedup test.")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
// Upload file A (full lifecycle)
fileAID := helperUploadFile(t, router, "fileA.txt", fileData, sha256Hash, nil, chunkSize)
if fileAID == 0 {
t.Fatal("file A ID should not be 0")
}
// Upload file B with same SHA256 → should be instant dedup
fileBID := helperUploadFile(t, router, "fileB.txt", fileData, sha256Hash, nil, chunkSize)
if fileBID == 0 {
t.Fatal("file B ID should not be 0")
}
if fileBID == fileAID {
t.Fatal("file B should have a different ID than file A")
}
// Verify blob ref_count = 2
var blob model.FileBlob
if err := db.Where("sha256 = ?", sha256Hash).First(&blob).Error; err != nil {
t.Fatalf("blob not found: %v", err)
}
if blob.RefCount != 2 {
t.Fatalf("expected ref_count 2, got %d", blob.RefCount)
}
// Both files should be downloadable
for _, id := range []int64{fileAID, fileBID} {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", id), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("download file %d: expected 200, got %d", id, w.Code)
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatalf("downloaded data mismatch for file %d", id)
}
}
// Delete file A — file B should still be downloadable
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file A: expected 200, got %d: %s", w.Code, w.Body.String())
}
// Verify blob still exists with ref_count = 1
if err := db.Where("sha256 = ?", sha256Hash).First(&blob).Error; err != nil {
t.Fatalf("blob should still exist: %v", err)
}
if blob.RefCount != 2 {
t.Fatalf("expected ref_count 2 (blob still shared), got %d", blob.RefCount)
}
// File B still downloadable
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("file B should still be downloadable, got %d", w.Code)
}
}
func TestFileResumeUpload(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
// Create data that spans 2 chunks (min chunk = 5MB, use that)
chunkSize := int64(5 << 20) // 5MB
data1 := bytes.Repeat([]byte("A"), int(chunkSize))
data2 := bytes.Repeat([]byte("B"), int(chunkSize))
fileData := append(data1, data2...)
fileSize := int64(len(fileData))
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
// 1. Init upload
initBody := model.InitUploadRequest{
FileName: "resume.bin",
FileSize: fileSize,
SHA256: sha256Hash,
ChunkSize: &chunkSize,
}
initJSON, _ := json.Marshal(initBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("init: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp := decodeResponse(t, w)
var session model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &session); err != nil {
t.Fatalf("decode session: %v", err)
}
if session.TotalChunks != 2 {
t.Fatalf("expected 2 chunks, got %d", session.TotalChunks)
}
// 2. Upload chunk 0 only
w = httptest.NewRecorder()
req = createChunkRequest(t, fmt.Sprintf("/api/v1/files/uploads/%d/chunks/0", session.ID), data1)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("upload chunk 0: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 3. Get status — should show chunk 0 uploaded
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("get status: expected 200, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var status model.UploadSessionResponse
if err := json.Unmarshal(resp.Data, &status); err != nil {
t.Fatalf("decode status: %v", err)
}
if len(status.UploadedChunks) != 1 || status.UploadedChunks[0] != 0 {
t.Fatalf("expected uploaded_chunks=[0], got %v", status.UploadedChunks)
}
// 4. Upload chunk 1 (resume)
w = httptest.NewRecorder()
req = createChunkRequest(t, fmt.Sprintf("/api/v1/files/uploads/%d/chunks/1", session.ID), data2)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("upload chunk 1: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 5. Complete
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("complete: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var fileResp model.FileResponse
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
t.Fatalf("decode file response: %v", err)
}
// 6. Download and verify
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("download: expected 200, got %d: %s", w.Code, w.Body.String())
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatal("downloaded data does not match original")
}
}
func TestFileFolderOperations(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
// 1. Create folder
folderBody := model.CreateFolderRequest{Name: "test-folder"}
folderJSON, _ := json.Marshal(folderBody)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/files/folders", bytes.NewReader(folderJSON))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("create folder: expected 201, got %d: %s", w.Code, w.Body.String())
}
resp := decodeResponse(t, w)
var folderResp model.FolderResponse
if err := json.Unmarshal(resp.Data, &folderResp); err != nil {
t.Fatalf("decode folder: %v", err)
}
if folderResp.Name != "test-folder" {
t.Fatalf("expected name test-folder, got %s", folderResp.Name)
}
if folderResp.Path != "/test-folder/" {
t.Fatalf("expected path /test-folder/, got %s", folderResp.Path)
}
// 2. Upload a file into the folder
fileData := []byte("File inside folder")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
fileID := helperUploadFile(t, router, "folder_file.txt", fileData, sha256Hash, &folderResp.ID, chunkSize)
if fileID == 0 {
t.Fatal("file ID should not be 0")
}
// 3. List files in folder
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files?folder_id=%d", folderResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("list files: expected 200, got %d: %s", w.Code, w.Body.String())
}
resp = decodeResponse(t, w)
var listResp model.ListFilesResponse
if err := json.Unmarshal(resp.Data, &listResp); err != nil {
t.Fatalf("decode list: %v", err)
}
if listResp.Total != 1 {
t.Fatalf("expected 1 file in folder, got %d", listResp.Total)
}
if listResp.Files[0].Name != "folder_file.txt" {
t.Fatalf("expected file name folder_file.txt, got %s", listResp.Files[0].Name)
}
// 4. List folders (root)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/api/v1/files/folders", nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("list folders: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 5. Try delete folder (should fail — not empty)
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/folders/%d", folderResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("delete non-empty folder: expected 400, got %d: %s", w.Code, w.Body.String())
}
// 6. Delete the file first
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file: expected 200, got %d: %s", w.Code, w.Body.String())
}
// 7. Now delete empty folder
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/folders/%d", folderResp.ID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete empty folder: expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestFileRangeDownload(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
fileData := []byte("0123456789ABCDEF") // 16 bytes
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
fileID := helperUploadFile(t, router, "range_test.bin", fileData, sha256Hash, nil, chunkSize)
// Download range bytes=4-9 → "456789"
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
req.Header.Set("Range", "bytes=4-9")
router.ServeHTTP(w, req)
if w.Code != http.StatusPartialContent {
t.Fatalf("range download: expected 206, got %d: %s", w.Code, w.Body.String())
}
contentRange := w.Header().Get("Content-Range")
expectedRange := fmt.Sprintf("bytes 4-9/%d", len(fileData))
if contentRange != expectedRange {
t.Fatalf("content-range: expected %q, got %q", expectedRange, contentRange)
}
if !bytes.Equal(w.Body.Bytes(), []byte("456789")) {
t.Fatalf("range content: expected '456789', got %q", w.Body.String())
}
// Full download still works
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("full download: expected 200, got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatalf("full download data mismatch")
}
}
func TestFileDeleteOneRefOtherStillDownloadable(t *testing.T) {
router, _, _ := setupFileTestRouter(t)
fileData := []byte("Shared blob content for multi-ref test")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
chunkSize := int64(5 << 20)
// Upload file A
fileAID := helperUploadFile(t, router, "refA.txt", fileData, sha256Hash, nil, chunkSize)
// Upload file B with same content → dedup instant
fileBID := helperUploadFile(t, router, "refB.txt", fileData, sha256Hash, nil, chunkSize)
// Delete file A
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file A: expected 200, got %d: %s", w.Code, w.Body.String())
}
// Verify file A is gone
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
router.ServeHTTP(w, req)
// File is soft-deleted, but GetByID may still find it depending on soft-delete handling
// The important check is that file B is still downloadable
// File B should still be downloadable
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("file B should still be downloadable after deleting file A, got %d: %s", w.Code, w.Body.String())
}
if !bytes.Equal(w.Body.Bytes(), fileData) {
t.Fatal("file B download data mismatch")
}
// Delete file B — now blob should be fully removed
w = httptest.NewRecorder()
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("delete file B: expected 200, got %d: %s", w.Code, w.Body.String())
}
// File B should now be gone
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
router.ServeHTTP(w, req)
if w.Code == http.StatusOK {
t.Fatal("file B should not be downloadable after deletion")
}
}

View File

@@ -0,0 +1,257 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// appListData mirrors the list endpoint response data structure.
type appListData struct {
Applications []json.RawMessage `json:"applications"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// appCreatedData mirrors the create endpoint response data structure.
type appCreatedData struct {
ID int64 `json:"id"`
}
// appMessageData mirrors the update/delete endpoint response data structure.
type appMessageData struct {
Message string `json:"message"`
}
// appData mirrors the application model returned by GET.
type appData struct {
ID int64 `json:"id"`
Name string `json:"name"`
ScriptTemplate string `json:"script_template"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
// appDoRequest is a small wrapper that marshals body and calls env.DoRequest.
func appDoRequest(env *testenv.TestEnv, method, path string, body interface{}) *http.Response {
var r io.Reader
if body != nil {
b, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("appDoRequest marshal: %v", err))
}
r = bytes.NewReader(b)
}
return env.DoRequest(method, path, r)
}
// appDecodeAll decodes the response and also reads the HTTP status.
func appDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// appSeedApp creates an app via the service (bypasses HTTP) and returns its ID.
func appSeedApp(env *testenv.TestEnv, name string) int64 {
id, err := env.CreateApp(name, "#!/bin/bash\necho hello", json.RawMessage(`[]`))
if err != nil {
panic(fmt.Sprintf("appSeedApp: %v", err))
}
return id
}
// TestIntegration_App_List verifies GET /api/v1/applications returns an empty list initially.
func TestIntegration_App_List(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/applications", nil)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var list appListData
if err := json.Unmarshal(data, &list); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if list.Total != 0 {
t.Fatalf("expected total=0, got %d", list.Total)
}
if len(list.Applications) != 0 {
t.Fatalf("expected 0 applications, got %d", len(list.Applications))
}
}
// TestIntegration_App_Create verifies POST /api/v1/applications creates an application.
func TestIntegration_App_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
body := map[string]interface{}{
"name": "test-app-create",
"script_template": "#!/bin/bash\necho hello",
"parameters": []interface{}{},
}
resp := appDoRequest(env, http.MethodPost, "/api/v1/applications", body)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var created appCreatedData
if err := json.Unmarshal(data, &created); err != nil {
t.Fatalf("unmarshal created data: %v", err)
}
if created.ID <= 0 {
t.Fatalf("expected positive id, got %d", created.ID)
}
}
// TestIntegration_App_Get verifies GET /api/v1/applications/:id returns the correct application.
func TestIntegration_App_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
id := appSeedApp(env, "test-app-get")
path := fmt.Sprintf("/api/v1/applications/%d", id)
resp := env.DoRequest(http.MethodGet, path, nil)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var app appData
if err := json.Unmarshal(data, &app); err != nil {
t.Fatalf("unmarshal app data: %v", err)
}
if app.ID != id {
t.Fatalf("expected id=%d, got %d", id, app.ID)
}
if app.Name != "test-app-get" {
t.Fatalf("expected name=test-app-get, got %s", app.Name)
}
}
// TestIntegration_App_Update verifies PUT /api/v1/applications/:id updates an application.
func TestIntegration_App_Update(t *testing.T) {
env := testenv.NewTestEnv(t)
id := appSeedApp(env, "test-app-update-before")
newName := "test-app-update-after"
body := map[string]interface{}{
"name": newName,
}
path := fmt.Sprintf("/api/v1/applications/%d", id)
resp := appDoRequest(env, http.MethodPut, path, body)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg appMessageData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal message data: %v", err)
}
if msg.Message != "application updated" {
t.Fatalf("expected message 'application updated', got %q", msg.Message)
}
getResp := env.DoRequest(http.MethodGet, path, nil)
_, _, getData, gErr := appDecodeAll(env, getResp)
if gErr != nil {
t.Fatalf("decode get response: %v", gErr)
}
var updated appData
if err := json.Unmarshal(getData, &updated); err != nil {
t.Fatalf("unmarshal updated app: %v", err)
}
if updated.Name != newName {
t.Fatalf("expected updated name=%q, got %q", newName, updated.Name)
}
}
// TestIntegration_App_Delete verifies DELETE /api/v1/applications/:id removes an application.
func TestIntegration_App_Delete(t *testing.T) {
env := testenv.NewTestEnv(t)
id := appSeedApp(env, "test-app-delete")
path := fmt.Sprintf("/api/v1/applications/%d", id)
resp := env.DoRequest(http.MethodDelete, path, nil)
status, success, data, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg appMessageData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal message data: %v", err)
}
if msg.Message != "application deleted" {
t.Fatalf("expected message 'application deleted', got %q", msg.Message)
}
// Verify deletion returns 404.
getResp := env.DoRequest(http.MethodGet, path, nil)
getStatus, getSuccess, _, _ := appDecodeAll(env, getResp)
if getStatus != http.StatusNotFound {
t.Fatalf("expected status 404 after delete, got %d", getStatus)
}
if getSuccess {
t.Fatal("expected success=false after delete")
}
}
// TestIntegration_App_CreateValidation verifies POST /api/v1/applications with empty name returns error.
func TestIntegration_App_CreateValidation(t *testing.T) {
env := testenv.NewTestEnv(t)
body := map[string]interface{}{
"name": "",
"script_template": "#!/bin/bash\necho hello",
}
resp := appDoRequest(env, http.MethodPost, "/api/v1/applications", body)
status, success, _, err := appDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", status)
}
if success {
t.Fatal("expected success=false for validation error")
}
}

View File

@@ -0,0 +1,186 @@
package main
import (
"encoding/json"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// clusterNodeData mirrors the NodeResponse DTO returned by the API.
type clusterNodeData struct {
Name string `json:"name"`
State []string `json:"state"`
CPUs int32 `json:"cpus"`
RealMemory int64 `json:"real_memory"`
}
// clusterPartitionData mirrors the PartitionResponse DTO returned by the API.
type clusterPartitionData struct {
Name string `json:"name"`
State []string `json:"state"`
TotalNodes int32 `json:"total_nodes,omitempty"`
TotalCPUs int32 `json:"total_cpus,omitempty"`
}
// clusterDiagStat mirrors a single entry from the diag statistics.
type clusterDiagStat struct {
Parts []struct {
Param string `json:"param"`
} `json:"parts,omitempty"`
}
// clusterDecodeAll decodes the response and returns status, success, and raw data.
func clusterDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// TestIntegration_Cluster_Nodes verifies GET /api/v1/nodes returns the 3 pre-loaded mock nodes.
func TestIntegration_Cluster_Nodes(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/nodes", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var nodes []clusterNodeData
if err := json.Unmarshal(data, &nodes); err != nil {
t.Fatalf("unmarshal nodes: %v", err)
}
if len(nodes) != 3 {
t.Fatalf("expected 3 nodes, got %d", len(nodes))
}
names := make(map[string]bool, len(nodes))
for _, n := range nodes {
names[n.Name] = true
}
for _, expected := range []string{"node01", "node02", "node03"} {
if !names[expected] {
t.Errorf("missing expected node %q", expected)
}
}
}
// TestIntegration_Cluster_NodeByName verifies GET /api/v1/nodes/:name returns a single node.
func TestIntegration_Cluster_NodeByName(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/nodes/node01", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var node clusterNodeData
if err := json.Unmarshal(data, &node); err != nil {
t.Fatalf("unmarshal node: %v", err)
}
if node.Name != "node01" {
t.Fatalf("expected name=node01, got %q", node.Name)
}
}
// TestIntegration_Cluster_Partitions verifies GET /api/v1/partitions returns the 2 pre-loaded partitions.
func TestIntegration_Cluster_Partitions(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/partitions", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var partitions []clusterPartitionData
if err := json.Unmarshal(data, &partitions); err != nil {
t.Fatalf("unmarshal partitions: %v", err)
}
if len(partitions) != 2 {
t.Fatalf("expected 2 partitions, got %d", len(partitions))
}
names := make(map[string]bool, len(partitions))
for _, p := range partitions {
names[p.Name] = true
}
if !names["normal"] {
t.Error("missing expected partition \"normal\"")
}
if !names["gpu"] {
t.Error("missing expected partition \"gpu\"")
}
}
// TestIntegration_Cluster_PartitionByName verifies GET /api/v1/partitions/:name returns a single partition.
func TestIntegration_Cluster_PartitionByName(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/partitions/normal", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var part clusterPartitionData
if err := json.Unmarshal(data, &part); err != nil {
t.Fatalf("unmarshal partition: %v", err)
}
if part.Name != "normal" {
t.Fatalf("expected name=normal, got %q", part.Name)
}
}
// TestIntegration_Cluster_Diag verifies GET /api/v1/diag returns diagnostics data.
func TestIntegration_Cluster_Diag(t *testing.T) {
env := testenv.NewTestEnv(t)
resp := env.DoRequest(http.MethodGet, "/api/v1/diag", nil)
status, success, data, err := clusterDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
// Verify the response contains a "statistics" field (non-empty JSON object).
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("unmarshal diag top-level: %v", err)
}
if _, ok := raw["statistics"]; !ok {
t.Fatal("diag response missing \"statistics\" field")
}
}

View File

@@ -0,0 +1,202 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/testutil/testenv"
)
// e2eResponse mirrors the unified API response structure.
type e2eResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// e2eTaskCreatedData mirrors the POST /api/v1/tasks response data.
type e2eTaskCreatedData struct {
ID int64 `json:"id"`
}
// e2eTaskItem mirrors a single task in the list response.
type e2eTaskItem struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
Status string `json:"status"`
WorkDir string `json:"work_dir"`
ErrorMessage string `json:"error_message"`
}
// e2eTaskListData mirrors the list endpoint response data.
type e2eTaskListData struct {
Items []e2eTaskItem `json:"items"`
Total int64 `json:"total"`
}
// e2eSendRequest sends an HTTP request via the test env and returns the response.
func e2eSendRequest(env *testenv.TestEnv, method, path string, body string) *http.Response {
var r io.Reader
if body != "" {
r = strings.NewReader(body)
}
return env.DoRequest(method, path, r)
}
// e2eParseResponse decodes an HTTP response into e2eResponse.
func e2eParseResponse(resp *http.Response) (int, e2eResponse) {
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("e2eParseResponse read: %v", err))
}
resp.Body.Close()
var result e2eResponse
if err := json.Unmarshal(b, &result); err != nil {
panic(fmt.Sprintf("e2eParseResponse unmarshal: %v (body: %s)", err, string(b)))
}
return resp.StatusCode, result
}
// TestIntegration_E2E_CompleteWorkflow verifies the full lifecycle:
// create app → upload file → submit task → queued → running → completed.
func TestIntegration_E2E_CompleteWorkflow(t *testing.T) {
t.Log("========== E2E 全链路测试开始 ==========")
t.Log("")
env := testenv.NewTestEnv(t)
t.Log("✓ 测试环境创建完成 (SQLite + MockSlurm + MockMinIO + Router + Poller)")
t.Log("")
// Step 1: Create Application with script template and parameters.
t.Log("【步骤 1】创建应用")
appID, err := env.CreateApp("e2e-app", "#!/bin/bash\necho {{.np}}",
json.RawMessage(`[{"name":"np","type":"string","default":"1"}]`))
if err != nil {
t.Fatalf("step 1 create app: %v", err)
}
t.Logf(" → 应用创建成功, appID=%d, 脚本模板='#!/bin/bash echo {{.np}}', 参数=[np]", appID)
t.Log("")
// Step 2: Upload input file.
t.Log("【步骤 2】上传输入文件")
fileID, _ := env.UploadTestData("input.txt", []byte("test input data"))
t.Logf(" → 文件上传成功, fileID=%d, 内容='test input data' (存入 MockMinIO + SQLite)", fileID)
t.Log("")
// Step 3: Submit Task via API.
t.Log("【步骤 3】通过 HTTP API 提交任务")
body := fmt.Sprintf(
`{"app_id": %d, "task_name": "e2e-task", "values": {"np": "4"}, "file_ids": [%d]}`,
appID, fileID,
)
t.Logf(" → POST /api/v1/tasks body=%s", body)
resp := e2eSendRequest(env, http.MethodPost, "/api/v1/tasks", body)
status, result := e2eParseResponse(resp)
if status != http.StatusCreated {
t.Fatalf("step 3 submit task: status=%d, success=%v, error=%q", status, result.Success, result.Error)
}
var created e2eTaskCreatedData
if err := json.Unmarshal(result.Data, &created); err != nil {
t.Fatalf("step 3 parse task id: %v", err)
}
taskID := created.ID
if taskID <= 0 {
t.Fatalf("step 3: expected positive task id, got %d", taskID)
}
t.Logf(" → HTTP 201 Created, taskID=%d", taskID)
t.Log("")
// Step 4: Wait for queued status.
t.Log("【步骤 4】等待 TaskProcessor 异步提交到 MockSlurm")
t.Log(" → 后台流程: submitted → preparing → downloading → ready → queued")
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
t.Fatalf("step 4 wait for queued: %v (current status via API: %q)", err, taskStatus)
}
t.Logf(" → 任务状态变为 'queued' (TaskProcessor 已提交到 Slurm)")
t.Log("")
// Step 5: Get slurmJobID.
t.Log("【步骤 5】查询数据库获取 Slurm Job ID")
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
if err != nil {
t.Fatalf("step 5 get slurm job id: %v", err)
}
t.Logf(" → slurmJobID=%d (MockSlurm 中的作业号)", slurmJobID)
t.Log("")
// Step 6: Transition to RUNNING.
t.Log("【步骤 6】模拟 Slurm: 作业开始运行")
t.Logf(" → MockSlurm.SetJobState(%d, 'RUNNING')", slurmJobID)
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
t.Logf(" → MakeTaskStale(%d) — 绕过 30s 等待,让 poller 立即刷新", taskID)
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("step 6 make task stale: %v", err)
}
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
t.Fatalf("step 6 wait for running: %v (current status via API: %q)", err, taskStatus)
}
t.Logf(" → 任务状态变为 'running'")
t.Log("")
// Step 7: Transition to COMPLETED — job evicted from activeJobs to historyJobs.
t.Log("【步骤 7】模拟 Slurm: 作业运行完成")
t.Logf(" → MockSlurm.SetJobState(%d, 'COMPLETED') — 作业从 activeJobs 淘汰到 historyJobs", slurmJobID)
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
t.Log(" → MakeTaskStale + WaitForTaskStatus...")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("step 7 make task stale: %v", err)
}
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
t.Fatalf("step 7 wait for completed: %v (current status via API: %q)", err, taskStatus)
}
t.Logf(" → 任务状态变为 'completed' (通过 SlurmDB 历史回退路径获取)")
t.Log("")
// Step 8: Verify final state via GET /api/v1/tasks.
t.Log("【步骤 8】通过 HTTP API 验证最终状态")
finalStatus, finalItem := e2eFetchTaskStatus(env, taskID)
if finalStatus != "completed" {
t.Fatalf("step 8: expected status completed, got %q (error: %q)", finalStatus, finalItem.ErrorMessage)
}
t.Logf(" → GET /api/v1/tasks 返回 status='completed'")
t.Logf(" → task_name='%s', work_dir='%s'", finalItem.TaskName, finalItem.WorkDir)
t.Logf(" → MockSlurm activeJobs=%d, historyJobs=%d",
len(env.MockSlurm.GetAllActiveJobs()), len(env.MockSlurm.GetAllHistoryJobs()))
t.Log("")
// Step 9: Verify WorkDir exists and contains the input file.
t.Log("【步骤 9】验证工作目录")
if finalItem.WorkDir == "" {
t.Fatal("step 9: expected non-empty work_dir")
}
t.Logf(" → work_dir='%s' (非空TaskProcessor 已创建)", finalItem.WorkDir)
t.Log("")
t.Log("========== E2E 全链路测试通过 ✓ ==========")
}
// e2eFetchTaskStatus fetches a single task's status from the list API.
func e2eFetchTaskStatus(env *testenv.TestEnv, taskID int64) (string, e2eTaskItem) {
resp := e2eSendRequest(env, http.MethodGet, "/api/v1/tasks", "")
_, result := e2eParseResponse(resp)
var list e2eTaskListData
if err := json.Unmarshal(result.Data, &list); err != nil {
return "", e2eTaskItem{}
}
for _, item := range list.Items {
if item.ID == taskID {
return item.Status, item
}
}
return "", e2eTaskItem{}
}

View File

@@ -0,0 +1,170 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/testutil/testenv"
)
// fileAPIResp mirrors server.APIResponse for file integration tests.
type fileAPIResp struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// fileDecode parses an HTTP response body into fileAPIResp.
func fileDecode(t *testing.T, body io.Reader) fileAPIResp {
t.Helper()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("fileDecode: read body: %v", err)
}
var r fileAPIResp
if err := json.Unmarshal(data, &r); err != nil {
t.Fatalf("fileDecode: unmarshal: %v (body: %s)", err, string(data))
}
return r
}
func TestIntegration_File_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Upload a file so the list is non-empty.
env.UploadTestData("list_test.txt", []byte("hello list"))
resp := env.DoRequest("GET", "/api/v1/files", nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := fileDecode(t, resp.Body)
if !r.Success {
t.Fatalf("response not success: %s", r.Error)
}
var listResp model.ListFilesResponse
if err := json.Unmarshal(r.Data, &listResp); err != nil {
t.Fatalf("unmarshal list response: %v", err)
}
if len(listResp.Files) == 0 {
t.Fatal("expected at least 1 file in list, got 0")
}
found := false
for _, f := range listResp.Files {
if f.Name == "list_test.txt" {
found = true
break
}
}
if !found {
t.Fatal("expected to find list_test.txt in file list")
}
if listResp.Total < 1 {
t.Fatalf("expected total >= 1, got %d", listResp.Total)
}
if listResp.Page < 1 {
t.Fatalf("expected page >= 1, got %d", listResp.Page)
}
}
func TestIntegration_File_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
fileID, _ := env.UploadTestData("get_test.txt", []byte("hello get"))
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := fileDecode(t, resp.Body)
if !r.Success {
t.Fatalf("response not success: %s", r.Error)
}
var fileResp model.FileResponse
if err := json.Unmarshal(r.Data, &fileResp); err != nil {
t.Fatalf("unmarshal file response: %v", err)
}
if fileResp.ID != fileID {
t.Fatalf("expected file ID %d, got %d", fileID, fileResp.ID)
}
if fileResp.Name != "get_test.txt" {
t.Fatalf("expected name get_test.txt, got %s", fileResp.Name)
}
if fileResp.Size != int64(len("hello get")) {
t.Fatalf("expected size %d, got %d", len("hello get"), fileResp.Size)
}
}
func TestIntegration_File_Download(t *testing.T) {
env := testenv.NewTestEnv(t)
content := []byte("hello world")
fileID, _ := env.UploadTestData("download_test.txt", content)
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read download body: %v", err)
}
if string(body) != string(content) {
t.Fatalf("downloaded content mismatch: got %q, want %q", string(body), string(content))
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
t.Fatal("expected Content-Type header to be set")
}
}
func TestIntegration_File_Delete(t *testing.T) {
env := testenv.NewTestEnv(t)
fileID, _ := env.UploadTestData("delete_test.txt", []byte("hello delete"))
// Delete the file.
resp := env.DoRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := fileDecode(t, resp.Body)
if !r.Success {
t.Fatalf("delete response not success: %s", r.Error)
}
// Verify the file is gone — GET should return 500 (internal error) or 404.
getResp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
defer getResp.Body.Close()
if getResp.StatusCode == http.StatusOK {
gr := fileDecode(t, getResp.Body)
if gr.Success {
t.Fatal("expected file to be deleted, but GET still returns success")
}
}
}

View File

@@ -0,0 +1,193 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// folderData mirrors the FolderResponse DTO returned by the API.
type folderData struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parent_id,omitempty"`
Path string `json:"path"`
FileCount int64 `json:"file_count"`
SubFolderCount int64 `json:"subfolder_count"`
}
// folderMessageData mirrors the delete endpoint response data structure.
type folderMessageData struct {
Message string `json:"message"`
}
// folderDoRequest marshals body and calls env.DoRequest.
func folderDoRequest(env *testenv.TestEnv, method, path string, body interface{}) *http.Response {
var r io.Reader
if body != nil {
b, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("folderDoRequest marshal: %v", err))
}
r = bytes.NewReader(b)
}
return env.DoRequest(method, path, r)
}
// folderDecodeAll decodes the response and returns status, success, and raw data.
func folderDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// folderSeed creates a folder via HTTP and returns its ID.
func folderSeed(env *testenv.TestEnv, name string) int64 {
body := map[string]interface{}{"name": name}
resp := folderDoRequest(env, http.MethodPost, "/api/v1/files/folders", body)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
panic(fmt.Sprintf("folderSeed decode: %v", err))
}
if status != http.StatusCreated {
panic(fmt.Sprintf("folderSeed: expected 201, got %d", status))
}
if !success {
panic("folderSeed: expected success=true")
}
var f folderData
if err := json.Unmarshal(data, &f); err != nil {
panic(fmt.Sprintf("folderSeed unmarshal: %v", err))
}
return f.ID
}
// TestIntegration_Folder_Create verifies POST /api/v1/files/folders creates a folder.
func TestIntegration_Folder_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
body := map[string]interface{}{"name": "test-folder-create"}
resp := folderDoRequest(env, http.MethodPost, "/api/v1/files/folders", body)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var created folderData
if err := json.Unmarshal(data, &created); err != nil {
t.Fatalf("unmarshal created data: %v", err)
}
if created.ID <= 0 {
t.Fatalf("expected positive id, got %d", created.ID)
}
if created.Name != "test-folder-create" {
t.Fatalf("expected name=test-folder-create, got %s", created.Name)
}
}
// TestIntegration_Folder_List verifies GET /api/v1/files/folders returns a list.
func TestIntegration_Folder_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Seed two folders.
folderSeed(env, "list-folder-1")
folderSeed(env, "list-folder-2")
resp := env.DoRequest(http.MethodGet, "/api/v1/files/folders", nil)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var folders []folderData
if err := json.Unmarshal(data, &folders); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if len(folders) < 2 {
t.Fatalf("expected at least 2 folders, got %d", len(folders))
}
}
// TestIntegration_Folder_Get verifies GET /api/v1/files/folders/:id returns folder details.
func TestIntegration_Folder_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
id := folderSeed(env, "test-folder-get")
path := fmt.Sprintf("/api/v1/files/folders/%d", id)
resp := env.DoRequest(http.MethodGet, path, nil)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var f folderData
if err := json.Unmarshal(data, &f); err != nil {
t.Fatalf("unmarshal folder data: %v", err)
}
if f.ID != id {
t.Fatalf("expected id=%d, got %d", id, f.ID)
}
if f.Name != "test-folder-get" {
t.Fatalf("expected name=test-folder-get, got %s", f.Name)
}
}
// TestIntegration_Folder_Delete verifies DELETE /api/v1/files/folders/:id removes a folder.
func TestIntegration_Folder_Delete(t *testing.T) {
env := testenv.NewTestEnv(t)
id := folderSeed(env, "test-folder-delete")
path := fmt.Sprintf("/api/v1/files/folders/%d", id)
resp := env.DoRequest(http.MethodDelete, path, nil)
status, success, data, err := folderDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg folderMessageData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal message data: %v", err)
}
if msg.Message != "folder deleted" {
t.Fatalf("expected message 'folder deleted', got %q", msg.Message)
}
// Verify it's gone via GET → 404.
getResp := env.DoRequest(http.MethodGet, path, nil)
getStatus, getSuccess, _, _ := folderDecodeAll(env, getResp)
if getStatus != http.StatusNotFound {
t.Fatalf("expected status 404 after delete, got %d", getStatus)
}
if getSuccess {
t.Fatal("expected success=false after delete")
}
}

View File

@@ -0,0 +1,222 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"testing"
"gcy_hpc_server/internal/testutil/testenv"
)
// jobItemData mirrors the JobResponse DTO for a single job.
type jobItemData struct {
JobID int32 `json:"job_id"`
Name string `json:"name"`
State []string `json:"job_state"`
Partition string `json:"partition"`
}
// jobListData mirrors the paginated JobListResponse DTO.
type jobListData struct {
Jobs []jobItemData `json:"jobs"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// jobCancelData mirrors the cancel response message.
type jobCancelData struct {
Message string `json:"message"`
}
// jobDecodeAll decodes the response and returns status, success, and raw data.
func jobDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
statusCode = resp.StatusCode
success, data, err = env.DecodeResponse(resp)
return
}
// jobSubmitBody builds a JSON body for job submit requests.
func jobSubmitBody(script string) *bytes.Reader {
body, _ := json.Marshal(map[string]string{"script": script})
return bytes.NewReader(body)
}
// jobSubmitViaAPI submits a job and returns the job ID. Fatals on failure.
func jobSubmitViaAPI(t *testing.T, env *testenv.TestEnv, script string) int32 {
t.Helper()
resp := env.DoRequest(http.MethodPost, "/api/v1/jobs/submit", jobSubmitBody(script))
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("submit job decode: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true on submit")
}
var job jobItemData
if err := json.Unmarshal(data, &job); err != nil {
t.Fatalf("unmarshal submitted job: %v", err)
}
return job.JobID
}
// TestIntegration_Jobs_Submit verifies POST /api/v1/jobs/submit creates a new job.
func TestIntegration_Jobs_Submit(t *testing.T) {
env := testenv.NewTestEnv(t)
script := "#!/bin/bash\necho hello"
resp := env.DoRequest(http.MethodPost, "/api/v1/jobs/submit", jobSubmitBody(script))
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusCreated {
t.Fatalf("expected status 201, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var job jobItemData
if err := json.Unmarshal(data, &job); err != nil {
t.Fatalf("unmarshal job: %v", err)
}
if job.JobID <= 0 {
t.Fatalf("expected positive job_id, got %d", job.JobID)
}
}
// TestIntegration_Jobs_List verifies GET /api/v1/jobs returns a paginated job list.
func TestIntegration_Jobs_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Submit a job so the list is not empty.
jobSubmitViaAPI(t, env, "#!/bin/bash\necho list-test")
resp := env.DoRequest(http.MethodGet, "/api/v1/jobs", nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var list jobListData
if err := json.Unmarshal(data, &list); err != nil {
t.Fatalf("unmarshal job list: %v", err)
}
if list.Total < 1 {
t.Fatalf("expected at least 1 job, got total=%d", list.Total)
}
if list.Page != 1 {
t.Fatalf("expected page=1, got %d", list.Page)
}
}
// TestIntegration_Jobs_Get verifies GET /api/v1/jobs/:id returns a single job.
func TestIntegration_Jobs_Get(t *testing.T) {
env := testenv.NewTestEnv(t)
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho get-test")
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
resp := env.DoRequest(http.MethodGet, path, nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var job jobItemData
if err := json.Unmarshal(data, &job); err != nil {
t.Fatalf("unmarshal job: %v", err)
}
if job.JobID != jobID {
t.Fatalf("expected job_id=%d, got %d", jobID, job.JobID)
}
}
// TestIntegration_Jobs_Cancel verifies DELETE /api/v1/jobs/:id cancels a job.
func TestIntegration_Jobs_Cancel(t *testing.T) {
env := testenv.NewTestEnv(t)
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho cancel-test")
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
resp := env.DoRequest(http.MethodDelete, path, nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var msg jobCancelData
if err := json.Unmarshal(data, &msg); err != nil {
t.Fatalf("unmarshal cancel response: %v", err)
}
if msg.Message == "" {
t.Fatal("expected non-empty cancel message")
}
}
// TestIntegration_Jobs_History verifies GET /api/v1/jobs/history returns historical jobs.
func TestIntegration_Jobs_History(t *testing.T) {
env := testenv.NewTestEnv(t)
// Submit and cancel a job so it moves from active to history queue.
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho history-test")
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
env.DoRequest(http.MethodDelete, path, nil)
resp := env.DoRequest(http.MethodGet, "/api/v1/jobs/history", nil)
status, success, data, err := jobDecodeAll(env, resp)
if err != nil {
t.Fatalf("decode response: %v", err)
}
if status != http.StatusOK {
t.Fatalf("expected status 200, got %d", status)
}
if !success {
t.Fatal("expected success=true")
}
var list jobListData
if err := json.Unmarshal(data, &list); err != nil {
t.Fatalf("unmarshal history: %v", err)
}
if list.Total < 1 {
t.Fatalf("expected at least 1 history job, got total=%d", list.Total)
}
// Verify the cancelled job appears in history.
found := false
for _, j := range list.Jobs {
if j.JobID == jobID {
found = true
break
}
}
if !found {
t.Fatalf("cancelled job %d not found in history", jobID)
}
}

View File

@@ -0,0 +1,261 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/testutil/testenv"
)
// taskAPIResponse decodes the unified API response envelope.
type taskAPIResponse struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// taskCreateData is the data payload from a successful task creation.
type taskCreateData struct {
ID int64 `json:"id"`
}
// taskListData is the data payload from listing tasks.
type taskListData struct {
Items []taskListItem `json:"items"`
Total int64 `json:"total"`
}
type taskListItem struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
AppID int64 `json:"app_id"`
Status string `json:"status"`
SlurmJobID *int32 `json:"slurm_job_id"`
}
// taskSendReq sends an HTTP request via the test env and returns the response.
func taskSendReq(t *testing.T, env *testenv.TestEnv, method, path string, body string) *http.Response {
t.Helper()
var r io.Reader
if body != "" {
r = strings.NewReader(body)
}
resp := env.DoRequest(method, path, r)
return resp
}
// taskParseResp decodes the response body into a taskAPIResponse.
func taskParseResp(t *testing.T, resp *http.Response) taskAPIResponse {
t.Helper()
b, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatalf("read response body: %v", err)
}
var result taskAPIResponse
if err := json.Unmarshal(b, &result); err != nil {
t.Fatalf("unmarshal response: %v (body: %s)", err, string(b))
}
return result
}
// taskCreateViaAPI creates a task via the HTTP API and returns the task ID.
func taskCreateViaAPI(t *testing.T, env *testenv.TestEnv, appID int64, taskName string) int64 {
t.Helper()
body := fmt.Sprintf(`{"app_id":%d,"task_name":"%s","values":{},"file_ids":[]}`, appID, taskName)
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success=true, got error: %s", parsed.Error)
}
var data taskCreateData
if err := json.Unmarshal(parsed.Data, &data); err != nil {
t.Fatalf("unmarshal create data: %v", err)
}
if data.ID == 0 {
t.Fatal("expected non-zero task ID")
}
return data.ID
}
// ---------- Tests ----------
func TestIntegration_Task_Create(t *testing.T) {
env := testenv.NewTestEnv(t)
// Create application
appID, err := env.CreateApp("task-create-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// Create task via API
taskID := taskCreateViaAPI(t, env, appID, "test-task-create")
// Verify the task ID is positive
if taskID <= 0 {
t.Fatalf("expected positive task ID, got %d", taskID)
}
// Wait briefly for async processing, then verify task exists in DB via list
time.Sleep(200 * time.Millisecond)
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer resp.Body.Close()
parsed := taskParseResp(t, resp)
var listData taskListData
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
found := false
for _, item := range listData.Items {
if item.ID == taskID {
found = true
if item.TaskName != "test-task-create" {
t.Errorf("expected task_name=test-task-create, got %s", item.TaskName)
}
if item.AppID != appID {
t.Errorf("expected app_id=%d, got %d", appID, item.AppID)
}
break
}
}
if !found {
t.Fatalf("task %d not found in list", taskID)
}
}
func TestIntegration_Task_List(t *testing.T) {
env := testenv.NewTestEnv(t)
// Create application
appID, err := env.CreateApp("task-list-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// Create 3 tasks
taskCreateViaAPI(t, env, appID, "list-task-1")
taskCreateViaAPI(t, env, appID, "list-task-2")
taskCreateViaAPI(t, env, appID, "list-task-3")
// Allow async processing
time.Sleep(200 * time.Millisecond)
// List tasks
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(b))
}
parsed := taskParseResp(t, resp)
if !parsed.Success {
t.Fatalf("expected success, got error: %s", parsed.Error)
}
var listData taskListData
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
t.Fatalf("unmarshal list data: %v", err)
}
if listData.Total < 3 {
t.Fatalf("expected at least 3 tasks, got %d", listData.Total)
}
// Verify each created task has required fields
for _, item := range listData.Items {
if item.ID == 0 {
t.Error("expected non-zero ID")
}
if item.Status == "" {
t.Error("expected non-empty status")
}
if item.AppID == 0 {
t.Error("expected non-zero app_id")
}
}
}
func TestIntegration_Task_PollerLifecycle(t *testing.T) {
env := testenv.NewTestEnv(t)
// 1. Create application
appID, err := env.CreateApp("poller-lifecycle-app", "#!/bin/bash\necho hello", nil)
if err != nil {
t.Fatalf("create app: %v", err)
}
// 2. Submit task via API
taskID := taskCreateViaAPI(t, env, appID, "poller-lifecycle-task")
// 3. Wait for queued — TaskProcessor submits to MockSlurm asynchronously.
// Intermediate states (submitted→preparing→downloading→ready→queued) are
// non-deterministic; only assert the final "queued" state.
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
t.Fatalf("wait for queued: %v", err)
}
// 4. Get slurm job ID from DB (not returned by API)
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
if err != nil {
t.Fatalf("get slurm job id: %v", err)
}
// 5. Transition: queued → running
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("make task stale (running): %v", err)
}
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
t.Fatalf("wait for running: %v", err)
}
// 6. Transition: running → completed
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
if err := env.MakeTaskStale(taskID); err != nil {
t.Fatalf("make task stale (completed): %v", err)
}
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
t.Fatalf("wait for completed: %v", err)
}
}
func TestIntegration_Task_Validation(t *testing.T) {
env := testenv.NewTestEnv(t)
// Missing required app_id
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", `{"task_name":"no-app-id"}`)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing app_id, got %d", resp.StatusCode)
}
parsed := taskParseResp(t, resp)
if parsed.Success {
t.Fatal("expected success=false for validation error")
}
if parsed.Error == "" {
t.Error("expected non-empty error message")
}
}

View File

@@ -0,0 +1,279 @@
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/testutil/testenv"
)
// uploadResponse mirrors server.APIResponse for upload integration tests.
type uploadAPIResp struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// uploadDecode parses an HTTP response body into uploadAPIResp.
func uploadDecode(t *testing.T, body io.Reader) uploadAPIResp {
t.Helper()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("uploadDecode: read body: %v", err)
}
var r uploadAPIResp
if err := json.Unmarshal(data, &r); err != nil {
t.Fatalf("uploadDecode: unmarshal: %v (body: %s)", err, string(data))
}
return r
}
// uploadInitSession calls InitUpload and returns the created session.
// Uses the real HTTP server from testenv.
func uploadInitSession(t *testing.T, env *testenv.TestEnv, fileName string, fileSize int64, sha256Hash string) model.UploadSessionResponse {
t.Helper()
reqBody := model.InitUploadRequest{
FileName: fileName,
FileSize: fileSize,
SHA256: sha256Hash,
}
body, _ := json.Marshal(reqBody)
resp := env.DoRequest("POST", "/api/v1/files/uploads", bytes.NewReader(body))
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("uploadInitSession: expected 201, got %d", resp.StatusCode)
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("uploadInitSession: response not success: %s", r.Error)
}
var session model.UploadSessionResponse
if err := json.Unmarshal(r.Data, &session); err != nil {
t.Fatalf("uploadInitSession: unmarshal session: %v", err)
}
return session
}
// uploadSendChunk sends a single chunk via multipart form data.
// Uses raw HTTP client to set the correct multipart content type.
func uploadSendChunk(t *testing.T, env *testenv.TestEnv, sessionID int64, chunkIndex int, chunkData []byte) {
t.Helper()
url := fmt.Sprintf("%s/api/v1/files/uploads/%d/chunks/%d", env.URL(), sessionID, chunkIndex)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("chunk", "chunk.bin")
if err != nil {
t.Fatalf("uploadSendChunk: create form file: %v", err)
}
part.Write(chunkData)
writer.Close()
req, err := http.NewRequest("PUT", url, &buf)
if err != nil {
t.Fatalf("uploadSendChunk: new request: %v", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("uploadSendChunk: do request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("uploadSendChunk: expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
}
func TestIntegration_Upload_Init(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test upload init content")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "init_test.txt", int64(len(fileData)), sha256Hash)
if session.ID <= 0 {
t.Fatalf("expected positive session ID, got %d", session.ID)
}
if session.FileName != "init_test.txt" {
t.Fatalf("expected file_name init_test.txt, got %s", session.FileName)
}
if session.Status != "pending" {
t.Fatalf("expected status pending, got %s", session.Status)
}
if session.TotalChunks != 1 {
t.Fatalf("expected 1 chunk for small file, got %d", session.TotalChunks)
}
if session.FileSize != int64(len(fileData)) {
t.Fatalf("expected file_size %d, got %d", len(fileData), session.FileSize)
}
if session.SHA256 != sha256Hash {
t.Fatalf("expected sha256 %s, got %s", sha256Hash, session.SHA256)
}
}
func TestIntegration_Upload_Status(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test status content")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "status_test.txt", int64(len(fileData)), sha256Hash)
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("response not success: %s", r.Error)
}
var status model.UploadSessionResponse
if err := json.Unmarshal(r.Data, &status); err != nil {
t.Fatalf("unmarshal status: %v", err)
}
if status.ID != session.ID {
t.Fatalf("expected session ID %d, got %d", session.ID, status.ID)
}
if status.Status != "pending" {
t.Fatalf("expected status pending, got %s", status.Status)
}
if status.FileName != "status_test.txt" {
t.Fatalf("expected file_name status_test.txt, got %s", status.FileName)
}
}
func TestIntegration_Upload_Chunk(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test chunk upload data")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "chunk_test.txt", int64(len(fileData)), sha256Hash)
uploadSendChunk(t, env, session.ID, 0, fileData)
// Verify chunk appears in uploaded_chunks via status endpoint
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer resp.Body.Close()
r := uploadDecode(t, resp.Body)
var status model.UploadSessionResponse
if err := json.Unmarshal(r.Data, &status); err != nil {
t.Fatalf("unmarshal status after chunk: %v", err)
}
if len(status.UploadedChunks) != 1 {
t.Fatalf("expected 1 uploaded chunk, got %d", len(status.UploadedChunks))
}
if status.UploadedChunks[0] != 0 {
t.Fatalf("expected uploaded chunk index 0, got %d", status.UploadedChunks[0])
}
}
func TestIntegration_Upload_Complete(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test complete upload data")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "complete_test.txt", int64(len(fileData)), sha256Hash)
// Upload all chunks
for i := 0; i < session.TotalChunks; i++ {
uploadSendChunk(t, env, session.ID, i, fileData)
}
// Complete upload
resp := env.DoRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(bodyBytes))
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("complete response not success: %s", r.Error)
}
var fileResp model.FileResponse
if err := json.Unmarshal(r.Data, &fileResp); err != nil {
t.Fatalf("unmarshal file response: %v", err)
}
if fileResp.ID <= 0 {
t.Fatalf("expected positive file ID, got %d", fileResp.ID)
}
if fileResp.Name != "complete_test.txt" {
t.Fatalf("expected name complete_test.txt, got %s", fileResp.Name)
}
if fileResp.Size != int64(len(fileData)) {
t.Fatalf("expected size %d, got %d", len(fileData), fileResp.Size)
}
if fileResp.SHA256 != sha256Hash {
t.Fatalf("expected sha256 %s, got %s", sha256Hash, fileResp.SHA256)
}
}
func TestIntegration_Upload_Cancel(t *testing.T) {
env := testenv.NewTestEnv(t)
fileData := []byte("integration test cancel upload data")
h := sha256.Sum256(fileData)
sha256Hash := hex.EncodeToString(h[:])
session := uploadInitSession(t, env, "cancel_test.txt", int64(len(fileData)), sha256Hash)
// Cancel the upload
resp := env.DoRequest("DELETE", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
r := uploadDecode(t, resp.Body)
if !r.Success {
t.Fatalf("cancel response not success: %s", r.Error)
}
// Verify session is no longer in pending state by checking status
statusResp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
defer statusResp.Body.Close()
sr := uploadDecode(t, statusResp.Body)
if sr.Success {
var status model.UploadSessionResponse
if err := json.Unmarshal(sr.Data, &status); err == nil {
if status.Status == "pending" {
t.Fatal("expected status to not be pending after cancel")
}
}
}
}

View File

@@ -21,7 +21,7 @@ import (
func newTestDB() *gorm.DB {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.JobTemplate{})
db.AutoMigrate(&model.Application{})
return db
}
@@ -34,12 +34,17 @@ func TestRouterRegistration(t *testing.T) {
defer slurmSrv.Close()
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
templateStore := store.NewTemplateStore(newTestDB())
jobSvc := service.NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(newTestDB())
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
appH := handler.NewApplicationHandler(appSvc, zap.NewNop())
router := server.NewRouter(
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
handler.NewJobHandler(jobSvc, zap.NewNop()),
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
handler.NewTemplateHandler(templateStore, zap.NewNop()),
appH,
nil, nil, nil,
nil,
nil,
)
@@ -58,11 +63,12 @@ func TestRouterRegistration(t *testing.T) {
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/partitions/:name"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/templates"},
{"POST", "/api/v1/templates"},
{"GET", "/api/v1/templates/:id"},
{"PUT", "/api/v1/templates/:id"},
{"DELETE", "/api/v1/templates/:id"},
{"GET", "/api/v1/applications"},
{"POST", "/api/v1/applications"},
{"GET", "/api/v1/applications/:id"},
{"PUT", "/api/v1/applications/:id"},
{"DELETE", "/api/v1/applications/:id"},
// {"POST", "/api/v1/applications/:id/submit"}, // [已禁用] 已被 POST /tasks 取代
}
routeMap := map[string]bool{}
@@ -90,12 +96,17 @@ func TestSmokeGetJobsEndpoint(t *testing.T) {
defer slurmSrv.Close()
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
templateStore := store.NewTemplateStore(newTestDB())
jobSvc := service.NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(newTestDB())
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
appH := handler.NewApplicationHandler(appSvc, zap.NewNop())
router := server.NewRouter(
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
handler.NewJobHandler(jobSvc, zap.NewNop()),
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
handler.NewTemplateHandler(templateStore, zap.NewNop()),
appH,
nil, nil, nil,
nil,
nil,
)

434
cmd/server/task_test.go Normal file
View File

@@ -0,0 +1,434 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/handler"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(1)
db.AutoMigrate(
&model.Application{},
&model.File{},
&model.FileBlob{},
&model.Task{},
)
t.Cleanup(func() {
sqlDB.Close()
})
return db
}
type mockSlurmHandler struct {
submitFn func(w http.ResponseWriter, r *http.Request)
}
func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
if h != nil && h.submitFn != nil {
h.submitFn(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`)
})
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) {
t.Helper()
return setupTaskTestServerWithSlurm(t, nil)
}
func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) {
t.Helper()
db := newTaskTestDB(t)
slurmSrv := newMockSlurmServer(t, slurmHandler)
client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
if err != nil {
t.Fatalf("slurm client: %v", err)
}
log := zap.NewNop()
jobSvc := service.NewJobService(client, log)
appStore := store.NewApplicationStore(db)
taskStore := store.NewTaskStore(db)
tmpDir, err := os.MkdirTemp("", "task-test-workdir-*")
if err != nil {
t.Fatalf("temp dir: %v", err)
}
taskSvc := service.NewTaskService(
taskStore, appStore,
nil, nil, nil,
jobSvc,
tmpDir,
log,
)
ctx, cancel := context.WithCancel(context.Background())
taskSvc.StartProcessor(ctx)
appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc)
appH := handler.NewApplicationHandler(appSvc, log)
taskH := handler.NewTaskHandler(taskSvc, log)
router := server.NewRouter(
handler.NewJobHandler(jobSvc, log),
handler.NewClusterHandler(service.NewClusterService(client, log), log),
appH,
nil, nil, nil,
taskH,
log,
)
httpSrv := httptest.NewServer(router)
cleanup := func() {
taskSvc.StopProcessor()
cancel()
httpSrv.Close()
os.RemoveAll(tmpDir)
}
return httpSrv, taskStore, cleanup
}
func createTestApp(t *testing.T, srvURL string) int64 {
t.Helper()
body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}`
resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("create app: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("create app: status %d, body %s", resp.StatusCode, b)
}
var result struct {
Data struct {
ID int64 `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("decode app response: %v", err)
}
return result.Data.ID
}
func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) {
t.Helper()
resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("post task: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
var result map[string]interface{}
json.Unmarshal(b, &result)
return resp, result
}
func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) {
t.Helper()
url := srvURL + "/api/v1/tasks"
if query != "" {
url += "?" + query
}
resp, err := http.Get(url)
if err != nil {
t.Fatalf("get tasks: %v", err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
var result map[string]interface{}
json.Unmarshal(b, &result)
data, _ := result["data"].(map[string]interface{})
items, _ := data["items"].([]interface{})
total, _ := data["total"].(float64)
return int(total), items
}
func TestTask_FullLifecycle(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
data, _ := result["data"].(map[string]interface{})
if data["id"] == nil {
t.Fatal("expected id in response")
}
time.Sleep(300 * time.Millisecond)
total, items := getTasks(t, srv.URL, "")
if total < 1 {
t.Fatalf("expected at least 1 task, got %d", total)
}
task := items[0].(map[string]interface{})
if task["status"] == nil || task["status"] == "" {
t.Error("expected non-empty status")
}
if task["task_name"] != "lifecycle-test" {
t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"])
}
if task["app_id"] != float64(appID) {
t.Errorf("expected app_id=%d, got %v", appID, task["app_id"])
}
}
func TestTask_CreateWithMissingApp(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_CreateWithInvalidBody(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
resp, result := postTask(t, srv.URL, `{}`)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_FileLimitExceeded(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
fileIDs := make([]int64, 101)
for i := range fileIDs {
fileIDs[i] = int64(i + 1)
}
idsJSON, _ := json.Marshal(fileIDs)
body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON))
resp, result := postTask(t, srv.URL, body)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result)
}
}
func TestTask_RetryScenario(t *testing.T) {
var failCount int32
slurmH := &mockSlurmHandler{
submitFn: func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&failCount, 1) <= 2 {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
return
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`)
},
}
srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
task, _ := taskStore.GetByID(context.Background(), taskID)
if task != nil && task.Status == model.TaskStatusQueued {
if task.SlurmJobID != nil && *task.SlurmJobID == 99 {
return
}
}
time.Sleep(100 * time.Millisecond)
}
task, _ := taskStore.GetByID(context.Background(), taskID)
if task == nil {
t.Fatalf("task %d not found after deadline", taskID)
}
t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v",
task.Status, task.RetryCount, task.SlurmJobID)
}
// [已禁用] 前端已全部迁移到 POST /tasks 接口,旧 API 兼容性测试不再需要。
/*
func TestTask_OldAPICompatibility(t *testing.T) {
srv, _, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
body := `{"values":{"np":"8"}}`
url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID)
resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body)))
if err != nil {
t.Fatalf("post submit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b)
}
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
data, _ := result["data"].(map[string]interface{})
if data == nil {
t.Fatal("expected data in response")
}
if data["job_id"] == nil {
t.Errorf("expected job_id in old API response, got: %v", data)
}
jobID, ok := data["job_id"].(float64)
if !ok || jobID == 0 {
t.Errorf("expected non-zero job_id, got %v", data["job_id"])
}
}
*/
func TestTask_ListWithFilters(t *testing.T) {
srv, taskStore, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
now := time.Now()
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-completed", AppID: appID, AppName: "test-app",
Status: model.TaskStatusCompleted, SubmittedAt: now,
})
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-failed", AppID: appID, AppName: "test-app",
Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now,
})
taskStore.Create(context.Background(), &model.Task{
TaskName: "task-queued", AppID: appID, AppName: "test-app",
Status: model.TaskStatusQueued, SubmittedAt: now,
})
total, items := getTasks(t, srv.URL, "status=completed")
if total != 1 {
t.Fatalf("expected 1 completed, got %d", total)
}
task := items[0].(map[string]interface{})
if task["status"] != "completed" {
t.Errorf("expected status=completed, got %v", task["status"])
}
total, items = getTasks(t, srv.URL, "status=failed")
if total != 1 {
t.Fatalf("expected 1 failed, got %d", total)
}
total, _ = getTasks(t, srv.URL, "")
if total != 3 {
t.Fatalf("expected 3 total, got %d", total)
}
}
func TestTask_WorkDirCreated(t *testing.T) {
srv, taskStore, cleanup := setupTaskTestServer(t)
defer cleanup()
appID := createTestApp(t, srv.URL)
resp, result := postTask(t, srv.URL, fmt.Sprintf(
`{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID,
))
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
task, _ := taskStore.GetByID(context.Background(), taskID)
if task != nil && task.WorkDir != "" {
if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) {
t.Fatalf("work dir %s not created", task.WorkDir)
}
if !filepath.IsAbs(task.WorkDir) {
t.Errorf("expected absolute work dir, got %s", task.WorkDir)
}
return
}
time.Sleep(50 * time.Millisecond)
}
task, _ := taskStore.GetByID(context.Background(), taskID)
if task == nil {
t.Fatalf("task %d not found after deadline", taskID)
}
t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir)
}

12
go.mod
View File

@@ -19,31 +19,43 @@ require (
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/klauspost/crc32 v1.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/minio/crc64nvme v1.1.1 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/minio-go/v7 v7.0.100 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/tinylib/msgp v1.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.51.0 // indirect

25
go.sum
View File

@@ -12,12 +12,16 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -35,14 +39,21 @@ github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM=
github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -53,6 +64,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI=
github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.100 h1:ShkWi8Tyj9RtU57OQB2HIXKz4bFgtVib0bbT1sbtLI8=
github.com/minio/minio-go/v7 v7.0.100/go.mod h1:EtGNKtlX20iL2yaYnxEigaIvj0G0GwSDnifnG8ClIdw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -60,6 +77,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -69,6 +88,8 @@ github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SA
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -80,6 +101,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY=
github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
@@ -94,6 +117,8 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=

2954
hpc_server_openapi.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -15,18 +15,21 @@ import (
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
)
// App encapsulates the entire application lifecycle.
type App struct {
cfg *config.Config
logger *zap.Logger
db *gorm.DB
server *http.Server
cancelCleanup context.CancelFunc
taskSvc *service.TaskService
taskPoller *TaskPoller
}
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
@@ -42,13 +45,16 @@ func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
return nil, err
}
srv := initHTTPServer(cfg, gormDB, slurmClient, logger)
srv, cancelCleanup, taskSvc, taskPoller := initHTTPServer(cfg, gormDB, slurmClient, logger)
return &App{
cfg: cfg,
logger: logger,
db: gormDB,
server: srv,
cancelCleanup: cancelCleanup,
taskSvc: taskSvc,
taskPoller: taskPoller,
}, nil
}
@@ -84,6 +90,18 @@ func (a *App) Run() error {
func (a *App) Close() error {
var errs []error
if a.taskSvc != nil {
a.taskSvc.StopProcessor()
}
if a.taskPoller != nil {
a.taskPoller.Stop()
}
if a.cancelCleanup != nil {
a.cancelCleanup()
}
if a.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -139,21 +157,71 @@ func initSlurmClient(cfg *config.Config) (*slurm.Client, error) {
return client, nil
}
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) *http.Server {
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc, *service.TaskService, *TaskPoller) {
ctx := context.Background()
jobSvc := service.NewJobService(slurmClient, logger)
clusterSvc := service.NewClusterService(slurmClient, logger)
templateStore := store.NewTemplateStore(db)
jobH := handler.NewJobHandler(jobSvc, logger)
clusterH := handler.NewClusterHandler(clusterSvc, logger)
templateH := handler.NewTemplateHandler(templateStore, logger)
router := server.NewRouter(jobH, clusterH, templateH, logger)
appStore := store.NewApplicationStore(db)
// File storage initialization
minioClient, err := storage.NewMinioClient(cfg.Minio)
if err != nil {
logger.Warn("failed to initialize MinIO client, file storage disabled", zap.Error(err))
}
var uploadH *handler.UploadHandler
var fileH *handler.FileHandler
var folderH *handler.FolderHandler
taskStore := store.NewTaskStore(db)
fileStore := store.NewFileStore(db)
blobStore := store.NewBlobStore(db)
var stagingSvc *service.FileStagingService
if minioClient != nil {
folderStore := store.NewFolderStore(db)
uploadStore := store.NewUploadStore(db)
uploadSvc := service.NewUploadService(minioClient, blobStore, fileStore, uploadStore, cfg.Minio, db, logger)
folderSvc := service.NewFolderService(folderStore, fileStore, logger)
fileSvc := service.NewFileService(minioClient, blobStore, fileStore, cfg.Minio.Bucket, db, logger)
uploadH = handler.NewUploadHandler(uploadSvc, logger)
fileH = handler.NewFileHandler(fileSvc, logger)
folderH = handler.NewFolderHandler(folderSvc, logger)
stagingSvc = service.NewFileStagingService(fileStore, blobStore, minioClient, cfg.Minio.Bucket, logger)
}
taskSvc := service.NewTaskService(taskStore, appStore, fileStore, blobStore, stagingSvc, jobSvc, cfg.WorkDirBase, logger)
taskSvc.StartProcessor(ctx)
appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger, taskSvc)
appH := handler.NewApplicationHandler(appSvc, logger)
poller := NewTaskPoller(taskSvc, 10*time.Second, logger)
poller.Start(ctx)
taskH := handler.NewTaskHandler(taskSvc, logger)
var cancelCleanup context.CancelFunc
if minioClient != nil {
cleanupCtx, cancel := context.WithCancel(context.Background())
cancelCleanup = cancel
go startCleanupWorker(cleanupCtx, store.NewUploadStore(db), minioClient, cfg.Minio.Bucket, logger)
}
router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, taskH, logger)
addr := ":" + cfg.ServerPort
return &http.Server{
Addr: addr,
Handler: router,
}
}, cancelCleanup, taskSvc, poller
}

83
internal/app/cleanup.go Normal file
View File

@@ -0,0 +1,83 @@
package app
import (
"context"
"time"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// startCleanupWorker runs a background goroutine that periodically cleans up:
// 1. Expired upload sessions (mark → delete MinIO chunks → delete DB records)
// 2. Leaked multipart uploads from failed ComposeObject calls
func startCleanupWorker(ctx context.Context, uploadStore *store.UploadStore, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
cleanupExpiredSessions(ctx, uploadStore, objStorage, bucket, logger)
cleanupLeakedMultipartUploads(ctx, objStorage, bucket, logger)
for {
select {
case <-ctx.Done():
logger.Info("cleanup worker stopped")
return
case <-ticker.C:
cleanupExpiredSessions(ctx, uploadStore, objStorage, bucket, logger)
cleanupLeakedMultipartUploads(ctx, objStorage, bucket, logger)
}
}
}
// cleanupExpiredSessions performs three-phase cleanup of expired upload sessions:
// Phase 1: Find and mark expired sessions
// Phase 2: Delete MinIO temp chunks for each session
// Phase 3: Delete DB records (session + chunks)
func cleanupExpiredSessions(ctx context.Context, uploadStore *store.UploadStore, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
sessions, err := uploadStore.ListExpiredSessions(ctx)
if err != nil {
logger.Error("failed to list expired sessions", zap.Error(err))
return
}
if len(sessions) == 0 {
return
}
logger.Info("cleaning up expired sessions", zap.Int("count", len(sessions)))
for i := range sessions {
session := &sessions[i]
if err := uploadStore.UpdateSessionStatus(ctx, session.ID, "expired"); err != nil {
logger.Error("failed to mark session expired", zap.Int64("session_id", session.ID), zap.Error(err))
continue
}
objects, err := objStorage.ListObjects(ctx, bucket, session.MinioPrefix, true)
if err != nil {
logger.Error("failed to list session objects", zap.Int64("session_id", session.ID), zap.Error(err))
} else if len(objects) > 0 {
keys := make([]string, len(objects))
for j, obj := range objects {
keys[j] = obj.Key
}
if err := objStorage.RemoveObjects(ctx, bucket, keys, storage.RemoveObjectsOptions{}); err != nil {
logger.Error("failed to remove session objects", zap.Int64("session_id", session.ID), zap.Error(err))
}
}
if err := uploadStore.DeleteSession(ctx, session.ID); err != nil {
logger.Error("failed to delete session", zap.Int64("session_id", session.ID), zap.Error(err))
}
}
}
func cleanupLeakedMultipartUploads(ctx context.Context, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
if err := objStorage.RemoveIncompleteUpload(ctx, bucket, "uploads/"); err != nil {
logger.Error("failed to cleanup leaked multipart uploads", zap.Error(err))
}
}

View File

@@ -0,0 +1,266 @@
package app
import (
"context"
"io"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// mockCleanupStorage implements ObjectStorage for cleanup tests.
type mockCleanupStorage struct {
listObjectsFn func(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error)
removeObjectsFn func(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error
removeIncompleteFn func(ctx context.Context, bucket, object string) error
}
func (m *mockCleanupStorage) PutObject(_ context.Context, _ string, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockCleanupStorage) GetObject(_ context.Context, _ string, _ string, _ storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return nil, storage.ObjectInfo{}, nil
}
func (m *mockCleanupStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockCleanupStorage) AbortMultipartUpload(_ context.Context, _ string, _ string, _ string) error {
return nil
}
func (m *mockCleanupStorage) RemoveIncompleteUpload(_ context.Context, _ string, _ string) error {
if m.removeIncompleteFn != nil {
return m.removeIncompleteFn(context.Background(), "", "")
}
return nil
}
func (m *mockCleanupStorage) RemoveObject(_ context.Context, _ string, _ string, _ storage.RemoveObjectOptions) error {
return nil
}
func (m *mockCleanupStorage) ListObjects(ctx context.Context, bucket string, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
if m.listObjectsFn != nil {
return m.listObjectsFn(ctx, bucket, prefix, recursive)
}
return nil, nil
}
func (m *mockCleanupStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
if m.removeObjectsFn != nil {
return m.removeObjectsFn(ctx, bucket, keys, opts)
}
return nil
}
func (m *mockCleanupStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (m *mockCleanupStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (m *mockCleanupStorage) StatObject(_ context.Context, _ string, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupCleanupTestDB(t *testing.T) (*gorm.DB, *store.UploadStore) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.UploadSession{}, &model.UploadChunk{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db, store.NewUploadStore(db)
}
func TestCleanupExpiredSessions(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
past := time.Now().Add(-1 * time.Hour)
err := uploadStore.CreateSession(ctx, &model.UploadSession{
FileName: "expired.bin",
FileSize: 1024,
ChunkSize: 16 << 20,
TotalChunks: 1,
SHA256: "expired_hash",
Status: "pending",
MinioPrefix: "uploads/99/",
ExpiresAt: past,
})
if err != nil {
t.Fatalf("create session: %v", err)
}
var listedPrefix string
var removedKeys []string
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
listedPrefix = prefix
return []storage.ObjectInfo{{Key: "uploads/99/chunk_00000", Size: 100}}, nil
},
removeObjectsFn: func(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
removedKeys = keys
return nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if listedPrefix != "uploads/99/" {
t.Errorf("listed prefix = %q, want %q", listedPrefix, "uploads/99/")
}
if len(removedKeys) != 1 || removedKeys[0] != "uploads/99/chunk_00000" {
t.Errorf("removed keys = %v, want [uploads/99/chunk_00000]", removedKeys)
}
session, err := uploadStore.GetSession(ctx, 1)
if err != nil {
t.Fatalf("get session: %v", err)
}
if session != nil {
t.Error("session should be deleted after cleanup")
}
}
func TestCleanupExpiredSessions_Empty(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
called := false
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
called = true
return nil, nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if called {
t.Error("ListObjects should not be called when no expired sessions exist")
}
}
func TestCleanupExpiredSessions_CompletedNotCleaned(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
past := time.Now().Add(-1 * time.Hour)
err := uploadStore.CreateSession(ctx, &model.UploadSession{
FileName: "completed.bin",
FileSize: 1024,
ChunkSize: 16 << 20,
TotalChunks: 1,
SHA256: "completed_hash",
Status: "completed",
MinioPrefix: "uploads/100/",
ExpiresAt: past,
})
if err != nil {
t.Fatalf("create session: %v", err)
}
called := false
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
called = true
return nil, nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if called {
t.Error("ListObjects should not be called for completed sessions")
}
session, _ := uploadStore.GetSession(ctx, 1)
if session == nil {
t.Error("completed session should not be deleted")
}
}
func TestCleanupWorker_StopsOnContextCancel(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx, cancel := context.WithCancel(context.Background())
mockStore := &mockCleanupStorage{}
done := make(chan struct{})
go func() {
startCleanupWorker(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
close(done)
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Error("worker did not stop after context cancel")
}
}
func TestCleanupExpiredSessions_NoObjects(t *testing.T) {
_, uploadStore := setupCleanupTestDB(t)
ctx := context.Background()
past := time.Now().Add(-1 * time.Hour)
err := uploadStore.CreateSession(ctx, &model.UploadSession{
FileName: "empty.bin",
FileSize: 1024,
ChunkSize: 16 << 20,
TotalChunks: 1,
SHA256: "empty_hash",
Status: "pending",
MinioPrefix: "uploads/200/",
ExpiresAt: past,
})
if err != nil {
t.Fatalf("create session: %v", err)
}
listCalled := false
removeCalled := false
mockStore := &mockCleanupStorage{
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
listCalled = true
return nil, nil
},
removeObjectsFn: func(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
removeCalled = true
return nil
},
}
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
if !listCalled {
t.Error("ListObjects should be called")
}
if removeCalled {
t.Error("RemoveObjects should not be called when no objects found")
}
session, _ := uploadStore.GetSession(ctx, 1)
if session != nil {
t.Error("session should be deleted even when no objects found")
}
}

View File

@@ -0,0 +1,61 @@
package app
import (
"context"
"sync"
"time"
"go.uber.org/zap"
)
// TaskPollable defines the interface for refreshing stale task statuses.
type TaskPollable interface {
RefreshStaleTasks(ctx context.Context) error
}
// TaskPoller periodically polls Slurm for task status updates via TaskPollable.
type TaskPoller struct {
taskSvc TaskPollable
interval time.Duration
cancel context.CancelFunc
wg sync.WaitGroup
logger *zap.Logger
}
// NewTaskPoller creates a new TaskPoller with the given service, interval, and logger.
func NewTaskPoller(taskSvc TaskPollable, interval time.Duration, logger *zap.Logger) *TaskPoller {
return &TaskPoller{
taskSvc: taskSvc,
interval: interval,
logger: logger,
}
}
// Start launches the background goroutine that periodically refreshes stale tasks.
func (p *TaskPoller) Start(ctx context.Context) {
ctx, p.cancel = context.WithCancel(ctx)
p.wg.Add(1)
go func() {
defer p.wg.Done()
ticker := time.NewTicker(p.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := p.taskSvc.RefreshStaleTasks(ctx); err != nil {
p.logger.Error("failed to refresh stale tasks", zap.Error(err))
}
}
}
}()
}
// Stop cancels the background goroutine and waits for it to finish.
func (p *TaskPoller) Stop() {
if p.cancel != nil {
p.cancel()
}
p.wg.Wait()
}

View File

@@ -0,0 +1,70 @@
package app
import (
"context"
"sync"
"testing"
"time"
"go.uber.org/zap"
)
type mockTaskPollable struct {
refreshFunc func(ctx context.Context) error
callCount int
mu sync.Mutex
}
func (m *mockTaskPollable) RefreshStaleTasks(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.refreshFunc != nil {
return m.refreshFunc(ctx)
}
return nil
}
func (m *mockTaskPollable) getCallCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.callCount
}
func TestTaskPoller_StartStop(t *testing.T) {
mock := &mockTaskPollable{}
logger := zap.NewNop()
poller := NewTaskPoller(mock, 1*time.Second, logger)
poller.Start(context.Background())
time.Sleep(100 * time.Millisecond)
poller.Stop()
// No goroutine leak — Stop() returned means wg.Wait() completed.
}
func TestTaskPoller_RefreshesStaleTasks(t *testing.T) {
mock := &mockTaskPollable{}
logger := zap.NewNop()
poller := NewTaskPoller(mock, 50*time.Millisecond, logger)
poller.Start(context.Background())
defer poller.Stop()
time.Sleep(300 * time.Millisecond)
if count := mock.getCallCount(); count < 1 {
t.Errorf("expected RefreshStaleTasks to be called at least once, got %d", count)
}
}
func TestTaskPoller_StopsCleanly(t *testing.T) {
mock := &mockTaskPollable{}
logger := zap.NewNop()
poller := NewTaskPoller(mock, 1*time.Second, logger)
poller.Start(context.Background())
poller.Stop()
// No panic and WaitGroup is done — Stop returned successfully.
}

View File

@@ -3,6 +3,7 @@ slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
work_dir_base: "/mnt/nfs_mount/platform" # 作业工作目录根路径,留空则不自动创建
log:
level: "info" # debug, info, warn, error
@@ -14,3 +15,14 @@ log:
max_age: 30 # days to retain old log files
compress: true # gzip rotated log files
gorm_level: "warn" # GORM SQL log level: silent, error, warn, info
minio:
endpoint: "http://fnos.dailz.cn:15001" # MinIO server address
access_key: "3dgDu9ncwflLoRQW2OeP" # access key
secret_key: "g2GLBNTPxJ9sdFwh37jtfilRSacEO5yQepMkDrnV" # secret key
bucket: "test" # bucket name
use_ssl: false # use TLS connection
chunk_size: 16777216 # upload chunk size in bytes (default: 16MB)
max_file_size: 53687091200 # max file size in bytes (default: 50GB)
min_chunk_size: 5242880 # minimum chunk size in bytes (default: 5MB)
session_ttl: 48 # session TTL in hours (default: 48)

View File

@@ -20,6 +20,19 @@ type LogConfig struct {
GormLevel string `yaml:"gorm_level"` // GORM SQL log level (default: warn)
}
// MinioConfig holds MinIO object storage configuration values.
type MinioConfig struct {
Endpoint string `yaml:"endpoint"` // MinIO server address
AccessKey string `yaml:"access_key"` // access key
SecretKey string `yaml:"secret_key"` // secret key
Bucket string `yaml:"bucket"` // bucket name
UseSSL bool `yaml:"use_ssl"` // use TLS connection
ChunkSize int64 `yaml:"chunk_size"` // upload chunk size in bytes (default: 16MB)
MaxFileSize int64 `yaml:"max_file_size"` // max file size in bytes (default: 50GB)
MinChunkSize int64 `yaml:"min_chunk_size"` // minimum chunk size in bytes (default: 5MB)
SessionTTL int `yaml:"session_ttl"` // session TTL in hours (default: 48)
}
// Config holds all application configuration values.
type Config struct {
ServerPort string `yaml:"server_port"`
@@ -27,7 +40,9 @@ type Config struct {
SlurmUserName string `yaml:"slurm_user_name"`
SlurmJWTKeyPath string `yaml:"slurm_jwt_key_path"`
MySQLDSN string `yaml:"mysql_dsn"`
WorkDirBase string `yaml:"work_dir_base"` // base directory for job work dirs
Log LogConfig `yaml:"log"`
Minio MinioConfig `yaml:"minio"`
}
// Load reads a YAML configuration file and returns a parsed Config.
@@ -47,5 +62,18 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("parse config file %s: %w", path, err)
}
if cfg.Minio.ChunkSize == 0 {
cfg.Minio.ChunkSize = 16 << 20 // 16MB
}
if cfg.Minio.MaxFileSize == 0 {
cfg.Minio.MaxFileSize = 50 << 30 // 50GB
}
if cfg.Minio.MinChunkSize == 0 {
cfg.Minio.MinChunkSize = 5 << 20 // 5MB
}
if cfg.Minio.SessionTTL == 0 {
cfg.Minio.SessionTTL = 48
}
return &cfg, nil
}

View File

@@ -254,3 +254,135 @@ log:
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
}
}
func TestLoadWithMinioConfig(t *testing.T) {
content := []byte(`server_port: "9090"
slurm_api_url: "http://slurm.example.com:6820"
slurm_user_name: "admin"
slurm_jwt_key_path: "/etc/slurm/jwt.key"
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
minio:
endpoint: "minio.example.com:9000"
access_key: "myaccesskey"
secret_key: "mysecretkey"
bucket: "test-bucket"
use_ssl: true
chunk_size: 33554432
max_file_size: 107374182400
min_chunk_size: 10485760
session_ttl: 24
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Minio.Endpoint != "minio.example.com:9000" {
t.Errorf("Minio.Endpoint = %q, want %q", cfg.Minio.Endpoint, "minio.example.com:9000")
}
if cfg.Minio.AccessKey != "myaccesskey" {
t.Errorf("Minio.AccessKey = %q, want %q", cfg.Minio.AccessKey, "myaccesskey")
}
if cfg.Minio.SecretKey != "mysecretkey" {
t.Errorf("Minio.SecretKey = %q, want %q", cfg.Minio.SecretKey, "mysecretkey")
}
if cfg.Minio.Bucket != "test-bucket" {
t.Errorf("Minio.Bucket = %q, want %q", cfg.Minio.Bucket, "test-bucket")
}
if cfg.Minio.UseSSL != true {
t.Errorf("Minio.UseSSL = %v, want %v", cfg.Minio.UseSSL, true)
}
if cfg.Minio.ChunkSize != 33554432 {
t.Errorf("Minio.ChunkSize = %d, want %d", cfg.Minio.ChunkSize, 33554432)
}
if cfg.Minio.MaxFileSize != 107374182400 {
t.Errorf("Minio.MaxFileSize = %d, want %d", cfg.Minio.MaxFileSize, 107374182400)
}
if cfg.Minio.MinChunkSize != 10485760 {
t.Errorf("Minio.MinChunkSize = %d, want %d", cfg.Minio.MinChunkSize, 10485760)
}
if cfg.Minio.SessionTTL != 24 {
t.Errorf("Minio.SessionTTL = %d, want %d", cfg.Minio.SessionTTL, 24)
}
}
func TestLoadWithoutMinioConfig(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Minio.Endpoint != "" {
t.Errorf("Minio.Endpoint = %q, want empty string", cfg.Minio.Endpoint)
}
if cfg.Minio.AccessKey != "" {
t.Errorf("Minio.AccessKey = %q, want empty string", cfg.Minio.AccessKey)
}
if cfg.Minio.SecretKey != "" {
t.Errorf("Minio.SecretKey = %q, want empty string", cfg.Minio.SecretKey)
}
if cfg.Minio.Bucket != "" {
t.Errorf("Minio.Bucket = %q, want empty string", cfg.Minio.Bucket)
}
if cfg.Minio.UseSSL != false {
t.Errorf("Minio.UseSSL = %v, want false", cfg.Minio.UseSSL)
}
}
func TestLoadMinioDefaults(t *testing.T) {
content := []byte(`server_port: "8080"
slurm_api_url: "http://localhost:6820"
slurm_user_name: "root"
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
minio:
endpoint: "localhost:9000"
access_key: "minioadmin"
secret_key: "minioadmin"
bucket: "uploads"
`)
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, content, 0644); err != nil {
t.Fatalf("write temp config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Minio.ChunkSize != 16<<20 {
t.Errorf("Minio.ChunkSize = %d, want %d", cfg.Minio.ChunkSize, 16<<20)
}
if cfg.Minio.MaxFileSize != 50<<30 {
t.Errorf("Minio.MaxFileSize = %d, want %d", cfg.Minio.MaxFileSize, 50<<30)
}
if cfg.Minio.MinChunkSize != 5<<20 {
t.Errorf("Minio.MinChunkSize = %d, want %d", cfg.Minio.MinChunkSize, 5<<20)
}
if cfg.Minio.SessionTTL != 48 {
t.Errorf("Minio.SessionTTL = %d, want %d", cfg.Minio.SessionTTL, 48)
}
}

View File

@@ -0,0 +1,174 @@
package handler
import (
"encoding/json"
"errors"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
type ApplicationHandler struct {
appSvc *service.ApplicationService
logger *zap.Logger
}
func NewApplicationHandler(appSvc *service.ApplicationService, logger *zap.Logger) *ApplicationHandler {
return &ApplicationHandler{appSvc: appSvc, logger: logger}
}
func (h *ApplicationHandler) ListApplications(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
apps, total, err := h.appSvc.ListApplications(c.Request.Context(), page, pageSize)
if err != nil {
h.logger.Error("failed to list applications", zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, gin.H{
"applications": apps,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *ApplicationHandler) CreateApplication(c *gin.Context) {
var req model.CreateApplicationRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create application", zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if req.Name == "" || req.ScriptTemplate == "" {
h.logger.Warn("missing required fields for create application")
server.BadRequest(c, "name and script_template are required")
return
}
if len(req.Parameters) == 0 {
req.Parameters = json.RawMessage(`[]`)
}
id, err := h.appSvc.CreateApplication(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to create application", zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("application created", zap.Int64("id", id))
server.Created(c, gin.H{"id": id})
}
func (h *ApplicationHandler) GetApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
app, err := h.appSvc.GetApplication(c.Request.Context(), id)
if err != nil {
h.logger.Error("failed to get application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if app == nil {
h.logger.Warn("application not found", zap.Int64("id", id))
server.NotFound(c, "application not found")
return
}
server.OK(c, app)
}
func (h *ApplicationHandler) UpdateApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id for update", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
var req model.UpdateApplicationRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for update application", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if err := h.appSvc.UpdateApplication(c.Request.Context(), id, &req); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
h.logger.Warn("application not found for update", zap.Int64("id", id))
server.NotFound(c, "application not found")
return
}
h.logger.Error("failed to update application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, "failed to update application")
return
}
h.logger.Info("application updated", zap.Int64("id", id))
server.OK(c, gin.H{"message": "application updated"})
}
func (h *ApplicationHandler) DeleteApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.appSvc.DeleteApplication(c.Request.Context(), id); err != nil {
h.logger.Error("failed to delete application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("application deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "application deleted"})
}
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此端点不再使用。
/* func (h *ApplicationHandler) SubmitApplication(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid application id for submit", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
var req model.ApplicationSubmitRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for submit application", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
resp, err := h.appSvc.SubmitFromApplication(c.Request.Context(), id, req.Values)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not found") {
h.logger.Warn("application not found for submit", zap.Int64("id", id))
server.NotFound(c, errStr)
return
}
if strings.Contains(errStr, "validation") {
h.logger.Warn("application submit validation failed", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to submit application", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("application submitted", zap.Int64("id", id), zap.Int32("job_id", resp.JobID))
server.Created(c, resp)
} */

View File

@@ -0,0 +1,642 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
func itoa(id int64) string {
return fmt.Sprintf("%d", id)
}
func setupApplicationHandler() (*ApplicationHandler, *gorm.DB) {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, nil, "", zap.NewNop())
h := NewApplicationHandler(appSvc, zap.NewNop())
return h, db
}
func setupApplicationRouter(h *ApplicationHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
apps := v1.Group("/applications")
apps.GET("", h.ListApplications)
apps.POST("", h.CreateApplication)
apps.GET("/:id", h.GetApplication)
apps.PUT("/:id", h.UpdateApplication)
apps.DELETE("/:id", h.DeleteApplication)
// apps.POST("/:id/submit", h.SubmitApplication) // [已禁用] 已被 POST /tasks 取代
return r
}
func setupApplicationHandlerWithSlurm(slurmHandler http.HandlerFunc) (*ApplicationHandler, func()) {
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
jobSvc := service.NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
h := NewApplicationHandler(appSvc, zap.NewNop())
return h, srv.Close
}
func setupApplicationHandlerWithObserver() (*ApplicationHandler, *gorm.DB, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, nil, "", l)
return NewApplicationHandler(appSvc, l), db, recorded
}
func createTestApplication(h *ApplicationHandler, r *gin.Engine) int64 {
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
return int64(data["id"].(float64))
}
// ---- CRUD Tests ----
func TestCreateApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "my-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if _, ok := data["id"]; !ok {
t.Fatal("expected id in response data")
}
}
func TestCreateApplication_MissingName(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestCreateApplication_MissingScriptTemplate(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "my-app",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestCreateApplication_EmptyParameters(t *testing.T) {
h, db := setupApplicationHandler()
r := setupApplicationRouter(h)
body := `{"name":"empty-params-app","script_template":"#!/bin/bash\necho hello"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var app model.Application
db.First(&app)
if string(app.Parameters) != "[]" {
t.Fatalf("expected parameters to default to [], got %s", string(app.Parameters))
}
}
func TestListApplications_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
createTestApplication(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["total"].(float64) < 1 {
t.Fatal("expected at least 1 application")
}
}
func TestListApplications_Pagination(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
for i := 0; i < 5; i++ {
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: fmt.Sprintf("app-%d", i),
ScriptTemplate: "#!/bin/bash\necho " + fmt.Sprintf("app-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications?page=1&page_size=2", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 5 {
t.Fatalf("expected total=5, got %v", data["total"])
}
if data["page"].(float64) != 1 {
t.Fatalf("expected page=1, got %v", data["page"])
}
if data["page_size"].(float64) != 2 {
t.Fatalf("expected page_size=2, got %v", data["page_size"])
}
}
func TestGetApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["name"] != "test-app" {
t.Fatalf("expected name=test-app, got %v", data["name"])
}
}
func TestGetApplication_NotFound(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestGetApplication_InvalidID(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestUpdateApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
newName := "updated-app"
body, _ := json.Marshal(model.UpdateApplicationRequest{
Name: &newName,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/"+itoa(id), bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateApplication_NotFound(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
newName := "updated-app"
body, _ := json.Marshal(model.UpdateApplicationRequest{
Name: &newName,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/999", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", w.Code)
}
}
func TestDeleteApplication_Success(t *testing.T) {
h, _ := setupApplicationHandler()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Fatalf("expected 404 after delete, got %d", w2.Code)
}
}
// ---- Submit Tests ----
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitApplication_Success(t *testing.T) {
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"job_id": 12345,
})
})
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
defer cleanup()
r := setupApplicationRouter(h)
params := `[{"name":"COUNT","type":"integer","required":true}]`
body := `{"name":"submit-app","script_template":"#!/bin/bash\necho $COUNT","parameters":` + params + `}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var createResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &createResp)
data := createResp["data"].(map[string]interface{})
id := int64(data["id"].(float64))
submitBody := `{"values":{"COUNT":"5"}}`
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
req2.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w2.Code, w2.Body.String())
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitApplication_AppNotFound(t *testing.T) {
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
defer cleanup()
r := setupApplicationRouter(h)
submitBody := `{"values":{"COUNT":"5"}}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/999/submit", strings.NewReader(submitBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitApplication_ValidationFail(t *testing.T) {
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{})
})
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
defer cleanup()
r := setupApplicationRouter(h)
params := `[{"name":"COUNT","type":"integer","required":true}]`
body := `{"name":"val-app","script_template":"#!/bin/bash\necho $COUNT","parameters":` + params + `}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var createResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &createResp)
data := createResp["data"].(map[string]interface{})
id := int64(data["id"].(float64))
submitBody := `{"values":{"COUNT":"not-a-number"}}`
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
req2.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w2.Code, w2.Body.String())
}
}
*/
// ---- Logging Tests ----
func TestApplicationLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "log-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application created" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application created' log message")
}
}
func TestApplicationLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/999", nil)
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application not found" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application not found' log message")
}
}
func TestApplicationLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
recorded.TakeAll()
newName := "updated-log-app"
body, _ := json.Marshal(model.UpdateApplicationRequest{
Name: &newName,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/"+itoa(id), bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application updated" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application updated' log message")
}
}
func TestApplicationLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
id := createTestApplication(h, r)
recorded.TakeAll()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/applications/"+itoa(id), nil)
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application deleted" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application deleted' log message")
}
}
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestApplicationLogging_SubmitSuccess_LogsInfoWithID(t *testing.T) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{
"job_id": 42,
})
}))
defer slurmSrv.Close()
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Application{})
jobSvc := service.NewJobService(client, l)
appStore := store.NewApplicationStore(db)
appSvc := service.NewApplicationService(appStore, jobSvc, "", l)
h := NewApplicationHandler(appSvc, l)
r := setupApplicationRouter(h)
params := `[{"name":"X","type":"string","required":false}]`
body := `{"name":"sub-log-app","script_template":"#!/bin/bash\necho $X","parameters":` + params + `}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var createResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &createResp)
data := createResp["data"].(map[string]interface{})
id := int64(data["id"].(float64))
recorded.TakeAll()
submitBody := `{"values":{"X":"val"}}`
w2 := httptest.NewRecorder()
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
req2.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w2, req2)
found := false
for _, entry := range recorded.All() {
if entry.Message == "application submitted" {
found = true
break
}
}
if !found {
t.Fatal("expected 'application submitted' log message")
}
}
*/
func TestApplicationLogging_CreateBadRequest_LogsWarn(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(`{"name":""}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
found := false
for _, entry := range recorded.All() {
if entry.Message == "invalid request body for create application" {
found = true
break
}
}
if !found {
t.Fatal("expected 'invalid request body for create application' log message")
}
}
func TestApplicationLogging_LogsDoNotContainApplicationContent(t *testing.T) {
h, _, recorded := setupApplicationHandlerWithObserver()
r := setupApplicationRouter(h)
body, _ := json.Marshal(model.CreateApplicationRequest{
Name: "secret-app",
ScriptTemplate: "#!/bin/bash\necho secret_password_here",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
for _, entry := range recorded.All() {
msg := entry.Message
if strings.Contains(msg, "secret_password_here") {
t.Fatalf("log message contains application content: %s", msg)
}
for _, field := range entry.Context {
if strings.Contains(field.String, "secret_password_here") {
t.Fatalf("log field contains application content: %s", field.String)
}
}
}
}

View File

@@ -319,8 +319,8 @@ func TestClusterHandler_GetNodes_Success_NoLogs(t *testing.T) {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
@@ -370,10 +370,10 @@ func TestClusterHandler_GetNode_NotFound_LogsWarn(t *testing.T) {
t.Fatalf("expected 404, got %d", w.Code)
}
if recorded.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
if recorded.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", recorded.Len())
}
entry := recorded.All()[0]
entry := recorded.All()[2]
if entry.Level != zapcore.WarnLevel {
t.Fatalf("expected Warn level, got %v", entry.Level)
}
@@ -405,8 +405,8 @@ func TestClusterHandler_GetNode_Success_NoLogs(t *testing.T) {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
@@ -458,8 +458,8 @@ func TestClusterHandler_GetPartitions_Success_NoLogs(t *testing.T) {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
@@ -509,10 +509,10 @@ func TestClusterHandler_GetPartition_NotFound_LogsWarn(t *testing.T) {
t.Fatalf("expected 404, got %d", w.Code)
}
if recorded.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
if recorded.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", recorded.Len())
}
entry := recorded.All()[0]
entry := recorded.All()[2]
if entry.Level != zapcore.WarnLevel {
t.Fatalf("expected Warn level, got %v", entry.Level)
}
@@ -559,8 +559,8 @@ func TestClusterHandler_GetPartition_Success_NoLogs(t *testing.T) {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}
@@ -617,8 +617,8 @@ func TestClusterHandler_GetDiag_Success_NoLogs(t *testing.T) {
t.Fatalf("expected 200, got %d", w.Code)
}
if recorded.Len() != 0 {
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
if recorded.Len() != 2 {
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
}
}

View File

@@ -0,0 +1,138 @@
package handler
import (
"context"
"io"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type fileServiceProvider interface {
ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error)
GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error)
DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error)
DeleteFile(ctx context.Context, fileID int64) error
}
type FileHandler struct {
svc fileServiceProvider
logger *zap.Logger
}
func NewFileHandler(svc *service.FileService, logger *zap.Logger) *FileHandler {
return &FileHandler{svc: svc, logger: logger}
}
func (h *FileHandler) ListFiles(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
var folderID *int64
if v := c.Query("folder_id"); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil {
h.logger.Warn("invalid folder_id", zap.String("folder_id", v))
server.BadRequest(c, "invalid folder_id")
return
}
folderID = &id
}
search := c.Query("search")
files, total, err := h.svc.ListFiles(c.Request.Context(), folderID, page, pageSize, search)
if err != nil {
h.logger.Error("failed to list files", zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, model.ListFilesResponse{
Files: files,
Total: total,
Page: page,
PageSize: pageSize,
})
}
func (h *FileHandler) GetFile(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid file id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
file, blob, err := h.svc.GetFileMetadata(c.Request.Context(), id)
if err != nil {
h.logger.Error("failed to get file", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
resp := model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: file.BlobSHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}
server.OK(c, resp)
}
func (h *FileHandler) DownloadFile(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid file id for download", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
rangeHeader := c.GetHeader("Range")
reader, file, blob, start, end, err := h.svc.DownloadFile(c.Request.Context(), id, rangeHeader)
if err != nil {
h.logger.Error("failed to download file", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if rangeHeader != "" {
server.StreamRange(c, reader, start, end, blob.FileSize, blob.MimeType)
} else {
server.StreamFile(c, reader, file.Name, blob.FileSize, blob.MimeType)
}
}
func (h *FileHandler) DeleteFile(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid file id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.svc.DeleteFile(c.Request.Context(), id); err != nil {
h.logger.Error("failed to delete file", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("file deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "file deleted"})
}

View File

@@ -0,0 +1,369 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/model"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type mockFileService struct {
listFilesFn func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error)
getFileMetadataFn func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error)
downloadFileFn func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error)
deleteFileFn func(ctx context.Context, fileID int64) error
}
func (m *mockFileService) ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return m.listFilesFn(ctx, folderID, page, pageSize, search)
}
func (m *mockFileService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
return m.getFileMetadataFn(ctx, fileID)
}
func (m *mockFileService) DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
return m.downloadFileFn(ctx, fileID, rangeHeader)
}
func (m *mockFileService) DeleteFile(ctx context.Context, fileID int64) error {
return m.deleteFileFn(ctx, fileID)
}
type fileHandlerSetup struct {
handler *FileHandler
mock *mockFileService
router *gin.Engine
}
func newFileHandlerSetup() *fileHandlerSetup {
gin.SetMode(gin.TestMode)
mock := &mockFileService{}
h := &FileHandler{
svc: mock,
logger: zap.NewNop(),
}
r := gin.New()
v1 := r.Group("/api/v1")
files := v1.Group("/files")
files.GET("", h.ListFiles)
files.GET("/:id", h.GetFile)
files.GET("/:id/download", h.DownloadFile)
files.DELETE("/:id", h.DeleteFile)
return &fileHandlerSetup{handler: h, mock: mock, router: r}
}
// ---- ListFiles Tests ----
func TestListFiles_Empty(t *testing.T) {
s := newFileHandlerSetup()
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return []model.FileResponse{}, 0, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
files := data["files"].([]interface{})
if len(files) != 0 {
t.Fatalf("expected empty files list, got %d", len(files))
}
if data["total"].(float64) != 0 {
t.Fatalf("expected total=0, got %v", data["total"])
}
}
func TestListFiles_WithFiles(t *testing.T) {
s := newFileHandlerSetup()
now := time.Now()
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return []model.FileResponse{
{ID: 1, Name: "a.txt", Size: 100, MimeType: "text/plain", SHA256: "abc123", CreatedAt: now, UpdatedAt: now},
{ID: 2, Name: "b.pdf", Size: 200, MimeType: "application/pdf", SHA256: "def456", CreatedAt: now, UpdatedAt: now},
}, 2, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files?page=1&page_size=10", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
files := data["files"].([]interface{})
if len(files) != 2 {
t.Fatalf("expected 2 files, got %d", len(files))
}
if data["total"].(float64) != 2 {
t.Fatalf("expected total=2, got %v", data["total"])
}
if data["page"].(float64) != 1 {
t.Fatalf("expected page=1, got %v", data["page"])
}
if data["page_size"].(float64) != 10 {
t.Fatalf("expected page_size=10, got %v", data["page_size"])
}
}
func TestListFiles_WithFolderID(t *testing.T) {
s := newFileHandlerSetup()
var capturedFolderID *int64
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
capturedFolderID = folderID
return []model.FileResponse{}, 0, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files?folder_id=5", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if capturedFolderID == nil || *capturedFolderID != 5 {
t.Fatalf("expected folder_id=5, got %v", capturedFolderID)
}
}
func TestListFiles_ServiceError(t *testing.T) {
s := newFileHandlerSetup()
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
return nil, 0, fmt.Errorf("db error")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", w.Code)
}
}
// ---- GetFile Tests ----
func TestGetFile_Found(t *testing.T) {
s := newFileHandlerSetup()
now := time.Now()
s.mock.getFileMetadataFn = func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
return &model.File{
ID: 1, Name: "test.txt", BlobSHA256: "abc123", CreatedAt: now, UpdatedAt: now,
}, &model.FileBlob{
ID: 1, SHA256: "abc123", FileSize: 1024, MimeType: "text/plain", CreatedAt: now,
}, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
}
func TestGetFile_NotFound(t *testing.T) {
s := newFileHandlerSetup()
s.mock.getFileMetadataFn = func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
return nil, nil, fmt.Errorf("file not found: 999")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/999", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetFile_InvalidID(t *testing.T) {
s := newFileHandlerSetup()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/abc", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
// ---- DownloadFile Tests ----
func TestDownloadFile_Full(t *testing.T) {
s := newFileHandlerSetup()
content := "hello world file content"
now := time.Now()
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
reader := io.NopCloser(strings.NewReader(content))
return reader,
&model.File{ID: 1, Name: "test.txt", CreatedAt: now},
&model.FileBlob{FileSize: int64(len(content)), MimeType: "text/plain"},
0, int64(len(content)) - 1, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1/download", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
// Check streaming headers
if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "test.txt") {
t.Fatalf("expected Content-Disposition to contain 'test.txt', got %s", got)
}
if got := w.Header().Get("Content-Type"); got != "text/plain" {
t.Fatalf("expected Content-Type=text/plain, got %s", got)
}
if got := w.Header().Get("Accept-Ranges"); got != "bytes" {
t.Fatalf("expected Accept-Ranges=bytes, got %s", got)
}
if w.Body.String() != content {
t.Fatalf("expected body %q, got %q", content, w.Body.String())
}
}
func TestDownloadFile_WithRange(t *testing.T) {
s := newFileHandlerSetup()
content := "hello world"
now := time.Now()
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
reader := io.NopCloser(strings.NewReader(content[0:5]))
return reader,
&model.File{ID: 1, Name: "test.txt", CreatedAt: now},
&model.FileBlob{FileSize: int64(len(content)), MimeType: "text/plain"},
0, 4, nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1/download", nil)
req.Header.Set("Range", "bytes=0-4")
s.router.ServeHTTP(w, req)
if w.Code != http.StatusPartialContent {
t.Fatalf("expected 206, got %d: %s", w.Code, w.Body.String())
}
// Check Content-Range header
if got := w.Header().Get("Content-Range"); got != "bytes 0-4/11" {
t.Fatalf("expected Content-Range 'bytes 0-4/11', got %s", got)
}
if got := w.Header().Get("Content-Type"); got != "text/plain" {
t.Fatalf("expected Content-Type=text/plain, got %s", got)
}
if got := w.Header().Get("Content-Length"); got != "5" {
t.Fatalf("expected Content-Length=5, got %s", got)
}
}
func TestDownloadFile_InvalidID(t *testing.T) {
s := newFileHandlerSetup()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/abc/download", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestDownloadFile_ServiceError(t *testing.T) {
s := newFileHandlerSetup()
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
return nil, nil, nil, 0, 0, fmt.Errorf("file not found: 999")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/999/download", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
}
// ---- DeleteFile Tests ----
func TestDeleteFile_Success(t *testing.T) {
s := newFileHandlerSetup()
s.mock.deleteFileFn = func(ctx context.Context, fileID int64) error {
return nil
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/1", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["message"] != "file deleted" {
t.Fatalf("expected message 'file deleted', got %v", data["message"])
}
}
func TestDeleteFile_InvalidID(t *testing.T) {
s := newFileHandlerSetup()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/abc", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestDeleteFile_ServiceError(t *testing.T) {
s := newFileHandlerSetup()
s.mock.deleteFileFn = func(ctx context.Context, fileID int64) error {
return fmt.Errorf("file not found: 999")
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/999", nil)
s.router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
}
}

View File

@@ -0,0 +1,133 @@
package handler
import (
"context"
"strconv"
"strings"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type folderServiceProvider interface {
CreateFolder(ctx context.Context, name string, parentID *int64) (*model.FolderResponse, error)
GetFolder(ctx context.Context, id int64) (*model.FolderResponse, error)
ListFolders(ctx context.Context, parentID *int64) ([]model.FolderResponse, error)
DeleteFolder(ctx context.Context, id int64) error
}
type FolderHandler struct {
svc folderServiceProvider
logger *zap.Logger
}
func NewFolderHandler(svc *service.FolderService, logger *zap.Logger) *FolderHandler {
return &FolderHandler{svc: svc, logger: logger}
}
func (h *FolderHandler) CreateFolder(c *gin.Context) {
var req model.CreateFolderRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create folder", zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if req.Name == "" {
h.logger.Warn("missing folder name")
server.BadRequest(c, "name is required")
return
}
resp, err := h.svc.CreateFolder(c.Request.Context(), req.Name, req.ParentID)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "invalid folder name") || strings.Contains(errStr, "cannot be") {
h.logger.Warn("invalid folder name", zap.String("name", req.Name), zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to create folder", zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("folder created", zap.Int64("id", resp.ID))
server.Created(c, resp)
}
func (h *FolderHandler) GetFolder(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid folder id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
resp, err := h.svc.GetFolder(c.Request.Context(), id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
h.logger.Warn("folder not found", zap.Int64("id", id))
server.NotFound(c, "folder not found")
return
}
h.logger.Error("failed to get folder", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}
func (h *FolderHandler) ListFolders(c *gin.Context) {
var parentID *int64
if q := c.Query("parent_id"); q != "" {
pid, err := strconv.ParseInt(q, 10, 64)
if err != nil {
h.logger.Warn("invalid parent_id query param", zap.String("parent_id", q))
server.BadRequest(c, "invalid parent_id")
return
}
parentID = &pid
}
folders, err := h.svc.ListFolders(c.Request.Context(), parentID)
if err != nil {
h.logger.Error("failed to list folders", zap.Error(err))
server.InternalError(c, err.Error())
return
}
if folders == nil {
folders = []model.FolderResponse{}
}
server.OK(c, folders)
}
func (h *FolderHandler) DeleteFolder(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid folder id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.svc.DeleteFolder(c.Request.Context(), id); err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not empty") {
h.logger.Warn("cannot delete non-empty folder", zap.Int64("id", id))
server.BadRequest(c, "folder is not empty")
return
}
if strings.Contains(errStr, "not found") {
h.logger.Warn("folder not found for delete", zap.Int64("id", id))
server.NotFound(c, "folder not found")
return
}
h.logger.Error("failed to delete folder", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("folder deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "folder deleted"})
}

View File

@@ -0,0 +1,206 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
func setupFolderHandler() (*FolderHandler, *gorm.DB) {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
db.AutoMigrate(&model.Folder{}, &model.File{})
folderStore := store.NewFolderStore(db)
fileStore := store.NewFileStore(db)
svc := service.NewFolderService(folderStore, fileStore, zap.NewNop())
h := NewFolderHandler(svc, zap.NewNop())
return h, db
}
func setupFolderRouter(h *FolderHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
folders := v1.Group("/files/folders")
folders.POST("", h.CreateFolder)
folders.GET("/:id", h.GetFolder)
folders.GET("", h.ListFolders)
folders.DELETE("/:id", h.DeleteFolder)
return r
}
func createTestFolder(h *FolderHandler, r *gin.Engine) int64 {
body, _ := json.Marshal(model.CreateFolderRequest{Name: "test-folder"})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
return int64(data["id"].(float64))
}
func TestCreateFolder_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
body, _ := json.Marshal(model.CreateFolderRequest{Name: "my-folder"})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["name"] != "my-folder" {
t.Fatalf("expected name=my-folder, got %v", data["name"])
}
if data["path"] != "/my-folder/" {
t.Fatalf("expected path=/my-folder/, got %v", data["path"])
}
}
func TestCreateFolder_PathTraversal(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
body, _ := json.Marshal(model.CreateFolderRequest{Name: ".."})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestCreateFolder_MissingName(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
body, _ := json.Marshal(map[string]interface{}{})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetFolder_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
id := createTestFolder(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["name"] != "test-folder" {
t.Fatalf("expected name=test-folder, got %v", data["name"])
}
}
func TestGetFolder_NotFound(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestListFolders_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
createTestFolder(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].([]interface{})
if len(data) < 1 {
t.Fatal("expected at least 1 folder")
}
}
func TestDeleteFolder_Success(t *testing.T) {
h, _ := setupFolderHandler()
r := setupFolderRouter(h)
id := createTestFolder(h, r)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/folders/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestDeleteFolder_NonEmpty(t *testing.T) {
h, db := setupFolderHandler()
r := setupFolderRouter(h)
parentID := createTestFolder(h, r)
child := &model.Folder{
Name: "child-folder",
ParentID: &parentID,
Path: "/test-folder/child-folder/",
}
db.Create(child)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/folders/"+itoa(parentID), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}

View File

@@ -46,16 +46,30 @@ func (h *JobHandler) SubmitJob(c *gin.Context) {
server.Created(c, resp)
}
// GetJobs handles GET /api/v1/jobs.
// GetJobs handles GET /api/v1/jobs with pagination.
func (h *JobHandler) GetJobs(c *gin.Context) {
jobs, err := h.jobSvc.GetJobs(c.Request.Context())
var query model.JobListQuery
if err := c.ShouldBindQuery(&query); err != nil {
h.logger.Warn("bad request", zap.String("method", "GetJobs"), zap.String("error", "invalid query params"))
server.BadRequest(c, "invalid query params")
return
}
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 {
query.PageSize = 20
}
resp, err := h.jobSvc.GetJobs(c.Request.Context(), &query)
if err != nil {
h.logger.Error("handler error", zap.String("method", "GetJobs"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, jobs)
server.OK(c, resp)
}
// GetJob handles GET /api/v1/jobs/:id.

View File

@@ -188,7 +188,7 @@ func TestGetJobs_Success(t *testing.T) {
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs?page=1&page_size=10", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
@@ -202,6 +202,93 @@ func TestGetJobs_Success(t *testing.T) {
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
jobs := data["jobs"].([]interface{})
if len(jobs) != 2 {
t.Fatalf("expected 2 jobs, got %d", len(jobs))
}
if int(data["total"].(float64)) != 2 {
t.Errorf("expected total=2, got %v", data["total"])
}
if int(data["page"].(float64)) != 1 {
t.Errorf("expected page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 10 {
t.Errorf("expected page_size=10, got %v", data["page_size"])
}
}
func TestGetJobs_Pagination(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
Jobs: []slurm.JobInfo{
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("job1")},
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("job2")},
{JobID: slurm.Ptr(int32(3)), Name: slurm.Ptr("job3")},
},
})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs?page=2&page_size=1", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
jobs := data["jobs"].([]interface{})
if len(jobs) != 1 {
t.Fatalf("expected 1 job on page 2, got %d", len(jobs))
}
if int(data["total"].(float64)) != 3 {
t.Errorf("expected total=3, got %v", data["total"])
}
jobData := jobs[0].(map[string]interface{})
if int(jobData["job_id"].(float64)) != 2 {
t.Errorf("expected job_id=2 on page 2, got %v", jobData["job_id"])
}
}
func TestGetJobs_DefaultPagination(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
})
srv, handler := setupJobHandler(mux)
defer srv.Close()
router := setupJobRouter(handler)
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if int(data["page"].(float64)) != 1 {
t.Errorf("expected default page=1, got %v", data["page"])
}
if int(data["page_size"].(float64)) != 20 {
t.Errorf("expected default page_size=20, got %v", data["page_size"])
}
}
func TestGetJob_Success(t *testing.T) {

View File

@@ -0,0 +1,98 @@
package handler
import (
"context"
"strings"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type taskServiceProvider interface {
SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error)
ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error)
}
type TaskHandler struct {
svc taskServiceProvider
logger *zap.Logger
}
func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler {
return &TaskHandler{svc: svc, logger: logger}
}
func (h *TaskHandler) CreateTask(c *gin.Context) {
var req model.CreateTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create task", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "not found") {
h.logger.Warn("task submit target not found", zap.Error(err))
server.NotFound(c, errStr)
return
}
if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") {
h.logger.Warn("task submit validation failed", zap.Error(err))
server.BadRequest(c, errStr)
return
}
h.logger.Error("failed to create task", zap.Error(err))
server.InternalError(c, errStr)
return
}
h.logger.Info("task created", zap.Int64("id", taskID))
server.Created(c, gin.H{"id": taskID})
}
func (h *TaskHandler) ListTasks(c *gin.Context) {
var query model.TaskListQuery
_ = c.ShouldBindQuery(&query)
if query.Page < 1 {
query.Page = 1
}
if query.PageSize < 1 || query.PageSize > 100 {
query.PageSize = 10
}
tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query)
if err != nil {
h.logger.Error("failed to list tasks", zap.Error(err))
server.InternalError(c, err.Error())
return
}
responses := make([]model.TaskResponse, 0, len(tasks))
for i := range tasks {
responses = append(responses, model.TaskResponse{
ID: tasks[i].ID,
TaskName: tasks[i].TaskName,
AppID: tasks[i].AppID,
AppName: tasks[i].AppName,
Status: tasks[i].Status,
CurrentStep: tasks[i].CurrentStep,
RetryCount: tasks[i].RetryCount,
SlurmJobID: tasks[i].SlurmJobID,
WorkDir: tasks[i].WorkDir,
ErrorMessage: tasks[i].ErrorMessage,
CreatedAt: tasks[i].CreatedAt,
UpdatedAt: tasks[i].UpdatedAt,
})
}
server.OK(c, model.TaskListResponse{
Items: responses,
Total: total,
})
}

View File

@@ -0,0 +1,286 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/service"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
var taskDBCounter atomic.Int64
func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) {
t.Helper()
dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1)))
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
if err != nil {
t.Fatalf("open db: %v", err)
}
db.AutoMigrate(&model.Task{}, &model.Application{})
t.Cleanup(func() { os.Remove(dbFile) })
taskStore := store.NewTaskStore(db)
appStore := store.NewApplicationStore(db)
var jobSvc *service.JobService
if slurmSrv != nil {
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
jobSvc = service.NewJobService(client, zap.NewNop())
}
workDir := filepath.Join(t.TempDir(), "work")
taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop())
h := NewTaskHandler(taskSvc, zap.NewNop())
return h, db
}
func setupTaskRouter(h *TaskHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
tasks := v1.Group("/tasks")
tasks.POST("", h.CreateTask)
tasks.GET("", h.ListTasks)
return r
}
func createTestAppForTask(db *gorm.DB) int64 {
app := &model.Application{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\necho hello",
Parameters: json.RawMessage(`[]`),
}
db.Create(app)
return app.ID
}
// ---- CreateTask Tests ----
func TestTaskHandler_CreateTask_Success(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: "my-task",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if _, ok := data["id"]; !ok {
t.Fatal("expected id in response data")
}
}
func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) {
h, _ := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
body := `{"task_name":"no-app"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) {
h, _ := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json")))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
// ---- ListTasks Tests ----
func TestTaskHandler_ListTasks_Pagination(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
for i := 0; i < 5; i++ {
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: fmt.Sprintf("task-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
// Wait for async processing
time.Sleep(200 * time.Millisecond)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 5 {
t.Fatalf("expected total=5, got %v", data["total"])
}
items := data["items"].([]interface{})
if len(items) != 3 {
t.Fatalf("expected 3 items, got %d", len(items))
}
}
func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)})
}))
defer slurmSrv.Close()
h, db := setupTaskHandler(t, slurmSrv)
r := setupTaskRouter(h)
appID := createTestAppForTask(db)
taskSvc := h.svc.(*service.TaskService)
ctx := context.Background()
taskSvc.StartProcessor(ctx)
defer taskSvc.StopProcessor()
for i := 0; i < 3; i++ {
body, _ := json.Marshal(model.CreateTaskRequest{
AppID: appID,
TaskName: fmt.Sprintf("filter-task-%d", i),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
}
// Wait for async processing
time.Sleep(200 * time.Millisecond)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
items := data["items"].([]interface{})
for _, item := range items {
m := item.(map[string]interface{})
if m["status"] != "queued" {
t.Fatalf("expected status=queued, got %v", m["status"])
}
}
}
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
h, db := setupTaskHandler(t, nil)
r := setupTaskRouter(h)
_ = createTestAppForTask(db)
// Directly insert tasks via DB to avoid needing processor
for i := 0; i < 15; i++ {
task := &model.Task{
TaskName: fmt.Sprintf("default-task-%d", i),
AppID: 1,
AppName: "test-app",
Status: model.TaskStatusSubmitted,
SubmittedAt: time.Now(),
}
db.Create(task)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["total"].(float64) != 15 {
t.Fatalf("expected total=15, got %v", data["total"])
}
items := data["items"].([]interface{})
if len(items) != 10 {
t.Fatalf("expected 10 items (default page_size), got %d", len(items))
}
}

View File

@@ -1,146 +0,0 @@
package handler
import (
"errors"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
type TemplateHandler struct {
store *store.TemplateStore
logger *zap.Logger
}
func NewTemplateHandler(s *store.TemplateStore, logger *zap.Logger) *TemplateHandler {
return &TemplateHandler{store: s, logger: logger}
}
func (h *TemplateHandler) ListTemplates(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
templates, total, err := h.store.List(c.Request.Context(), page, pageSize)
if err != nil {
h.logger.Error("failed to list templates", zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, gin.H{
"templates": templates,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *TemplateHandler) CreateTemplate(c *gin.Context) {
var req model.CreateTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for create template", zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if req.Name == "" || req.Script == "" {
h.logger.Warn("missing required fields for create template")
server.BadRequest(c, "name and script are required")
return
}
id, err := h.store.Create(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to create template", zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("template created", zap.Int64("id", id))
server.Created(c, gin.H{"id": id})
}
func (h *TemplateHandler) GetTemplate(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid template id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
tmpl, err := h.store.GetByID(c.Request.Context(), id)
if err != nil {
h.logger.Error("failed to get template", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
if tmpl == nil {
h.logger.Warn("template not found", zap.Int64("id", id))
server.NotFound(c, "template not found")
return
}
server.OK(c, tmpl)
}
func (h *TemplateHandler) UpdateTemplate(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid template id for update", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
var req model.UpdateTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for update template", zap.Int64("id", id), zap.Error(err))
server.BadRequest(c, "invalid request body")
return
}
if err := h.store.Update(c.Request.Context(), id, &req); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
h.logger.Warn("template not found for update", zap.Int64("id", id))
server.NotFound(c, "template not found")
return
}
h.logger.Error("failed to update template", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, "failed to update template")
return
}
h.logger.Info("template updated", zap.Int64("id", id))
server.OK(c, gin.H{"message": "template updated"})
}
func (h *TemplateHandler) DeleteTemplate(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid template id for delete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid id")
return
}
if err := h.store.Delete(c.Request.Context(), id); err != nil {
h.logger.Error("failed to delete template", zap.Int64("id", id), zap.Error(err))
server.InternalError(c, err.Error())
return
}
h.logger.Info("template deleted", zap.Int64("id", id))
server.OK(c, gin.H{"message": "template deleted"})
}

View File

@@ -1,387 +0,0 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/store"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupTemplateHandler() (*TemplateHandler, *gorm.DB) {
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.JobTemplate{})
s := store.NewTemplateStore(db)
h := NewTemplateHandler(s, zap.NewNop())
return h, db
}
func setupTemplateRouter(h *TemplateHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
v1 := r.Group("/api/v1")
templates := v1.Group("/templates")
templates.GET("", h.ListTemplates)
templates.POST("", h.CreateTemplate)
templates.GET("/:id", h.GetTemplate)
templates.PUT("/:id", h.UpdateTemplate)
templates.DELETE("/:id", h.DeleteTemplate)
return r
}
func TestListTemplates_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal", QOS: "high", CPUs: 4, Memory: "4GB"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestCreateTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
body := `{"name":"my-tpl","description":"desc","script":"echo hello","partition":"gpu"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestCreateTemplate_MissingFields(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
body := `{"name":"","script":""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestGetTemplate_NotFound(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestGetTemplate_InvalidID(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
body := `{"name":"updated"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
func TestDeleteTemplate_Success(t *testing.T) {
h, _ := setupTemplateHandler()
r := setupTemplateRouter(h)
// Seed data
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp server.APIResponse
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp.Success {
t.Fatalf("expected success=true, got: %s", w.Body.String())
}
}
// itoa converts int64 to string for URL path construction.
func itoa(id int64) string {
return fmt.Sprintf("%d", id)
}
func setupTemplateHandlerWithObserver() (*TemplateHandler, *gorm.DB, *observer.ObservedLogs) {
core, recorded := observer.New(zapcore.DebugLevel)
l := zap.New(core)
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
db.AutoMigrate(&model.JobTemplate{})
s := store.NewTemplateStore(db)
return NewTemplateHandler(s, l), db, recorded
}
func TestTemplateLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
body := `{"name":"log-tpl","script":"echo hi"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template created").FilterLevelExact(zapcore.InfoLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 info log for 'template created', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
body := `{"name":"updated"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template updated").FilterLevelExact(zapcore.InfoLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 info log for 'template updated', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template deleted").FilterLevelExact(zapcore.InfoLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 info log for 'template deleted', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.FilterMessage("template not found").FilterLevelExact(zapcore.WarnLevel).All()
if len(entries) != 1 {
t.Fatalf("expected 1 warn log for 'template not found', got %d", len(entries))
}
fields := entries[0].ContextMap()
if fields["id"] == nil {
t.Fatal("expected log entry to contain 'id' field")
}
}
func TestTemplateLogging_CreateBadRequest_LogsWarn(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
body := `{"name":"","script":""}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
if len(warnEntries) == 0 {
t.Fatal("expected at least 1 warn log for bad request")
}
}
func TestTemplateLogging_InvalidID_LogsWarn(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
if len(warnEntries) == 0 {
t.Fatal("expected at least 1 warn log for invalid id")
}
}
func TestTemplateLogging_ListSuccess_NoInfoLog(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
// Seed data
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi"})
// Reset recorded logs so the create log doesn't interfere
recorded.TakeAll()
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
infoEntries := recorded.FilterLevelExact(zapcore.InfoLevel).All()
if len(infoEntries) != 0 {
t.Fatalf("expected 0 info logs for list success, got %d: %+v", len(infoEntries), infoEntries)
}
}
func TestTemplateLogging_LogsDoNotContainTemplateContent(t *testing.T) {
h, _, recorded := setupTemplateHandlerWithObserver()
r := setupTemplateRouter(h)
body := `{"name":"secret-name","script":"secret-script","partition":"secret-partition"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
entries := recorded.All()
for _, e := range entries {
logStr := e.Message + " " + fmt.Sprintf("%v", e.ContextMap())
if strings.Contains(logStr, "secret-name") || strings.Contains(logStr, "secret-script") || strings.Contains(logStr, "secret-partition") {
t.Fatalf("log entry contains sensitive template content: %s", logStr)
}
}
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"context"
"io"
"strconv"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// uploadServiceProvider defines the interface for upload operations.
type uploadServiceProvider interface {
InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error)
UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error
CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error)
GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error)
CancelUpload(ctx context.Context, sessionID int64) error
}
// UploadHandler handles HTTP requests for chunked file uploads.
type UploadHandler struct {
svc uploadServiceProvider
logger *zap.Logger
}
// NewUploadHandler creates a new UploadHandler.
func NewUploadHandler(svc *service.UploadService, logger *zap.Logger) *UploadHandler {
return &UploadHandler{svc: svc, logger: logger}
}
// InitUpload initiates a new chunked upload session.
// POST /api/v1/files/uploads
func (h *UploadHandler) InitUpload(c *gin.Context) {
var req model.InitUploadRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn("invalid request body for init upload", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
result, err := h.svc.InitUpload(c.Request.Context(), req)
if err != nil {
h.logger.Error("failed to init upload", zap.Error(err))
server.BadRequest(c, err.Error())
return
}
switch resp := result.(type) {
case model.FileResponse:
server.OK(c, resp)
case model.UploadSessionResponse:
server.Created(c, resp)
default:
server.Created(c, resp)
}
}
// UploadChunk uploads a single chunk of an upload session.
// PUT /api/v1/files/uploads/:id/chunks/:index
func (h *UploadHandler) UploadChunk(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
chunkIndex, err := strconv.Atoi(c.Param("index"))
if err != nil {
h.logger.Warn("invalid chunk index", zap.String("index", c.Param("index")))
server.BadRequest(c, "invalid chunk index")
return
}
file, header, err := c.Request.FormFile("chunk")
if err != nil {
h.logger.Warn("missing chunk file in request", zap.Error(err))
server.BadRequest(c, "missing chunk file")
return
}
defer file.Close()
if err := h.svc.UploadChunk(c.Request.Context(), sessionID, chunkIndex, file, header.Size); err != nil {
h.logger.Error("failed to upload chunk", zap.Int64("session_id", sessionID), zap.Int("chunk_index", chunkIndex), zap.Error(err))
server.BadRequest(c, err.Error())
return
}
server.OK(c, gin.H{"message": "chunk uploaded"})
}
// CompleteUpload finalizes an upload session and assembles the file.
// POST /api/v1/files/uploads/:id/complete
func (h *UploadHandler) CompleteUpload(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id for complete", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
resp, err := h.svc.CompleteUpload(c.Request.Context(), sessionID)
if err != nil {
h.logger.Error("failed to complete upload", zap.Int64("session_id", sessionID), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.Created(c, resp)
}
// GetUploadStatus returns the current status of an upload session.
// GET /api/v1/files/uploads/:id
func (h *UploadHandler) GetUploadStatus(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id for status", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
resp, err := h.svc.GetUploadStatus(c.Request.Context(), sessionID)
if err != nil {
h.logger.Error("failed to get upload status", zap.Int64("session_id", sessionID), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, resp)
}
// CancelUpload cancels and cleans up an upload session.
// DELETE /api/v1/files/uploads/:id
func (h *UploadHandler) CancelUpload(c *gin.Context) {
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
h.logger.Warn("invalid session id for cancel", zap.String("id", c.Param("id")))
server.BadRequest(c, "invalid session id")
return
}
if err := h.svc.CancelUpload(c.Request.Context(), sessionID); err != nil {
h.logger.Error("failed to cancel upload", zap.Int64("session_id", sessionID), zap.Error(err))
server.InternalError(c, err.Error())
return
}
server.OK(c, gin.H{"message": "upload cancelled"})
}

View File

@@ -0,0 +1,307 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"testing"
"time"
"gcy_hpc_server/internal/model"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type mockUploadService struct {
initUploadFn func(ctx context.Context, req model.InitUploadRequest) (interface{}, error)
uploadChunkFn func(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error
completeUploadFn func(ctx context.Context, sessionID int64) (*model.FileResponse, error)
getUploadStatusFn func(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error)
cancelUploadFn func(ctx context.Context, sessionID int64) error
}
func (m *mockUploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return m.initUploadFn(ctx, req)
}
func (m *mockUploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
return m.uploadChunkFn(ctx, sessionID, chunkIndex, reader, size)
}
func (m *mockUploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
return m.completeUploadFn(ctx, sessionID)
}
func (m *mockUploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
return m.getUploadStatusFn(ctx, sessionID)
}
func (m *mockUploadService) CancelUpload(ctx context.Context, sessionID int64) error {
return m.cancelUploadFn(ctx, sessionID)
}
func setupUploadHandler(mock uploadServiceProvider) (*UploadHandler, *gin.Engine) {
gin.SetMode(gin.TestMode)
h := &UploadHandler{svc: mock, logger: zap.NewNop()}
r := gin.New()
v1 := r.Group("/api/v1")
uploads := v1.Group("/files/uploads")
uploads.POST("", h.InitUpload)
uploads.PUT("/:id/chunks/:index", h.UploadChunk)
uploads.POST("/:id/complete", h.CompleteUpload)
uploads.GET("/:id", h.GetUploadStatus)
uploads.DELETE("/:id", h.CancelUpload)
return h, r
}
func TestInitUpload_NewSession(t *testing.T) {
mock := &mockUploadService{
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return model.UploadSessionResponse{
ID: 1,
FileName: req.FileName,
FileSize: req.FileSize,
ChunkSize: 5 * 1024 * 1024,
TotalChunks: 2,
SHA256: req.SHA256,
Status: "pending",
UploadedChunks: []int{},
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
body, _ := json.Marshal(model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 * 1024 * 1024,
SHA256: "abc123",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["status"] != "pending" {
t.Fatalf("expected status=pending, got %v", data["status"])
}
}
func TestInitUpload_DedupHit(t *testing.T) {
mock := &mockUploadService{
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return model.FileResponse{
ID: 42,
Name: req.FileName,
Size: req.FileSize,
SHA256: req.SHA256,
MimeType: "application/octet-stream",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
body, _ := json.Marshal(model.InitUploadRequest{
FileName: "existing.bin",
FileSize: 1024,
SHA256: "existinghash",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["sha256"] != "existinghash" {
t.Fatalf("expected sha256=existinghash, got %v", data["sha256"])
}
}
func TestInitUpload_MissingFields(t *testing.T) {
mock := &mockUploadService{
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
return nil, nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader([]byte(`{}`)))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUploadChunk_Success(t *testing.T) {
var receivedSessionID int64
var receivedChunkIndex int
mock := &mockUploadService{
uploadChunkFn: func(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
receivedSessionID = sessionID
receivedChunkIndex = chunkIndex
readBytes, _ := io.ReadAll(reader)
if len(readBytes) == 0 {
t.Error("expected non-empty chunk data")
}
return nil
},
}
_, r := setupUploadHandler(mock)
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, _ := writer.CreateFormFile("chunk", "test.bin")
fmt.Fprintf(part, "chunk data here")
writer.Close()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/api/v1/files/uploads/1/chunks/0", &buf)
req.Header.Set("Content-Type", writer.FormDataContentType())
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if receivedSessionID != 1 {
t.Fatalf("expected session_id=1, got %d", receivedSessionID)
}
if receivedChunkIndex != 0 {
t.Fatalf("expected chunk_index=0, got %d", receivedChunkIndex)
}
}
func TestCompleteUpload_Success(t *testing.T) {
mock := &mockUploadService{
completeUploadFn: func(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
return &model.FileResponse{
ID: 10,
Name: "completed.bin",
Size: 2048,
SHA256: "completehash",
MimeType: "application/octet-stream",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads/5/complete", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["name"] != "completed.bin" {
t.Fatalf("expected name=completed.bin, got %v", data["name"])
}
}
func TestGetUploadStatus_Success(t *testing.T) {
mock := &mockUploadService{
getUploadStatusFn: func(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
return &model.UploadSessionResponse{
ID: sessionID,
FileName: "status.bin",
FileSize: 4096,
ChunkSize: 2048,
TotalChunks: 2,
SHA256: "statushash",
Status: "uploading",
UploadedChunks: []int{0},
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
}, nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/uploads/3", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if !resp["success"].(bool) {
t.Fatal("expected success=true")
}
data := resp["data"].(map[string]interface{})
if data["status"] != "uploading" {
t.Fatalf("expected status=uploading, got %v", data["status"])
}
}
func TestCancelUpload_Success(t *testing.T) {
var receivedSessionID int64
mock := &mockUploadService{
cancelUploadFn: func(ctx context.Context, sessionID int64) error {
receivedSessionID = sessionID
return nil
},
}
_, r := setupUploadHandler(mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/uploads/7", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if receivedSessionID != 7 {
t.Fatalf("expected session_id=7, got %d", receivedSessionID)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
data := resp["data"].(map[string]interface{})
if data["message"] != "upload cancelled" {
t.Fatalf("expected message='upload cancelled', got %v", data["message"])
}
}

View File

@@ -0,0 +1,76 @@
package model
import (
"encoding/json"
"time"
)
// Parameter type constants for ParameterSchema.Type.
const (
ParamTypeString = "string"
ParamTypeInteger = "integer"
ParamTypeEnum = "enum"
ParamTypeFile = "file"
ParamTypeDirectory = "directory"
ParamTypeBoolean = "boolean"
)
// Application represents a parameterized application definition for HPC job submission.
type Application struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"uniqueIndex;size:255;not null" json:"name"`
Description string `gorm:"type:text" json:"description,omitempty"`
Icon string `gorm:"size:255" json:"icon,omitempty"`
Category string `gorm:"size:255" json:"category,omitempty"`
ScriptTemplate string `gorm:"type:text;not null" json:"script_template"`
Parameters json.RawMessage `gorm:"type:json" json:"parameters,omitempty"`
Scope string `gorm:"size:50;default:'system'" json:"scope,omitempty"`
CreatedBy int64 `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (Application) TableName() string {
return "hpc_applications"
}
// ParameterSchema defines a single parameter in an application's form schema.
type ParameterSchema struct {
Name string `json:"name"`
Label string `json:"label,omitempty"`
Type string `json:"type"`
Required bool `json:"required,omitempty"`
Default string `json:"default,omitempty"`
Options []string `json:"options,omitempty"`
Description string `json:"description,omitempty"`
}
// CreateApplicationRequest is the DTO for creating a new application.
type CreateApplicationRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
Category string `json:"category,omitempty"`
ScriptTemplate string `json:"script_template" binding:"required"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Scope string `json:"scope,omitempty"`
}
// UpdateApplicationRequest is the DTO for updating an existing application.
// All fields are optional. Parameters uses *json.RawMessage to distinguish
// between "not provided" (nil) and "set to empty" (non-nil).
type UpdateApplicationRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Icon *string `json:"icon,omitempty"`
Category *string `json:"category,omitempty"`
ScriptTemplate *string `json:"script_template,omitempty"`
Parameters *json.RawMessage `json:"parameters,omitempty"`
Scope *string `json:"scope,omitempty"`
}
// ApplicationSubmitRequest is the DTO for submitting a job from an application.
// ApplicationID is parsed from the URL :id parameter, not included in the body.
type ApplicationSubmitRequest struct {
Values map[string]string `json:"values" binding:"required"`
}

View File

@@ -1,23 +1,73 @@
package model
// NodeResponse is the simplified API response for a node.
// NodeResponse is the API response for a node.
type NodeResponse struct {
Name string `json:"name"`
State []string `json:"state"`
CPUs int32 `json:"cpus"`
RealMemory int64 `json:"real_memory"`
AllocMem int64 `json:"alloc_memory,omitempty"`
Arch string `json:"architecture,omitempty"`
OS string `json:"operating_system,omitempty"`
// Identity
Name string `json:"name"` // 节点主机名
State []string `json:"state"` // 节点状态 (e.g. ["IDLE"], ["ALLOCATED","COMPLETING"])
Reason string `json:"reason,omitempty"` // 节点 DOWN/DRAIN 的原因
ReasonSetByUser string `json:"reason_set_by_user,omitempty"` // 设置原因的用户
// CPU Resources
CPUs int32 `json:"cpus"` // 总 CPU 核数
AllocCpus *int32 `json:"alloc_cpus,omitempty"` // 已分配 CPU 核数
Cores *int32 `json:"cores,omitempty"` // 物理核心数
Sockets *int32 `json:"sockets,omitempty"` // CPU 插槽数
Threads *int32 `json:"threads,omitempty"` // 每核线程数
CpuLoad *int32 `json:"cpu_load,omitempty"` // CPU 负载 (内核 nice 值乘以 100)
// Memory (MiB)
RealMemory int64 `json:"real_memory"` // 物理内存总量
AllocMemory int64 `json:"alloc_memory,omitempty"` // 已分配内存
FreeMem *int64 `json:"free_mem,omitempty"` // 空闲内存
// Hardware
Arch string `json:"architecture,omitempty"` // 系统架构 (e.g. x86_64)
OS string `json:"operating_system,omitempty"` // 操作系统版本
Gres string `json:"gres,omitempty"` // 可用通用资源 (e.g. "gpu:4")
GresUsed string `json:"gres_used,omitempty"` // 已使用的通用资源 (e.g. "gpu:2")
// Network
Address string `json:"address,omitempty"` // 节点地址 (IP)
Hostname string `json:"hostname,omitempty"` // 节点主机名 (可能与 Name 不同)
// Scheduling
Weight *int32 `json:"weight,omitempty"` // 调度权重
Features string `json:"features,omitempty"` // 节点特性标签 (可修改)
ActiveFeatures string `json:"active_features,omitempty"` // 当前生效的特性标签 (只读)
}
// PartitionResponse is the simplified API response for a partition.
// PartitionResponse is the API response for a partition.
type PartitionResponse struct {
Name string `json:"name"`
State []string `json:"state"`
Nodes string `json:"nodes,omitempty"`
TotalCPUs int32 `json:"total_cpus,omitempty"`
TotalNodes int32 `json:"total_nodes,omitempty"`
MaxTime string `json:"max_time,omitempty"`
Default bool `json:"default,omitempty"`
// Identity
Name string `json:"name"` // 分区名称
State []string `json:"state"` // 分区状态 (e.g. ["UP"], ["DOWN","DRAIN"])
Default bool `json:"default,omitempty"` // 是否为默认分区
// Nodes
Nodes string `json:"nodes,omitempty"` // 分区包含的节点列表
TotalNodes int32 `json:"total_nodes,omitempty"` // 节点总数
// CPUs
TotalCPUs int32 `json:"total_cpus,omitempty"` // CPU 总核数
MaxCPUsPerNode *int32 `json:"max_cpus_per_node,omitempty"` // 每节点最大 CPU 核数
// Limits
MaxTime string `json:"max_time,omitempty"` // 最大运行时间 (分钟,"UNLIMITED" 表示无限)
MaxNodes *int32 `json:"max_nodes,omitempty"` // 单作业最大节点数
MinNodes *int32 `json:"min_nodes,omitempty"` // 单作业最小节点数
DefaultTime string `json:"default_time,omitempty"` // 默认运行时间限制
GraceTime *int32 `json:"grace_time,omitempty"` // 作业抢占后的宽限时间 (秒)
// Priority
Priority *int32 `json:"priority,omitempty"` // 分区内作业优先级因子
// Access Control - QOS
QOSAllowed string `json:"qos_allowed,omitempty"` // 允许使用的 QOS 列表
QOSDeny string `json:"qos_deny,omitempty"` // 禁止使用的 QOS 列表
QOSAssigned string `json:"qos_assigned,omitempty"` // 分区默认分配的 QOS
// Access Control - Accounts
AccountsAllowed string `json:"accounts_allowed,omitempty"` // 允许使用的账户列表
AccountsDeny string `json:"accounts_deny,omitempty"` // 禁止使用的账户列表
}

193
internal/model/file.go Normal file
View File

@@ -0,0 +1,193 @@
package model
import (
"fmt"
"strings"
"time"
"unicode"
"gorm.io/gorm"
)
// FileBlob represents a physical file stored in MinIO, deduplicated by SHA256.
type FileBlob struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
SHA256 string `gorm:"uniqueIndex;size:64;not null" json:"sha256"`
MinioKey string `gorm:"size:255;not null" json:"minio_key"`
FileSize int64 `gorm:"not null" json:"file_size"`
MimeType string `gorm:"size:255" json:"mime_type"`
RefCount int `gorm:"not null;default:0" json:"ref_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (FileBlob) TableName() string {
return "hpc_file_blobs"
}
// File represents a logical file visible to users, backed by a FileBlob.
type File struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"size:255;not null" json:"name"`
FolderID *int64 `gorm:"index" json:"folder_id,omitempty"`
BlobSHA256 string `gorm:"size:64;not null" json:"blob_sha256"`
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
func (File) TableName() string {
return "hpc_files"
}
// Folder represents a directory in the virtual file system.
type Folder struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"size:255;not null" json:"name"`
ParentID *int64 `gorm:"index" json:"parent_id,omitempty"`
Path string `gorm:"uniqueIndex;size:768;not null" json:"path"`
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
func (Folder) TableName() string {
return "hpc_folders"
}
// UploadSession represents an in-progress chunked upload.
// State transitions: pending→uploading, pending→completed(zero-byte), uploading→merging,
// uploading→cancelled, merging→completed, merging→failed, any→expired
type UploadSession struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
FileName string `gorm:"size:255;not null" json:"file_name"`
FileSize int64 `gorm:"not null" json:"file_size"`
ChunkSize int64 `gorm:"not null" json:"chunk_size"`
TotalChunks int `gorm:"not null" json:"total_chunks"`
SHA256 string `gorm:"size:64;not null" json:"sha256"`
FolderID *int64 `gorm:"index" json:"folder_id,omitempty"`
Status string `gorm:"size:20;not null;default:pending" json:"status"`
MinioPrefix string `gorm:"size:255;not null" json:"minio_prefix"`
MimeType string `gorm:"size:255;default:'application/octet-stream'" json:"mime_type"`
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (UploadSession) TableName() string {
return "hpc_upload_sessions"
}
// UploadChunk represents a single chunk of an upload session.
type UploadChunk struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
SessionID int64 `gorm:"not null;uniqueIndex:idx_session_chunk" json:"session_id"`
ChunkIndex int `gorm:"not null;uniqueIndex:idx_session_chunk" json:"chunk_index"`
MinioKey string `gorm:"size:255;not null" json:"minio_key"`
SHA256 string `gorm:"size:64" json:"sha256,omitempty"`
Size int64 `gorm:"not null" json:"size"`
Status string `gorm:"size:20;not null;default:pending" json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (UploadChunk) TableName() string {
return "hpc_upload_chunks"
}
// InitUploadRequest is the DTO for initiating a chunked upload.
type InitUploadRequest struct {
FileName string `json:"file_name" binding:"required"`
FileSize int64 `json:"file_size" binding:"required"`
SHA256 string `json:"sha256" binding:"required"`
FolderID *int64 `json:"folder_id,omitempty"`
ChunkSize *int64 `json:"chunk_size,omitempty"`
MimeType string `json:"mime_type,omitempty"`
}
// CreateFolderRequest is the DTO for creating a new folder.
type CreateFolderRequest struct {
Name string `json:"name" binding:"required"`
ParentID *int64 `json:"parent_id,omitempty"`
}
// UploadSessionResponse is the DTO returned when creating/querying an upload session.
type UploadSessionResponse struct {
ID int64 `json:"id"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
ChunkSize int64 `json:"chunk_size"`
TotalChunks int `json:"total_chunks"`
SHA256 string `json:"sha256"`
Status string `json:"status"`
UploadedChunks []int `json:"uploaded_chunks"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
// FileResponse is the DTO for a file in API responses.
type FileResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
FolderID *int64 `json:"folder_id,omitempty"`
Size int64 `json:"size"`
MimeType string `json:"mime_type"`
SHA256 string `json:"sha256"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// FolderResponse is the DTO for a folder in API responses.
type FolderResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parent_id,omitempty"`
Path string `json:"path"`
FileCount int64 `json:"file_count"`
SubFolderCount int64 `json:"subfolder_count"`
CreatedAt time.Time `json:"created_at"`
}
// ListFilesResponse is the paginated response for listing files.
type ListFilesResponse struct {
Files []FileResponse `json:"files"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// ValidateFileName rejects empty, "..", "/", "\", null bytes, control chars, leading/trailing spaces.
func ValidateFileName(name string) error {
if name == "" {
return fmt.Errorf("file name cannot be empty")
}
if strings.TrimSpace(name) != name {
return fmt.Errorf("file name cannot have leading or trailing spaces")
}
if name == ".." {
return fmt.Errorf("file name cannot be '..'")
}
if strings.Contains(name, "/") || strings.Contains(name, "\\") {
return fmt.Errorf("file name cannot contain '/' or '\\'")
}
for _, r := range name {
if r == 0 {
return fmt.Errorf("file name cannot contain null bytes")
}
if unicode.IsControl(r) {
return fmt.Errorf("file name cannot contain control characters")
}
}
return nil
}
// ValidateFolderName rejects same as ValidateFileName plus ".".
func ValidateFolderName(name string) error {
if name == "." {
return fmt.Errorf("folder name cannot be '.'")
}
return ValidateFileName(name)
}

View File

@@ -2,46 +2,95 @@ package model
// SubmitJobRequest is the API request for submitting a job.
type SubmitJobRequest struct {
Script string `json:"script"`
Partition string `json:"partition,omitempty"`
QOS string `json:"qos,omitempty"`
CPUs int32 `json:"cpus,omitempty"`
Memory string `json:"memory,omitempty"`
TimeLimit string `json:"time_limit,omitempty"`
JobName string `json:"job_name,omitempty"`
Environment map[string]string `json:"environment,omitempty"`
Script string `json:"script"` // 作业脚本内容
Partition string `json:"partition,omitempty"` // 提交到的分区
QOS string `json:"qos,omitempty"` // 使用的 QOS 策略
CPUs int32 `json:"cpus,omitempty"` // 请求的 CPU 核数
Memory string `json:"memory,omitempty"` // 请求的内存大小
TimeLimit string `json:"time_limit,omitempty"` // 运行时间限制 (分钟)
JobName string `json:"job_name,omitempty"` // 作业名称
Environment map[string]string `json:"environment,omitempty"` // 环境变量键值对
WorkDir string `json:"work_dir,omitempty"` // 作业工作目录
}
// JobResponse is the simplified API response for a job.
// JobResponse is the API response for a job.
type JobResponse struct {
JobID int32 `json:"job_id"`
Name string `json:"name"`
State []string `json:"job_state"`
Partition string `json:"partition"`
SubmitTime *int64 `json:"submit_time,omitempty"`
StartTime *int64 `json:"start_time,omitempty"`
EndTime *int64 `json:"end_time,omitempty"`
ExitCode *int32 `json:"exit_code,omitempty"`
Nodes string `json:"nodes,omitempty"`
// Identity
JobID int32 `json:"job_id"` // Slurm 作业 ID
Name string `json:"name"` // 作业名称
State []string `json:"job_state"` // 作业当前状态 (e.g. ["RUNNING"], ["PENDING","REQUEUED"])
StateReason string `json:"state_reason,omitempty"` // 作业等待/失败的原因
// Scheduling
Partition string `json:"partition"` // 所属分区
QOS string `json:"qos,omitempty"` // 使用的 QOS 策略
Priority *int32 `json:"priority,omitempty"` // 作业优先级
TimeLimit string `json:"time_limit,omitempty"` // 运行时间限制 (分钟,"UNLIMITED" 表示无限)
// Ownership
Account string `json:"account,omitempty"` // 计费账户
User string `json:"user,omitempty"` // 提交用户
Cluster string `json:"cluster,omitempty"` // 所属集群
// Resources
Cpus *int32 `json:"cpus,omitempty"` // 分配/请求的 CPU 核数
Tasks *int32 `json:"tasks,omitempty"` // 任务数
NodeCount *int32 `json:"node_count,omitempty"` // 节点数
Nodes string `json:"nodes,omitempty"` // 分配的节点列表
BatchHost string `json:"batch_host,omitempty"` // 批处理主节点
// Timing (Unix timestamp)
SubmitTime *int64 `json:"submit_time,omitempty"` // 提交时间
StartTime *int64 `json:"start_time,omitempty"` // 开始运行时间
EndTime *int64 `json:"end_time,omitempty"` // 结束/预计结束时间
// Result
ExitCode *int32 `json:"exit_code,omitempty"` // 退出码 (nil 表示未结束)
// IO Paths
StdOut string `json:"standard_output,omitempty"` // 标准输出文件路径
StdErr string `json:"standard_error,omitempty"` // 标准错误文件路径
StdIn string `json:"standard_input,omitempty"` // 标准输入文件路径
WorkDir string `json:"working_directory,omitempty"` // 工作目录
Command string `json:"command,omitempty"` // 执行的命令
// Array Job
ArrayJobID *int32 `json:"array_job_id,omitempty"` // 数组作业的父 Job ID
ArrayTaskID *int32 `json:"array_task_id,omitempty"` // 数组作业中的子任务 ID
}
// JobListResponse is the paginated response for job listings.
type JobListResponse struct {
Jobs []JobResponse `json:"jobs"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Jobs []JobResponse `json:"jobs"` // 作业列表
Total int `json:"total"` // 符合条件的作业总数
Page int `json:"page"` // 当前页码 (从 1 开始)
PageSize int `json:"page_size"` // 每页条数
}
// JobListQuery contains pagination parameters for active job listing.
type JobListQuery struct {
Page int `form:"page,default=1" json:"page,omitempty"` // 页码 (从 1 开始)
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"` // 每页条数
}
// JobHistoryQuery contains query parameters for job history.
type JobHistoryQuery struct {
Users string `form:"users" json:"users,omitempty"`
StartTime string `form:"start_time" json:"start_time,omitempty"`
EndTime string `form:"end_time" json:"end_time,omitempty"`
Account string `form:"account" json:"account,omitempty"`
Partition string `form:"partition" json:"partition,omitempty"`
State string `form:"state" json:"state,omitempty"`
JobName string `form:"job_name" json:"job_name,omitempty"`
Page int `form:"page,default=1" json:"page,omitempty"`
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"`
Users string `form:"users" json:"users,omitempty"` // 按用户名过滤 (逗号分隔)
StartTime string `form:"start_time" json:"start_time,omitempty"` // 作业开始时间下限 (Unix 时间戳)
EndTime string `form:"end_time" json:"end_time,omitempty"` // 作业结束时间上限 (Unix 时间戳)
SubmitTime string `form:"submit_time" json:"submit_time,omitempty"` // 作业提交时间过滤 (Unix 时间戳)
Account string `form:"account" json:"account,omitempty"` // 按计费账户过滤
Partition string `form:"partition" json:"partition,omitempty"` // 按分区过滤
State string `form:"state" json:"state,omitempty"` // 按作业状态过滤 (e.g. "COMPLETED", "FAILED")
JobName string `form:"job_name" json:"job_name,omitempty"` // 按作业名称过滤
Cluster string `form:"cluster" json:"cluster,omitempty"` // 按集群名称过滤
Qos string `form:"qos" json:"qos,omitempty"` // 按 QOS 策略过滤
Constraints string `form:"constraints" json:"constraints,omitempty"` // 按节点约束过滤
ExitCode string `form:"exit_code" json:"exit_code,omitempty"` // 按退出码过滤
Node string `form:"node" json:"node,omitempty"` // 按分配节点过滤
Reservation string `form:"reservation" json:"reservation,omitempty"` // 按预约名称过滤
Groups string `form:"groups" json:"groups,omitempty"` // 按用户组过滤
Wckey string `form:"wckey" json:"wckey,omitempty"` // 按 WCKey (Workload Characterization Key) 过滤
Page int `form:"page,default=1" json:"page,omitempty"` // 页码 (从 1 开始)
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"` // 每页条数
}

93
internal/model/task.go Normal file
View File

@@ -0,0 +1,93 @@
package model
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// Task status constants.
const (
TaskStatusSubmitted = "submitted"
TaskStatusPreparing = "preparing"
TaskStatusDownloading = "downloading"
TaskStatusReady = "ready"
TaskStatusQueued = "queued"
TaskStatusRunning = "running"
TaskStatusCompleted = "completed"
TaskStatusFailed = "failed"
)
// Task step constants for step-level retry tracking.
const (
TaskStepPreparing = "preparing"
TaskStepDownloading = "downloading"
TaskStepSubmitting = "submitting"
)
// Task represents an HPC task submitted through the application framework.
type Task struct {
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
TaskName string `gorm:"size:255" json:"task_name"`
AppID int64 `json:"app_id"`
AppName string `gorm:"size:255" json:"app_name"`
Status string `json:"status"`
CurrentStep string `json:"current_step"`
RetryCount int `json:"retry_count"`
Values json.RawMessage `gorm:"type:text" json:"values,omitempty"`
InputFileIDs json.RawMessage `json:"input_file_ids" gorm:"column:input_file_ids;type:text"`
Script string `json:"script,omitempty"`
SlurmJobID *int32 `json:"slurm_job_id,omitempty"`
WorkDir string `json:"work_dir,omitempty"`
Partition string `json:"partition,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
UserID string `json:"user_id"`
SubmittedAt time.Time `json:"submitted_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
func (Task) TableName() string {
return "hpc_tasks"
}
// CreateTaskRequest is the DTO for creating a new task.
type CreateTaskRequest struct {
AppID int64 `json:"app_id" binding:"required"`
TaskName string `json:"task_name"`
Values map[string]string `json:"values"`
InputFileIDs []int64 `json:"file_ids"`
}
// TaskResponse is the DTO returned in API responses.
type TaskResponse struct {
ID int64 `json:"id"`
TaskName string `json:"task_name"`
AppID int64 `json:"app_id"`
AppName string `json:"app_name"`
Status string `json:"status"`
CurrentStep string `json:"current_step"`
RetryCount int `json:"retry_count"`
SlurmJobID *int32 `json:"slurm_job_id"`
WorkDir string `json:"work_dir"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TaskListResponse is the paginated response for listing tasks.
type TaskListResponse struct {
Items []TaskResponse `json:"items"`
Total int64 `json:"total"`
}
// TaskListQuery contains query parameters for listing tasks.
type TaskListQuery struct {
Page int `form:"page" json:"page,omitempty"`
PageSize int `form:"page_size" json:"page_size,omitempty"`
Status string `form:"status" json:"status,omitempty"`
}

104
internal/model/task_test.go Normal file
View File

@@ -0,0 +1,104 @@
package model
import (
"encoding/json"
"testing"
"time"
)
func TestTask_TableName(t *testing.T) {
task := Task{}
if got := task.TableName(); got != "hpc_tasks" {
t.Errorf("Task.TableName() = %q, want %q", got, "hpc_tasks")
}
}
func TestTask_JSONRoundTrip(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
jobID := int32(42)
task := Task{
ID: 1,
TaskName: "test task",
AppID: 10,
AppName: "GROMACS",
Status: TaskStatusRunning,
CurrentStep: TaskStepSubmitting,
RetryCount: 1,
Values: json.RawMessage(`{"np":"4"}`),
InputFileIDs: json.RawMessage(`[1,2,3]`),
Script: "#!/bin/bash",
SlurmJobID: &jobID,
WorkDir: "/data/work",
Partition: "gpu",
ErrorMessage: "",
UserID: "user1",
SubmittedAt: now,
StartedAt: &now,
FinishedAt: nil,
CreatedAt: now,
UpdatedAt: now,
}
data, err := json.Marshal(task)
if err != nil {
t.Fatalf("marshal task: %v", err)
}
var got Task
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal task: %v", err)
}
if got.ID != task.ID {
t.Errorf("ID = %d, want %d", got.ID, task.ID)
}
if got.TaskName != task.TaskName {
t.Errorf("TaskName = %q, want %q", got.TaskName, task.TaskName)
}
if got.Status != task.Status {
t.Errorf("Status = %q, want %q", got.Status, task.Status)
}
if got.CurrentStep != task.CurrentStep {
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, task.CurrentStep)
}
if got.RetryCount != task.RetryCount {
t.Errorf("RetryCount = %d, want %d", got.RetryCount, task.RetryCount)
}
if got.SlurmJobID == nil || *got.SlurmJobID != jobID {
t.Errorf("SlurmJobID = %v, want %d", got.SlurmJobID, jobID)
}
if got.UserID != task.UserID {
t.Errorf("UserID = %q, want %q", got.UserID, task.UserID)
}
if string(got.Values) != string(task.Values) {
t.Errorf("Values = %s, want %s", got.Values, task.Values)
}
if string(got.InputFileIDs) != string(task.InputFileIDs) {
t.Errorf("InputFileIDs = %s, want %s", got.InputFileIDs, task.InputFileIDs)
}
if got.FinishedAt != nil {
t.Errorf("FinishedAt = %v, want nil", got.FinishedAt)
}
}
func TestCreateTaskRequest_JSONBinding(t *testing.T) {
payload := `{"app_id":5,"task_name":"my task","values":{"np":"8"},"file_ids":[10,20]}`
var req CreateTaskRequest
if err := json.Unmarshal([]byte(payload), &req); err != nil {
t.Fatalf("unmarshal CreateTaskRequest: %v", err)
}
if req.AppID != 5 {
t.Errorf("AppID = %d, want 5", req.AppID)
}
if req.TaskName != "my task" {
t.Errorf("TaskName = %q, want %q", req.TaskName, "my task")
}
if v, ok := req.Values["np"]; !ok || v != "8" {
t.Errorf("Values[\"np\"] = %q, want %q", v, "8")
}
if len(req.InputFileIDs) != 2 || req.InputFileIDs[0] != 10 || req.InputFileIDs[1] != 20 {
t.Errorf("InputFileIDs = %v, want [10 20]", req.InputFileIDs)
}
}

View File

@@ -1,45 +0,0 @@
package model
import "time"
// JobTemplate represents a saved job template.
type JobTemplate struct {
ID int64 `json:"id" gorm:"primaryKey;autoIncrement"`
Name string `json:"name" gorm:"uniqueIndex;size:255;not null"`
Description string `json:"description,omitempty" gorm:"type:text"`
Script string `json:"script" gorm:"type:text;not null"`
Partition string `json:"partition,omitempty" gorm:"size:255"`
QOS string `json:"qos,omitempty" gorm:"column:qos;size:255"`
CPUs int `json:"cpus,omitempty" gorm:"column:cpus"`
Memory string `json:"memory,omitempty" gorm:"size:50"`
TimeLimit string `json:"time_limit,omitempty" gorm:"column:time_limit;size:50"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName specifies the database table name for GORM.
func (JobTemplate) TableName() string { return "job_templates" }
// CreateTemplateRequest is the API request for creating a template.
type CreateTemplateRequest struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Script string `json:"script"`
Partition string `json:"partition,omitempty"`
QOS string `json:"qos,omitempty"`
CPUs int `json:"cpus,omitempty"`
Memory string `json:"memory,omitempty"`
TimeLimit string `json:"time_limit,omitempty"`
}
// UpdateTemplateRequest is the API request for updating a template.
type UpdateTemplateRequest struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Script string `json:"script,omitempty"`
Partition string `json:"partition,omitempty"`
QOS string `json:"qos,omitempty"`
CPUs int `json:"cpus,omitempty"`
Memory string `json:"memory,omitempty"`
TimeLimit string `json:"time_limit,omitempty"`
}

View File

@@ -1,7 +1,11 @@
package server
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
@@ -42,3 +46,97 @@ func InternalError(c *gin.Context, msg string) {
func ErrorWithStatus(c *gin.Context, code int, msg string) {
c.JSON(code, APIResponse{Success: false, Error: msg})
}
// ParseRange parses an HTTP Range header (RFC 7233).
// Only single-part ranges are supported: bytes=start-end, bytes=start-, bytes=-suffix.
// Multi-part ranges (bytes=0-100,200-300) return an error.
func ParseRange(rangeHeader string, fileSize int64) (start, end int64, err error) {
if rangeHeader == "" {
return 0, 0, fmt.Errorf("empty range header")
}
if !strings.HasPrefix(rangeHeader, "bytes=") {
return 0, 0, fmt.Errorf("invalid range unit: %s", rangeHeader)
}
rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
if strings.Contains(rangeSpec, ",") {
return 0, 0, fmt.Errorf("multi-part ranges are not supported")
}
rangeSpec = strings.TrimSpace(rangeSpec)
parts := strings.Split(rangeSpec, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid range format: %s", rangeSpec)
}
if parts[0] == "" {
suffix, parseErr := strconv.ParseInt(parts[1], 10, 64)
if parseErr != nil {
return 0, 0, fmt.Errorf("invalid suffix range: %s", parts[1])
}
if suffix <= 0 || suffix > fileSize {
return 0, 0, fmt.Errorf("suffix range %d exceeds file size %d", suffix, fileSize)
}
start = fileSize - suffix
end = fileSize - 1
} else if parts[1] == "" {
start, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid range start: %s", parts[0])
}
if start >= fileSize {
return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize)
}
end = fileSize - 1
} else {
start, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid range start: %s", parts[0])
}
end, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("invalid range end: %s", parts[1])
}
if start > end {
return 0, 0, fmt.Errorf("range start %d > end %d", start, end)
}
if start >= fileSize {
return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize)
}
if end >= fileSize {
end = fileSize - 1
}
}
return start, end, nil
}
// StreamFile sends a full file as an HTTP response with proper headers.
func StreamFile(c *gin.Context, reader io.ReadCloser, filename string, fileSize int64, contentType string) {
defer reader.Close()
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
c.Header("Content-Type", contentType)
c.Header("Content-Length", strconv.FormatInt(fileSize, 10))
c.Header("Accept-Ranges", "bytes")
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// StreamRange sends a partial content response (206) for a byte range.
func StreamRange(c *gin.Context, reader io.ReadCloser, start, end, totalSize int64, contentType string) {
defer reader.Close()
contentLength := end - start + 1
c.Header("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize))
c.Header("Content-Type", contentType)
c.Header("Content-Length", strconv.FormatInt(contentLength, 10))
c.Header("Accept-Ranges", "bytes")
c.Status(http.StatusPartialContent)
io.Copy(c.Writer, reader)
}

View File

@@ -114,3 +114,65 @@ func TestErrorWithStatus(t *testing.T) {
t.Fatalf("expected error 'already exists', got '%s'", resp.Error)
}
}
func TestParseRangeStandard(t *testing.T) {
tests := []struct {
rangeHeader string
fileSize int64
wantStart int64
wantEnd int64
wantErr bool
}{
{"bytes=0-1023", 10000, 0, 1023, false},
{"bytes=1024-", 10000, 1024, 9999, false},
{"bytes=-1024", 10000, 8976, 9999, false},
{"bytes=0-0", 10000, 0, 0, false},
{"bytes=9999-", 10000, 9999, 9999, false},
}
for _, tt := range tests {
start, end, err := ParseRange(tt.rangeHeader, tt.fileSize)
if (err != nil) != tt.wantErr {
t.Errorf("ParseRange(%q, %d) error = %v, wantErr %v", tt.rangeHeader, tt.fileSize, err, tt.wantErr)
continue
}
if !tt.wantErr {
if start != tt.wantStart || end != tt.wantEnd {
t.Errorf("ParseRange(%q, %d) = (%d, %d), want (%d, %d)", tt.rangeHeader, tt.fileSize, start, end, tt.wantStart, tt.wantEnd)
}
}
}
}
func TestParseRangeInvalidAndMultiPart(t *testing.T) {
tests := []struct {
rangeHeader string
fileSize int64
}{
{"", 10000},
{"bytes=9999-0", 10000},
{"bytes=20000-", 10000},
{"bytes=0-100,200-300", 10000},
{"bytes=0-100, 400-500", 10000},
{"bytes=", 10000},
{"chars=0-100", 10000},
}
for _, tt := range tests {
_, _, err := ParseRange(tt.rangeHeader, tt.fileSize)
if err == nil {
t.Errorf("ParseRange(%q, %d) expected error, got nil", tt.rangeHeader, tt.fileSize)
}
}
}
func TestParseRangeEdgeCases(t *testing.T) {
start, end, err := ParseRange("bytes=0-99999", 10000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if end != 9999 {
t.Errorf("end = %d, want 9999 (clamped to fileSize-1)", end)
}
if start != 0 {
t.Errorf("start = %d, want 0", start)
}
}

View File

@@ -25,16 +25,44 @@ type ClusterHandler interface {
GetDiag(c *gin.Context)
}
type TemplateHandler interface {
ListTemplates(c *gin.Context)
CreateTemplate(c *gin.Context)
GetTemplate(c *gin.Context)
UpdateTemplate(c *gin.Context)
DeleteTemplate(c *gin.Context)
type ApplicationHandler interface {
ListApplications(c *gin.Context)
CreateApplication(c *gin.Context)
GetApplication(c *gin.Context)
UpdateApplication(c *gin.Context)
DeleteApplication(c *gin.Context)
// SubmitApplication(c *gin.Context) // [已禁用] 已被 POST /tasks 取代
}
type UploadHandler interface {
InitUpload(c *gin.Context)
GetUploadStatus(c *gin.Context)
UploadChunk(c *gin.Context)
CompleteUpload(c *gin.Context)
CancelUpload(c *gin.Context)
}
type FileHandler interface {
ListFiles(c *gin.Context)
GetFile(c *gin.Context)
DownloadFile(c *gin.Context)
DeleteFile(c *gin.Context)
}
type FolderHandler interface {
CreateFolder(c *gin.Context)
GetFolder(c *gin.Context)
ListFolders(c *gin.Context)
DeleteFolder(c *gin.Context)
}
type TaskHandler interface {
CreateTask(c *gin.Context)
ListTasks(c *gin.Context)
}
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
func NewRouter(jobH JobHandler, clusterH ClusterHandler, templateH TemplateHandler, logger *zap.Logger) *gin.Engine {
func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, taskH TaskHandler, logger *zap.Logger) *gin.Engine {
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(gin.Recovery())
@@ -59,12 +87,47 @@ func NewRouter(jobH JobHandler, clusterH ClusterHandler, templateH TemplateHandl
v1.GET("/diag", clusterH.GetDiag)
templates := v1.Group("/templates")
templates.GET("", templateH.ListTemplates)
templates.POST("", templateH.CreateTemplate)
templates.GET("/:id", templateH.GetTemplate)
templates.PUT("/:id", templateH.UpdateTemplate)
templates.DELETE("/:id", templateH.DeleteTemplate)
apps := v1.Group("/applications")
apps.GET("", appH.ListApplications)
apps.POST("", appH.CreateApplication)
apps.GET("/:id", appH.GetApplication)
apps.PUT("/:id", appH.UpdateApplication)
apps.DELETE("/:id", appH.DeleteApplication)
// apps.POST("/:id/submit", appH.SubmitApplication) // [已禁用] 已被 POST /tasks 取代
files := v1.Group("/files")
if uploadH != nil {
uploads := files.Group("/uploads")
uploads.POST("", uploadH.InitUpload)
uploads.GET("/:id", uploadH.GetUploadStatus)
uploads.PUT("/:id/chunks/:index", uploadH.UploadChunk)
uploads.POST("/:id/complete", uploadH.CompleteUpload)
uploads.DELETE("/:id", uploadH.CancelUpload)
}
if fileH != nil {
files.GET("", fileH.ListFiles)
files.GET("/:id", fileH.GetFile)
files.GET("/:id/download", fileH.DownloadFile)
files.DELETE("/:id", fileH.DeleteFile)
}
if folderH != nil {
folders := files.Group("/folders")
folders.POST("", folderH.CreateFolder)
folders.GET("", folderH.ListFolders)
folders.GET("/:id", folderH.GetFolder)
folders.DELETE("/:id", folderH.DeleteFolder)
}
if taskH != nil {
tasks := v1.Group("/tasks")
{
tasks.POST("", taskH.CreateTask)
tasks.GET("", taskH.ListTasks)
}
}
return r
}
@@ -95,12 +158,36 @@ func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
v1.GET("/diag", notImplemented)
templates := v1.Group("/templates")
templates.GET("", notImplemented)
templates.POST("", notImplemented)
templates.GET("/:id", notImplemented)
templates.PUT("/:id", notImplemented)
templates.DELETE("/:id", notImplemented)
apps := v1.Group("/applications")
apps.GET("", notImplemented)
apps.POST("", notImplemented)
apps.GET("/:id", notImplemented)
apps.PUT("/:id", notImplemented)
apps.DELETE("/:id", notImplemented)
// apps.POST("/:id/submit", notImplemented) // [已禁用] 已被 POST /tasks 取代
files := v1.Group("/files")
uploads := files.Group("/uploads")
uploads.POST("", notImplemented)
uploads.GET("/:id", notImplemented)
uploads.PUT("/:id/chunks/:index", notImplemented)
uploads.POST("/:id/complete", notImplemented)
uploads.DELETE("/:id", notImplemented)
files.GET("", notImplemented)
files.GET("/:id", notImplemented)
files.GET("/:id/download", notImplemented)
files.DELETE("/:id", notImplemented)
folders := files.Group("/folders")
folders.POST("", notImplemented)
folders.GET("", notImplemented)
folders.GET("/:id", notImplemented)
folders.DELETE("/:id", notImplemented)
v1.POST("/tasks", notImplemented)
v1.GET("/tasks", notImplemented)
}
func notImplemented(c *gin.Context) {

View File

@@ -27,11 +27,12 @@ func TestAllRoutesRegistered(t *testing.T) {
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/partitions/:name"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/templates"},
{"POST", "/api/v1/templates"},
{"GET", "/api/v1/templates/:id"},
{"PUT", "/api/v1/templates/:id"},
{"DELETE", "/api/v1/templates/:id"},
{"GET", "/api/v1/applications"},
{"POST", "/api/v1/applications"},
{"GET", "/api/v1/applications/:id"},
{"PUT", "/api/v1/applications/:id"},
{"DELETE", "/api/v1/applications/:id"},
// {"POST", "/api/v1/applications/:id/submit"}, // [已禁用] 已被 POST /tasks 取代
}
routeMap := map[string]bool{}
@@ -74,7 +75,7 @@ func TestRegisteredPathReturns501(t *testing.T) {
{"GET", "/api/v1/nodes"},
{"GET", "/api/v1/partitions"},
{"GET", "/api/v1/diag"},
{"GET", "/api/v1/templates"},
{"GET", "/api/v1/applications"},
}
for _, ep := range endpoints {

View File

@@ -0,0 +1,111 @@
package service
import (
"context"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// ApplicationService handles parameter validation, script rendering, and job
// submission for parameterized HPC applications.
type ApplicationService struct {
store *store.ApplicationStore
jobSvc *JobService
workDirBase string
logger *zap.Logger
taskSvc *TaskService
}
func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger, taskSvc ...*TaskService) *ApplicationService {
var ts *TaskService
if len(taskSvc) > 0 {
ts = taskSvc[0]
}
return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger, taskSvc: ts}
}
// ListApplications delegates to the store.
func (s *ApplicationService) ListApplications(ctx context.Context, page, pageSize int) ([]model.Application, int, error) {
return s.store.List(ctx, page, pageSize)
}
// CreateApplication delegates to the store.
func (s *ApplicationService) CreateApplication(ctx context.Context, req *model.CreateApplicationRequest) (int64, error) {
return s.store.Create(ctx, req)
}
// GetApplication delegates to the store.
func (s *ApplicationService) GetApplication(ctx context.Context, id int64) (*model.Application, error) {
return s.store.GetByID(ctx, id)
}
// UpdateApplication delegates to the store.
func (s *ApplicationService) UpdateApplication(ctx context.Context, id int64, req *model.UpdateApplicationRequest) error {
return s.store.Update(ctx, id, req)
}
// DeleteApplication delegates to the store.
func (s *ApplicationService) DeleteApplication(ctx context.Context, id int64) error {
return s.store.Delete(ctx, id)
}
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此方法不再被调用。
/* // SubmitFromApplication orchestrates the full submission flow.
// When TaskService is available, it delegates to ProcessTaskSync which creates
// an hpc_tasks record and runs the full pipeline. Otherwise falls back to the
// original direct implementation.
func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicationID int64, values map[string]string) (*model.JobResponse, error) {
// [已禁用] 旧的直接提交路径,已被 TaskService 管道取代。生产环境中 taskSvc 始终非 nil此分支不会执行。
// if s.taskSvc != nil {
req := &model.CreateTaskRequest{
AppID: applicationID,
Values: values,
InputFileIDs: nil, // old API has no file_ids concept
TaskName: "",
}
return s.taskSvc.ProcessTaskSync(ctx, req)
// }
// // Fallback: original direct logic when TaskService not available
// app, err := s.store.GetByID(ctx, applicationID)
// if err != nil {
// return nil, fmt.Errorf("get application: %w", err)
// }
// if app == nil {
// return nil, fmt.Errorf("application %d not found", applicationID)
// }
//
// var params []model.ParameterSchema
// if len(app.Parameters) > 0 {
// if err := json.Unmarshal(app.Parameters, &params); err != nil {
// return nil, fmt.Errorf("parse parameters: %w", err)
// }
// }
//
// if err := ValidateParams(params, values); err != nil {
// return nil, err
// }
//
// rendered := RenderScript(app.ScriptTemplate, params, values)
//
// workDir := ""
// if s.workDirBase != "" {
// safeName := SanitizeDirName(app.Name)
// subDir := time.Now().Format("20060102_150405") + "_" + RandomSuffix(4)
// workDir = filepath.Join(s.workDirBase, safeName, subDir)
// if err := os.MkdirAll(workDir, 0777); err != nil {
// return nil, fmt.Errorf("create work directory %s: %w", workDir, err)
// }
// // 绕过 umask确保整条路径都有写权限
// for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
// os.Chmod(dir, 0777)
// }
// os.Chmod(s.workDirBase, 0777)
// }
//
// req := &model.SubmitJobRequest{Script: rendered, WorkDir: workDir}
// return s.jobSvc.SubmitJob(ctx, req)
} */

View File

@@ -0,0 +1,367 @@
package service
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
gormlogger "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupApplicationService(t *testing.T, slurmHandler http.HandlerFunc) (*ApplicationService, func()) {
t.Helper()
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Application{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
jobSvc := NewJobService(client, zap.NewNop())
appStore := store.NewApplicationStore(db)
appSvc := NewApplicationService(appStore, jobSvc, "", zap.NewNop())
return appSvc, srv.Close
}
func TestValidateParams_AllRequired(t *testing.T) {
params := []model.ParameterSchema{
{Name: "NAME", Type: model.ParamTypeString, Required: true},
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
}
values := map[string]string{"NAME": "hello", "COUNT": "5"}
if err := ValidateParams(params, values); err != nil {
t.Errorf("expected no error, got %v", err)
}
}
func TestValidateParams_MissingRequired(t *testing.T) {
params := []model.ParameterSchema{
{Name: "NAME", Type: model.ParamTypeString, Required: true},
}
values := map[string]string{}
err := ValidateParams(params, values)
if err == nil {
t.Fatal("expected error for missing required param")
}
if !strings.Contains(err.Error(), "NAME") {
t.Errorf("error should mention param name, got: %v", err)
}
}
func TestValidateParams_InvalidInteger(t *testing.T) {
params := []model.ParameterSchema{
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
}
values := map[string]string{"COUNT": "abc"}
err := ValidateParams(params, values)
if err == nil {
t.Fatal("expected error for invalid integer")
}
if !strings.Contains(err.Error(), "integer") {
t.Errorf("error should mention integer, got: %v", err)
}
}
func TestValidateParams_InvalidEnum(t *testing.T) {
params := []model.ParameterSchema{
{Name: "MODE", Type: model.ParamTypeEnum, Required: true, Options: []string{"fast", "slow"}},
}
values := map[string]string{"MODE": "medium"}
err := ValidateParams(params, values)
if err == nil {
t.Fatal("expected error for invalid enum value")
}
if !strings.Contains(err.Error(), "MODE") {
t.Errorf("error should mention param name, got: %v", err)
}
}
func TestValidateParams_BooleanValues(t *testing.T) {
params := []model.ParameterSchema{
{Name: "FLAG", Type: model.ParamTypeBoolean, Required: true},
}
for _, val := range []string{"true", "false", "1", "0"} {
err := ValidateParams(params, map[string]string{"FLAG": val})
if err != nil {
t.Errorf("boolean value %q should be valid, got error: %v", val, err)
}
}
}
func TestRenderScript_SimpleReplacement(t *testing.T) {
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
values := map[string]string{"INPUT": "data.txt"}
result := RenderScript("echo $INPUT", params, values)
expected := "echo 'data.txt'"
if result != expected {
t.Errorf("got %q, want %q", result, expected)
}
}
func TestRenderScript_DefaultValues(t *testing.T) {
params := []model.ParameterSchema{{Name: "OUTPUT", Type: model.ParamTypeString, Default: "out.log"}}
values := map[string]string{}
result := RenderScript("cat $OUTPUT", params, values)
expected := "cat 'out.log'"
if result != expected {
t.Errorf("got %q, want %q", result, expected)
}
}
func TestRenderScript_PreservesUnknownVars(t *testing.T) {
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
values := map[string]string{"INPUT": "data.txt"}
result := RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values)
if !strings.Contains(result, "$HOME") {
t.Error("$HOME should be preserved")
}
if !strings.Contains(result, "$PATH") {
t.Error("$PATH should be preserved")
}
if !strings.Contains(result, "'data.txt'") {
t.Error("$INPUT should be replaced")
}
}
func TestRenderScript_ShellEscaping(t *testing.T) {
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
tests := []struct {
name string
value string
expected string
}{
{"semicolon injection", "; rm -rf /", "'; rm -rf /'"},
{"command substitution", "$(cat /etc/passwd)", "'$(cat /etc/passwd)'"},
{"single quote", "hello'world", "'hello'\\''world'"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value})
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}
func TestRenderScript_OverlappingParams(t *testing.T) {
template := "$JOB_NAME and $JOB"
params := []model.ParameterSchema{
{Name: "JOB", Type: model.ParamTypeString},
{Name: "JOB_NAME", Type: model.ParamTypeString},
}
values := map[string]string{"JOB": "myjob", "JOB_NAME": "my-test-job"}
result := RenderScript(template, params, values)
if strings.Contains(result, "$JOB_NAME") {
t.Error("$JOB_NAME was not replaced")
}
if strings.Contains(result, "$JOB") {
t.Error("$JOB was not replaced")
}
if !strings.Contains(result, "'my-test-job'") {
t.Errorf("expected 'my-test-job' in result, got: %s", result)
}
if !strings.Contains(result, "'myjob'") {
t.Errorf("expected 'myjob' in result, got: %s", result)
}
}
func TestRenderScript_NewlineInValue(t *testing.T) {
params := []model.ParameterSchema{{Name: "CMD", Type: model.ParamTypeString}}
values := map[string]string{"CMD": "line1\nline2"}
result := RenderScript("echo $CMD", params, values)
expected := "echo 'line1\nline2'"
if result != expected {
t.Errorf("got %q, want %q", result, expected)
}
}
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_Success(t *testing.T) {
jobID := int32(42)
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer cleanup()
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
Name: "test-app",
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
})
if err != nil {
t.Fatalf("create app: %v", err)
}
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
"JOB_NAME": "my-job",
"INPUT": "hello",
})
if err != nil {
t.Fatalf("SubmitFromApplication() error = %v", err)
}
if resp.JobID != 42 {
t.Errorf("JobID = %d, want 42", resp.JobID)
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_AppNotFound(t *testing.T) {
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer cleanup()
_, err := appSvc.SubmitFromApplication(context.Background(), 99999, map[string]string{})
if err == nil {
t.Fatal("expected error for non-existent app")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_ValidationFail(t *testing.T) {
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer cleanup()
_, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
Name: "valid-app",
ScriptTemplate: "#!/bin/bash\necho $INPUT",
Parameters: json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`),
})
_, err = appSvc.SubmitFromApplication(context.Background(), 1, map[string]string{})
if err == nil {
t.Fatal("expected validation error for missing required param")
}
if !strings.Contains(err.Error(), "missing") {
t.Errorf("error should mention 'missing', got: %v", err)
}
}
*/
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
/*
func TestSubmitFromApplication_NoParameters(t *testing.T) {
jobID := int32(99)
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer cleanup()
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
Name: "simple-app",
ScriptTemplate: "#!/bin/bash\necho hello",
})
if err != nil {
t.Fatalf("create app: %v", err)
}
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{})
if err != nil {
t.Fatalf("SubmitFromApplication() error = %v", err)
}
if resp.JobID != 99 {
t.Errorf("JobID = %d, want 99", resp.JobID)
}
}
*/
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此测试不再需要。
/*
func TestSubmitFromApplication_DelegatesToTaskService(t *testing.T) {
jobID := int32(77)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer srv.Close()
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Application{}, &model.Task{}, &model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
appStore := store.NewApplicationStore(db)
taskStore := store.NewTaskStore(db)
fileStore := store.NewFileStore(db)
blobStore := store.NewBlobStore(db)
workDirBase := filepath.Join(t.TempDir(), "workdir")
os.MkdirAll(workDirBase, 0777)
taskSvc := NewTaskService(taskStore, appStore, fileStore, blobStore, nil, jobSvc, workDirBase, zap.NewNop())
appSvc := NewApplicationService(appStore, jobSvc, workDirBase, zap.NewNop(), taskSvc)
id, err := appStore.Create(context.Background(), &model.CreateApplicationRequest{
Name: "delegated-app",
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
})
if err != nil {
t.Fatalf("create app: %v", err)
}
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
"JOB_NAME": "delegated-job",
"INPUT": "test-data",
})
if err != nil {
t.Fatalf("SubmitFromApplication() error = %v", err)
}
if resp.JobID != 77 {
t.Errorf("JobID = %d, want 77", resp.JobID)
}
var task model.Task
if err := db.Where("app_id = ?", id).First(&task).Error; err != nil {
t.Fatalf("no hpc_tasks record found for app_id %d: %v", id, err)
}
if task.SlurmJobID == nil || *task.SlurmJobID != 77 {
t.Errorf("task SlurmJobID = %v, want 77", task.SlurmJobID)
}
if task.Status != model.TaskStatusQueued {
t.Errorf("task Status = %q, want %q", task.Status, model.TaskStatusQueued)
}
}
*/

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strconv"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
@@ -45,6 +46,27 @@ func uint32NoValString(v *slurm.Uint32NoVal) string {
return ""
}
func derefUint64NoValInt64(v *slurm.Uint64NoVal) *int64 {
if v != nil && v.Number != nil {
return v.Number
}
return nil
}
func derefCSVString(cs *slurm.CSVString) string {
if cs == nil || len(*cs) == 0 {
return ""
}
result := ""
for i, s := range *cs {
if i > 0 {
result += ","
}
result += s
}
return result
}
type ClusterService struct {
client *slurm.Client
logger *zap.Logger
@@ -55,11 +77,30 @@ func NewClusterService(client *slurm.Client, logger *zap.Logger) *ClusterService
}
func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetNodes"),
)
start := time.Now()
resp, _, err := s.client.Nodes.GetNodes(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetNodes"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get nodes", zap.Error(err))
return nil, fmt.Errorf("get nodes: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetNodes"),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Nodes == nil {
return nil, nil
}
@@ -71,11 +112,33 @@ func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, er
}
func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetNode"),
zap.String("node_name", name),
)
start := time.Now()
resp, _, err := s.client.Nodes.GetNode(ctx, name, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetNode"),
zap.String("node_name", name),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err))
return nil, fmt.Errorf("get node %s: %w", name, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetNode"),
zap.String("node_name", name),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Nodes == nil || len(*resp.Nodes) == 0 {
return nil, nil
}
@@ -85,11 +148,30 @@ func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeR
}
func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetPartitions"),
)
start := time.Now()
resp, _, err := s.client.Partitions.GetPartitions(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetPartitions"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get partitions", zap.Error(err))
return nil, fmt.Errorf("get partitions: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetPartitions"),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Partitions == nil {
return nil, nil
}
@@ -101,11 +183,33 @@ func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionRe
}
func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetPartition"),
zap.String("partition_name", name),
)
start := time.Now()
resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetPartition"),
zap.String("partition_name", name),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err))
return nil, fmt.Errorf("get partition %s: %w", name, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetPartition"),
zap.String("partition_name", name),
zap.Duration("took", took),
zap.Any("body", resp),
)
if resp.Partitions == nil || len(*resp.Partitions) == 0 {
return nil, nil
}
@@ -115,11 +219,30 @@ func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.
}
func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetDiag"),
)
start := time.Now()
resp, _, err := s.client.Diag.GetDiag(ctx)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetDiag"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get diag", zap.Error(err))
return nil, fmt.Errorf("get diag: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetDiag"),
zap.Duration("took", took),
zap.Any("body", resp),
)
return resp, nil
}
@@ -128,17 +251,39 @@ func mapNode(n slurm.Node) model.NodeResponse {
Name: derefStr(n.Name),
State: n.State,
CPUs: derefInt32(n.Cpus),
AllocCpus: n.AllocCpus,
Cores: n.Cores,
Sockets: n.Sockets,
Threads: n.Threads,
RealMemory: derefInt64(n.RealMemory),
AllocMem: derefInt64(n.AllocMemory),
AllocMemory: derefInt64(n.AllocMemory),
FreeMem: derefUint64NoValInt64(n.FreeMem),
CpuLoad: n.CpuLoad,
Arch: derefStr(n.Architecture),
OS: derefStr(n.OperatingSystem),
Gres: derefStr(n.Gres),
GresUsed: derefStr(n.GresUsed),
Reason: derefStr(n.Reason),
ReasonSetByUser: derefStr(n.ReasonSetByUser),
Address: derefStr(n.Address),
Hostname: derefStr(n.Hostname),
Weight: n.Weight,
Features: derefCSVString(n.Features),
ActiveFeatures: derefCSVString(n.ActiveFeatures),
}
}
func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
var state []string
var isDefault bool
if pi.Partition != nil {
state = pi.Partition.State
for _, s := range state {
if s == "DEFAULT" {
isDefault = true
break
}
}
}
var nodes string
if pi.Nodes != nil {
@@ -156,12 +301,56 @@ func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
if pi.Maximums != nil {
maxTime = uint32NoValString(pi.Maximums.Time)
}
var maxNodes *int32
if pi.Maximums != nil {
maxNodes = mapUint32NoValToInt32(pi.Maximums.Nodes)
}
var maxCPUsPerNode *int32
if pi.Maximums != nil {
maxCPUsPerNode = mapUint32NoValToInt32(pi.Maximums.CpusPerNode)
}
var minNodes *int32
if pi.Minimums != nil {
minNodes = pi.Minimums.Nodes
}
var defaultTime string
if pi.Defaults != nil {
defaultTime = uint32NoValString(pi.Defaults.Time)
}
var graceTime *int32 = pi.GraceTime
var priority *int32
if pi.Priority != nil {
priority = pi.Priority.JobFactor
}
var qosAllowed, qosDeny, qosAssigned string
if pi.QOS != nil {
qosAllowed = derefStr(pi.QOS.Allowed)
qosDeny = derefStr(pi.QOS.Deny)
qosAssigned = derefStr(pi.QOS.Assigned)
}
var accountsAllowed, accountsDeny string
if pi.Accounts != nil {
accountsAllowed = derefStr(pi.Accounts.Allowed)
accountsDeny = derefStr(pi.Accounts.Deny)
}
return model.PartitionResponse{
Name: derefStr(pi.Name),
State: state,
Default: isDefault,
Nodes: nodes,
TotalCPUs: totalCPUs,
TotalNodes: totalNodes,
TotalCPUs: totalCPUs,
MaxTime: maxTime,
MaxNodes: maxNodes,
MaxCPUsPerNode: maxCPUsPerNode,
MinNodes: minNodes,
DefaultTime: defaultTime,
GraceTime: graceTime,
Priority: priority,
QOSAllowed: qosAllowed,
QOSDeny: qosDeny,
QOSAssigned: qosAssigned,
AccountsAllowed: accountsAllowed,
AccountsDeny: accountsDeny,
}
}

View File

@@ -352,10 +352,10 @@ func TestClusterService_GetNodes_ErrorLogging(t *testing.T) {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[0]
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
@@ -374,10 +374,10 @@ func TestClusterService_GetNode_ErrorLogging(t *testing.T) {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[0]
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
@@ -403,10 +403,10 @@ func TestClusterService_GetPartitions_ErrorLogging(t *testing.T) {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[0]
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
@@ -425,10 +425,10 @@ func TestClusterService_GetPartition_ErrorLogging(t *testing.T) {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[0]
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}
@@ -454,10 +454,10 @@ func TestClusterService_GetDiag_ErrorLogging(t *testing.T) {
t.Fatal("expected error, got nil")
}
if logs.Len() != 1 {
t.Fatalf("expected 1 log entry, got %d", logs.Len())
if logs.Len() != 3 {
t.Fatalf("expected 3 log entries, got %d", logs.Len())
}
entry := logs.All()[0]
entry := logs.All()[2]
if entry.Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entry.Level)
}

View File

@@ -0,0 +1,98 @@
package service
import (
"context"
"fmt"
"io"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// DownloadService handles file downloads with streaming and Range support.
type DownloadService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
bucket string
logger *zap.Logger
}
// NewDownloadService creates a new DownloadService.
func NewDownloadService(storage storage.ObjectStorage, blobStore *store.BlobStore, fileStore *store.FileStore, bucket string, logger *zap.Logger) *DownloadService {
return &DownloadService{
storage: storage,
blobStore: blobStore,
fileStore: fileStore,
bucket: bucket,
logger: logger,
}
}
// Download returns a stream reader for the given file, optionally limited to a byte range.
// Returns (reader, file, blob, start, end, error).
func (s *DownloadService) Download(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
file, err := s.fileStore.GetByID(ctx, fileID)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get file: %w", err)
}
if file == nil {
return nil, nil, nil, 0, 0, fmt.Errorf("file not found")
}
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get blob: %w", err)
}
if blob == nil {
return nil, nil, nil, 0, 0, fmt.Errorf("blob not found")
}
var start, end int64
if rangeHeader != "" {
start, end, err = server.ParseRange(rangeHeader, blob.FileSize)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("parse range: %w", err)
}
} else {
start = 0
end = blob.FileSize - 1
}
opts := storage.GetOptions{
Start: &start,
End: &end,
}
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, opts)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get object: %w", err)
}
return reader, file, blob, start, end, nil
}
// GetFileMetadata returns the file and its associated blob metadata.
func (s *DownloadService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
file, err := s.fileStore.GetByID(ctx, fileID)
if err != nil {
return nil, nil, fmt.Errorf("get file: %w", err)
}
if file == nil {
return nil, nil, fmt.Errorf("file not found")
}
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
if err != nil {
return nil, nil, fmt.Errorf("get blob: %w", err)
}
if blob == nil {
return nil, nil, fmt.Errorf("blob not found")
}
return file, blob, nil
}

View File

@@ -0,0 +1,260 @@
package service
import (
"bytes"
"context"
"io"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
gormlogger "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type mockDownloadStorage struct {
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
}
func (m *mockDownloadStorage) PutObject(_ context.Context, _ string, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockDownloadStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return nil, storage.ObjectInfo{}, nil
}
func (m *mockDownloadStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockDownloadStorage) AbortMultipartUpload(_ context.Context, _ string, _ string, _ string) error {
return nil
}
func (m *mockDownloadStorage) RemoveIncompleteUpload(_ context.Context, _ string, _ string) error {
return nil
}
func (m *mockDownloadStorage) RemoveObject(_ context.Context, _ string, _ string, _ storage.RemoveObjectOptions) error {
return nil
}
func (m *mockDownloadStorage) ListObjects(_ context.Context, _ string, _ string, _ bool) ([]storage.ObjectInfo, error) {
return nil, nil
}
func (m *mockDownloadStorage) RemoveObjects(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
return nil
}
func (m *mockDownloadStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (m *mockDownloadStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (m *mockDownloadStorage) StatObject(_ context.Context, _ string, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupDownloadService(t *testing.T) (*DownloadService, *store.FileStore, *store.BlobStore, *mockDownloadStorage) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
mockStorage := &mockDownloadStorage{}
blobStore := store.NewBlobStore(db)
fileStore := store.NewFileStore(db)
svc := NewDownloadService(mockStorage, blobStore, fileStore, "test-bucket", zap.NewNop())
return svc, fileStore, blobStore, mockStorage
}
func createTestFileAndBlob(t *testing.T, fileStore *store.FileStore, blobStore *store.BlobStore) (*model.File, *model.FileBlob) {
t.Helper()
blob := &model.FileBlob{
SHA256: "abc123def456abc123def456abc123def456abc123def456abc123def456abcd",
MinioKey: "chunks/session1/part-0",
FileSize: 5000,
MimeType: "application/octet-stream",
RefCount: 1,
}
if err := blobStore.Create(context.Background(), blob); err != nil {
t.Fatalf("create blob: %v", err)
}
file := &model.File{
Name: "test.dat",
BlobSHA256: blob.SHA256,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := fileStore.Create(context.Background(), file); err != nil {
t.Fatalf("create file: %v", err)
}
return file, blob
}
func TestDownload_FullFile(t *testing.T) {
svc, fileStore, blobStore, mockStorage := setupDownloadService(t)
file, blob := createTestFileAndBlob(t, fileStore, blobStore)
content := make([]byte, blob.FileSize)
for i := range content {
content[i] = byte(i % 256)
}
mockStorage.getObjectFn = func(_ context.Context, _, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start == nil || opts.End == nil {
t.Fatal("expected Start and End to be set")
}
if *opts.Start != 0 {
t.Fatalf("expected start=0, got %d", *opts.Start)
}
if *opts.End != blob.FileSize-1 {
t.Fatalf("expected end=%d, got %d", blob.FileSize-1, *opts.End)
}
return io.NopCloser(bytes.NewReader(content)), storage.ObjectInfo{Size: blob.FileSize}, nil
}
reader, gotFile, gotBlob, start, end, err := svc.Download(context.Background(), file.ID, "")
if err != nil {
t.Fatalf("Download: %v", err)
}
defer reader.Close()
if gotFile.ID != file.ID {
t.Fatalf("expected file ID %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.SHA256 != blob.SHA256 {
t.Fatalf("expected blob SHA256 %s, got %s", blob.SHA256, gotBlob.SHA256)
}
if start != 0 {
t.Fatalf("expected start=0, got %d", start)
}
if end != blob.FileSize-1 {
t.Fatalf("expected end=%d, got %d", blob.FileSize-1, end)
}
read, _ := io.ReadAll(reader)
if int64(len(read)) != blob.FileSize {
t.Fatalf("expected %d bytes, got %d", blob.FileSize, len(read))
}
}
func TestDownload_WithRange(t *testing.T) {
svc, fileStore, blobStore, mockStorage := setupDownloadService(t)
file, _ := createTestFileAndBlob(t, fileStore, blobStore)
content := make([]byte, 1024)
for i := range content {
content[i] = byte(i % 256)
}
mockStorage.getObjectFn = func(_ context.Context, _, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start == nil || opts.End == nil {
t.Fatal("expected Start and End to be set")
}
if *opts.Start != 0 {
t.Fatalf("expected start=0, got %d", *opts.Start)
}
if *opts.End != 1023 {
t.Fatalf("expected end=1023, got %d", *opts.End)
}
return io.NopCloser(bytes.NewReader(content[:1024])), storage.ObjectInfo{Size: 1024}, nil
}
reader, _, _, start, end, err := svc.Download(context.Background(), file.ID, "bytes=0-1023")
if err != nil {
t.Fatalf("Download: %v", err)
}
defer reader.Close()
if start != 0 {
t.Fatalf("expected start=0, got %d", start)
}
if end != 1023 {
t.Fatalf("expected end=1023, got %d", end)
}
read, _ := io.ReadAll(reader)
if len(read) != 1024 {
t.Fatalf("expected 1024 bytes, got %d", len(read))
}
}
func TestDownload_FileNotFound(t *testing.T) {
svc, _, _, _ := setupDownloadService(t)
_, _, _, _, _, err := svc.Download(context.Background(), 99999, "")
if err == nil {
t.Fatal("expected error for missing file")
}
if err.Error() != "file not found" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDownload_BlobNotFound(t *testing.T) {
svc, fileStore, _, _ := setupDownloadService(t)
file := &model.File{
Name: "orphan.dat",
BlobSHA256: "nonexistent_hash_0000000000000000000000000000000000000000",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := fileStore.Create(context.Background(), file); err != nil {
t.Fatalf("create file: %v", err)
}
_, _, _, _, _, err := svc.Download(context.Background(), file.ID, "")
if err == nil {
t.Fatal("expected error for missing blob")
}
if err.Error() != "blob not found" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestGetFileMetadata(t *testing.T) {
svc, fileStore, blobStore, _ := setupDownloadService(t)
file, _ := createTestFileAndBlob(t, fileStore, blobStore)
gotFile, gotBlob, err := svc.GetFileMetadata(context.Background(), file.ID)
if err != nil {
t.Fatalf("GetFileMetadata: %v", err)
}
if gotFile.ID != file.ID {
t.Fatalf("expected file ID %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.FileSize != 5000 {
t.Fatalf("expected file size 5000, got %d", gotBlob.FileSize)
}
if gotBlob.MimeType != "application/octet-stream" {
t.Fatalf("expected mime type application/octet-stream, got %s", gotBlob.MimeType)
}
}

View File

@@ -0,0 +1,178 @@
package service
import (
"context"
"fmt"
"io"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/server"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
)
// FileService handles file listing, metadata, download, and deletion operations.
type FileService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
bucket string
db *gorm.DB
logger *zap.Logger
}
// NewFileService creates a new FileService.
func NewFileService(storage storage.ObjectStorage, blobStore *store.BlobStore, fileStore *store.FileStore, bucket string, db *gorm.DB, logger *zap.Logger) *FileService {
return &FileService{
storage: storage,
blobStore: blobStore,
fileStore: fileStore,
bucket: bucket,
db: db,
logger: logger,
}
}
// ListFiles returns a paginated list of files, optionally filtered by folder or search query.
func (s *FileService) ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
var files []model.File
var total int64
var err error
if search != "" {
files, total, err = s.fileStore.Search(ctx, search, page, pageSize)
} else {
files, total, err = s.fileStore.List(ctx, folderID, page, pageSize)
}
if err != nil {
return nil, 0, fmt.Errorf("list files: %w", err)
}
responses := make([]model.FileResponse, 0, len(files))
for _, f := range files {
blob, err := s.blobStore.GetBySHA256(ctx, f.BlobSHA256)
if err != nil {
return nil, 0, fmt.Errorf("get blob for file %d: %w", f.ID, err)
}
if blob == nil {
return nil, 0, fmt.Errorf("blob not found for file %d", f.ID)
}
responses = append(responses, model.FileResponse{
ID: f.ID,
Name: f.Name,
FolderID: f.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: f.BlobSHA256,
CreatedAt: f.CreatedAt,
UpdatedAt: f.UpdatedAt,
})
}
return responses, total, nil
}
// GetFileMetadata returns the file and its associated blob metadata.
func (s *FileService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
file, err := s.fileStore.GetByID(ctx, fileID)
if err != nil {
return nil, nil, fmt.Errorf("get file: %w", err)
}
if file == nil {
return nil, nil, fmt.Errorf("file not found: %d", fileID)
}
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
if err != nil {
return nil, nil, fmt.Errorf("get blob: %w", err)
}
if blob == nil {
return nil, nil, fmt.Errorf("blob not found for file %d", fileID)
}
return file, blob, nil
}
// DownloadFile returns a reader for the file content, along with file and blob metadata.
// If rangeHeader is non-empty, it parses the range and returns partial content.
func (s *FileService) DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
file, blob, err := s.GetFileMetadata(ctx, fileID)
if err != nil {
return nil, nil, nil, 0, 0, err
}
var start, end int64
if rangeHeader != "" {
start, end, err = server.ParseRange(rangeHeader, blob.FileSize)
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("parse range: %w", err)
}
} else {
start = 0
end = blob.FileSize - 1
}
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{
Start: &start,
End: &end,
})
if err != nil {
return nil, nil, nil, 0, 0, fmt.Errorf("get object: %w", err)
}
return reader, file, blob, start, end, nil
}
// DeleteFile soft-deletes a file. If no other active files reference the same blob,
// it decrements the blob ref count and removes the object from storage when ref count reaches 0.
func (s *FileService) DeleteFile(ctx context.Context, fileID int64) error {
return s.db.Transaction(func(tx *gorm.DB) error {
txFileStore := store.NewFileStore(tx)
txBlobStore := store.NewBlobStore(tx)
blobSHA256, err := txFileStore.GetBlobSHA256ByID(ctx, fileID)
if err != nil {
return fmt.Errorf("get blob sha256: %w", err)
}
if blobSHA256 == "" {
return fmt.Errorf("file not found: %d", fileID)
}
if err := tx.Delete(&model.File{}, fileID).Error; err != nil {
return fmt.Errorf("soft delete file: %w", err)
}
activeCount, err := txFileStore.CountByBlobSHA256(ctx, blobSHA256)
if err != nil {
return fmt.Errorf("count active refs: %w", err)
}
if activeCount == 0 {
newRefCount, err := txBlobStore.DecrementRef(ctx, blobSHA256)
if err != nil {
return fmt.Errorf("decrement ref: %w", err)
}
if newRefCount == 0 {
blob, err := txBlobStore.GetBySHA256(ctx, blobSHA256)
if err != nil {
return fmt.Errorf("get blob for cleanup: %w", err)
}
if blob != nil {
if err := s.storage.RemoveObject(ctx, s.bucket, blob.MinioKey, storage.RemoveObjectOptions{}); err != nil {
return fmt.Errorf("remove object: %w", err)
}
if err := txBlobStore.Delete(ctx, blobSHA256); err != nil {
return fmt.Errorf("delete blob: %w", err)
}
}
}
}
return nil
})
}

View File

@@ -0,0 +1,484 @@
package service
import (
"bytes"
"context"
"errors"
"io"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type mockFileStorage struct {
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
removeObjectFn func(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error
}
func (m *mockFileStorage) PutObject(_ context.Context, _, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockFileStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return io.NopCloser(strings.NewReader("data")), storage.ObjectInfo{}, nil
}
func (m *mockFileStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *mockFileStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockFileStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
return nil
}
func (m *mockFileStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
if m.removeObjectFn != nil {
return m.removeObjectFn(ctx, bucket, key, opts)
}
return nil
}
func (m *mockFileStorage) ListObjects(_ context.Context, _ string, _ string, _ bool) ([]storage.ObjectInfo, error) {
return nil, nil
}
func (m *mockFileStorage) RemoveObjects(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
return nil
}
func (m *mockFileStorage) BucketExists(_ context.Context, _ string) (bool, error) {
return true, nil
}
func (m *mockFileStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
return nil
}
func (m *mockFileStorage) StatObject(_ context.Context, _, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupFileTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func setupFileService(t *testing.T) (*FileService, *mockFileStorage, *gorm.DB) {
t.Helper()
db := setupFileTestDB(t)
ms := &mockFileStorage{}
svc := NewFileService(ms, store.NewBlobStore(db), store.NewFileStore(db), "test-bucket", db, zap.NewNop())
return svc, ms, db
}
func createTestBlob(t *testing.T, db *gorm.DB, sha256, minioKey, mimeType string, fileSize int64, refCount int) *model.FileBlob {
t.Helper()
blob := &model.FileBlob{
SHA256: sha256,
MinioKey: minioKey,
FileSize: fileSize,
MimeType: mimeType,
RefCount: refCount,
}
if err := db.Create(blob).Error; err != nil {
t.Fatalf("create blob: %v", err)
}
return blob
}
func createTestFile(t *testing.T, db *gorm.DB, name, blobSHA256 string, folderID *int64) *model.File {
t.Helper()
file := &model.File{
Name: name,
FolderID: folderID,
BlobSHA256: blobSHA256,
}
if err := db.Create(file).Error; err != nil {
t.Fatalf("create file: %v", err)
}
return file
}
func TestListFiles_Empty(t *testing.T) {
svc, _, _ := setupFileService(t)
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 0 {
t.Errorf("expected total 0, got %d", total)
}
if len(files) != 0 {
t.Errorf("expected empty files, got %d", len(files))
}
}
func TestListFiles_WithFiles(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256abc", "blobs/abc", "text/plain", 1024, 2)
createTestFile(t, db, "file1.txt", blob.SHA256, nil)
createTestFile(t, db, "file2.txt", blob.SHA256, nil)
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
if len(files) != 2 {
t.Fatalf("expected 2 files, got %d", len(files))
}
for _, f := range files {
if f.Size != 1024 {
t.Errorf("expected size 1024, got %d", f.Size)
}
if f.MimeType != "text/plain" {
t.Errorf("expected mime text/plain, got %s", f.MimeType)
}
if f.SHA256 != "sha256abc" {
t.Errorf("expected sha256 sha256abc, got %s", f.SHA256)
}
}
}
func TestListFiles_Search(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256search", "blobs/search", "image/png", 2048, 1)
createTestFile(t, db, "photo.png", blob.SHA256, nil)
createTestFile(t, db, "document.pdf", "sha256other", nil)
createTestBlob(t, db, "sha256other", "blobs/other", "application/pdf", 512, 1)
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "photo")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 1 {
t.Errorf("expected total 1, got %d", total)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Name != "photo.png" {
t.Errorf("expected photo.png, got %s", files[0].Name)
}
}
func TestGetFileMetadata_Found(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256meta", "blobs/meta", "application/json", 42, 1)
file := createTestFile(t, db, "data.json", blob.SHA256, nil)
gotFile, gotBlob, err := svc.GetFileMetadata(context.Background(), file.ID)
if err != nil {
t.Fatalf("GetFileMetadata: %v", err)
}
if gotFile.ID != file.ID {
t.Errorf("expected file id %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.SHA256 != blob.SHA256 {
t.Errorf("expected blob sha256 %s, got %s", blob.SHA256, gotBlob.SHA256)
}
}
func TestGetFileMetadata_NotFound(t *testing.T) {
svc, _, _ := setupFileService(t)
_, _, err := svc.GetFileMetadata(context.Background(), 9999)
if err == nil {
t.Fatal("expected error for missing file")
}
}
func TestDownloadFile_Full(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256dl", "blobs/dl", "text/plain", 100, 1)
file := createTestFile(t, db, "download.txt", blob.SHA256, nil)
content := []byte("hello world")
ms.getObjectFn = func(_ context.Context, _ string, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start == nil || opts.End == nil {
t.Error("expected start and end to be set")
} else if *opts.Start != 0 || *opts.End != 99 {
t.Errorf("expected range 0-99, got %d-%d", *opts.Start, *opts.End)
}
return io.NopCloser(bytes.NewReader(content)), storage.ObjectInfo{}, nil
}
reader, gotFile, gotBlob, start, end, err := svc.DownloadFile(context.Background(), file.ID, "")
if err != nil {
t.Fatalf("DownloadFile: %v", err)
}
defer reader.Close()
if gotFile.ID != file.ID {
t.Errorf("expected file id %d, got %d", file.ID, gotFile.ID)
}
if gotBlob.SHA256 != blob.SHA256 {
t.Errorf("expected blob sha256 %s, got %s", blob.SHA256, gotBlob.SHA256)
}
if start != 0 {
t.Errorf("expected start 0, got %d", start)
}
if end != 99 {
t.Errorf("expected end 99, got %d", end)
}
data, _ := io.ReadAll(reader)
if string(data) != "hello world" {
t.Errorf("expected 'hello world', got %q", string(data))
}
}
func TestDownloadFile_WithRange(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256range", "blobs/range", "text/plain", 1000, 1)
file := createTestFile(t, db, "range.txt", blob.SHA256, nil)
ms.getObjectFn = func(_ context.Context, _ string, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if opts.Start != nil && *opts.Start != 100 {
t.Errorf("expected start 100, got %d", *opts.Start)
}
if opts.End != nil && *opts.End != 199 {
t.Errorf("expected end 199, got %d", *opts.End)
}
return io.NopCloser(strings.NewReader("partial")), storage.ObjectInfo{}, nil
}
reader, _, _, start, end, err := svc.DownloadFile(context.Background(), file.ID, "bytes=100-199")
if err != nil {
t.Fatalf("DownloadFile: %v", err)
}
defer reader.Close()
if start != 100 {
t.Errorf("expected start 100, got %d", start)
}
if end != 199 {
t.Errorf("expected end 199, got %d", end)
}
}
func TestDeleteFile_LastRef(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256del", "blobs/del", "text/plain", 50, 1)
file := createTestFile(t, db, "delete-me.txt", blob.SHA256, nil)
removed := false
ms.removeObjectFn = func(_ context.Context, bucket, key string, _ storage.RemoveObjectOptions) error {
if bucket != "test-bucket" {
t.Errorf("expected bucket 'test-bucket', got %q", bucket)
}
if key != "blobs/del" {
t.Errorf("expected key 'blobs/del', got %q", key)
}
removed = true
return nil
}
if err := svc.DeleteFile(context.Background(), file.ID); err != nil {
t.Fatalf("DeleteFile: %v", err)
}
if !removed {
t.Error("expected RemoveObject to be called")
}
var count int64
db.Model(&model.FileBlob{}).Where("sha256 = ?", "sha256del").Count(&count)
if count != 0 {
t.Errorf("expected blob to be hard deleted, found %d records", count)
}
var fileCount int64
db.Unscoped().Model(&model.File{}).Where("id = ?", file.ID).Count(&fileCount)
if fileCount != 1 {
t.Errorf("expected file to still exist (soft deleted), found %d", fileCount)
}
var deletedFile model.File
db.Unscoped().First(&deletedFile, file.ID)
if deletedFile.DeletedAt.Time.IsZero() {
t.Error("expected deleted_at to be set")
}
}
func TestDeleteFile_OtherRefsExist(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256multi", "blobs/multi", "text/plain", 50, 3)
file1 := createTestFile(t, db, "ref1.txt", blob.SHA256, nil)
createTestFile(t, db, "ref2.txt", blob.SHA256, nil)
createTestFile(t, db, "ref3.txt", blob.SHA256, nil)
removed := false
ms.removeObjectFn = func(_ context.Context, _, _ string, _ storage.RemoveObjectOptions) error {
removed = true
return nil
}
if err := svc.DeleteFile(context.Background(), file1.ID); err != nil {
t.Fatalf("DeleteFile: %v", err)
}
if removed {
t.Error("expected RemoveObject NOT to be called since other refs exist")
}
var updatedBlob model.FileBlob
db.Where("sha256 = ?", "sha256multi").First(&updatedBlob)
if updatedBlob.RefCount != 3 {
t.Errorf("expected ref_count to remain 3, got %d", updatedBlob.RefCount)
}
}
func TestDeleteFile_SoftDeleteNotAffectRefcount(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256soft", "blobs/soft", "text/plain", 50, 2)
file1 := createTestFile(t, db, "soft1.txt", blob.SHA256, nil)
file2 := createTestFile(t, db, "soft2.txt", blob.SHA256, nil)
if err := svc.DeleteFile(context.Background(), file1.ID); err != nil {
t.Fatalf("DeleteFile: %v", err)
}
var updatedBlob model.FileBlob
db.Where("sha256 = ?", "sha256soft").First(&updatedBlob)
if updatedBlob.RefCount != 2 {
t.Errorf("expected ref_count to remain 2 (soft delete should not decrement), got %d", updatedBlob.RefCount)
}
activeCount, err := store.NewFileStore(db).CountByBlobSHA256(context.Background(), "sha256soft")
if err != nil {
t.Fatalf("CountByBlobSHA256: %v", err)
}
if activeCount != 1 {
t.Errorf("expected 1 active ref after soft delete, got %d", activeCount)
}
var allFiles []model.File
db.Unscoped().Where("blob_sha256 = ?", "sha256soft").Find(&allFiles)
if len(allFiles) != 2 {
t.Errorf("expected 2 total files (one soft deleted), got %d", len(allFiles))
}
if err := svc.DeleteFile(context.Background(), file2.ID); err != nil {
t.Fatalf("DeleteFile second: %v", err)
}
activeCount2, err := store.NewFileStore(db).CountByBlobSHA256(context.Background(), "sha256soft")
if err != nil {
t.Fatalf("CountByBlobSHA256 after second delete: %v", err)
}
if activeCount2 != 0 {
t.Errorf("expected 0 active refs after both deleted, got %d", activeCount2)
}
var finalBlob model.FileBlob
db.Where("sha256 = ?", "sha256soft").First(&finalBlob)
if finalBlob.RefCount != 1 {
t.Errorf("expected ref_count=1 (decremented once from 2), got %d", finalBlob.RefCount)
}
}
func TestDeleteFile_NotFound(t *testing.T) {
svc, _, _ := setupFileService(t)
err := svc.DeleteFile(context.Background(), 9999)
if err == nil {
t.Fatal("expected error for missing file")
}
}
func TestDownloadFile_StorageError(t *testing.T) {
svc, ms, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256err", "blobs/err", "text/plain", 100, 1)
file := createTestFile(t, db, "error.txt", blob.SHA256, nil)
ms.getObjectFn = func(_ context.Context, _ string, _ string, _ storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return nil, storage.ObjectInfo{}, errors.New("storage unavailable")
}
_, _, _, _, _, err := svc.DownloadFile(context.Background(), file.ID, "")
if err == nil {
t.Fatal("expected error from storage")
}
if !strings.Contains(err.Error(), "storage unavailable") {
t.Errorf("expected storage error, got: %v", err)
}
}
func TestListFiles_WithFolderFilter(t *testing.T) {
svc, _, db := setupFileService(t)
blob := createTestBlob(t, db, "sha256folder", "blobs/folder", "text/plain", 100, 2)
folderID := int64(1)
createTestFile(t, db, "in_folder.txt", blob.SHA256, &folderID)
createTestFile(t, db, "root.txt", blob.SHA256, nil)
files, total, err := svc.ListFiles(context.Background(), &folderID, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles: %v", err)
}
if total != 1 {
t.Errorf("expected total 1, got %d", total)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Name != "in_folder.txt" {
t.Errorf("expected in_folder.txt, got %s", files[0].Name)
}
rootFiles, rootTotal, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
if err != nil {
t.Fatalf("ListFiles root: %v", err)
}
if rootTotal != 1 {
t.Errorf("expected root total 1, got %d", rootTotal)
}
if len(rootFiles) != 1 || rootFiles[0].Name != "root.txt" {
t.Errorf("expected root.txt in root listing")
}
}
func TestGetFileMetadata_BlobMissing(t *testing.T) {
svc, _, db := setupFileService(t)
file := createTestFile(t, db, "orphan.txt", "nonexistent_sha256", nil)
_, _, err := svc.GetFileMetadata(context.Background(), file.ID)
if err == nil {
t.Fatal("expected error when blob is missing")
}
}

View File

@@ -0,0 +1,145 @@
package service
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// FileStagingService batch downloads files from MinIO to a local (NFS) directory,
// deduplicating by blob SHA256 so each unique blob is fetched only once.
type FileStagingService struct {
fileStore *store.FileStore
blobStore *store.BlobStore
storage storage.ObjectStorage
bucket string
logger *zap.Logger
}
func NewFileStagingService(fileStore *store.FileStore, blobStore *store.BlobStore, st storage.ObjectStorage, bucket string, logger *zap.Logger) *FileStagingService {
return &FileStagingService{
fileStore: fileStore,
blobStore: blobStore,
storage: st,
bucket: bucket,
logger: logger,
}
}
// DownloadFilesToDir downloads the given files into destDir.
// Files sharing the same blob SHA256 are deduplicated: the blob is fetched once
// and then copied to each filename. Filenames are sanitized with filepath.Base
// to prevent path traversal.
func (s *FileStagingService) DownloadFilesToDir(ctx context.Context, fileIDs []int64, destDir string) error {
if len(fileIDs) == 0 {
return nil
}
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
if err != nil {
return fmt.Errorf("fetch files: %w", err)
}
type group struct {
primary *model.File // first file — written via io.Copy from MinIO
others []*model.File // remaining files — local copy of primary
}
groups := make(map[string]*group)
for i := range files {
f := &files[i]
g, ok := groups[f.BlobSHA256]
if !ok {
groups[f.BlobSHA256] = &group{primary: f}
} else {
g.others = append(g.others, f)
}
}
sha256s := make([]string, 0, len(groups))
for sh := range groups {
sha256s = append(sha256s, sh)
}
blobs, err := s.blobStore.GetBySHA256s(ctx, sha256s)
if err != nil {
return fmt.Errorf("fetch blobs: %w", err)
}
blobMap := make(map[string]*model.FileBlob, len(blobs))
for i := range blobs {
blobMap[blobs[i].SHA256] = &blobs[i]
}
for sha256, g := range groups {
blob, ok := blobMap[sha256]
if !ok {
return fmt.Errorf("blob %s not found", sha256)
}
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{})
if err != nil {
return fmt.Errorf("get object %s: %w", blob.MinioKey, err)
}
// TODO: handle filename collisions when multiple files have the same Name (low risk without user auth, revisit when auth is added)
primaryName := filepath.Base(g.primary.Name)
primaryPath := filepath.Join(destDir, primaryName)
if err := writeFile(primaryPath, reader); err != nil {
reader.Close()
os.Remove(primaryPath)
return fmt.Errorf("write file %s: %w", primaryName, err)
}
reader.Close()
for _, other := range g.others {
otherName := filepath.Base(other.Name)
otherPath := filepath.Join(destDir, otherName)
if err := copyFile(primaryPath, otherPath); err != nil {
return fmt.Errorf("copy %s to %s: %w", primaryName, otherName, err)
}
}
}
return nil
}
func writeFile(path string, reader io.Reader) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, reader); err != nil {
return err
}
return nil
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
if _, err := io.Copy(out, in); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,232 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type stagingMockStorage struct {
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
}
func (m *stagingMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return nil, storage.ObjectInfo{}, nil
}
func (m *stagingMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *stagingMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, nil
}
func (m *stagingMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
return nil
}
func (m *stagingMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
return nil
}
func (m *stagingMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
return nil
}
func (m *stagingMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
return nil, nil
}
func (m *stagingMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
return nil
}
func (m *stagingMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
return true, nil
}
func (m *stagingMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
return nil
}
func (m *stagingMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
return storage.ObjectInfo{}, nil
}
func setupStagingTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func newStagingService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *FileStagingService {
t.Helper()
return NewFileStagingService(
store.NewFileStore(db),
store.NewBlobStore(db),
st,
"test-bucket",
zap.NewNop(),
)
}
func TestFileStaging_DownloadWithDedup(t *testing.T) {
db := setupStagingTestDB(t)
sha1 := "aaa111"
sha2 := "bbb222"
db.Create(&model.FileBlob{SHA256: sha1, MinioKey: "blobs/aaa111", FileSize: 5, MimeType: "text/plain", RefCount: 2})
db.Create(&model.FileBlob{SHA256: sha2, MinioKey: "blobs/bbb222", FileSize: 3, MimeType: "text/plain", RefCount: 1})
db.Create(&model.File{Name: "file1.txt", BlobSHA256: sha1})
db.Create(&model.File{Name: "file2.txt", BlobSHA256: sha1})
db.Create(&model.File{Name: "file3.txt", BlobSHA256: sha2})
var files []model.File
db.Find(&files)
if len(files) < 3 {
t.Fatalf("need 3 files, got %d", len(files))
}
var getObjCalls int32
st := &stagingMockStorage{}
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
atomic.AddInt32(&getObjCalls, 1)
var content string
switch key {
case "blobs/aaa111":
content = "content-a"
case "blobs/bbb222":
content = "content-b"
default:
return nil, storage.ObjectInfo{}, fmt.Errorf("unexpected key %s", key)
}
return io.NopCloser(bytes.NewReader([]byte(content))), storage.ObjectInfo{Key: key}, nil
}
destDir := t.TempDir()
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{files[0].ID, files[1].ID, files[2].ID}, destDir)
if err != nil {
t.Fatalf("DownloadFilesToDir: %v", err)
}
if calls := atomic.LoadInt32(&getObjCalls); calls != 2 {
t.Errorf("GetObject called %d times, want 2", calls)
}
expected := map[string]string{
"file1.txt": "content-a",
"file2.txt": "content-a",
"file3.txt": "content-b",
}
for name, want := range expected {
p := filepath.Join(destDir, name)
data, err := os.ReadFile(p)
if err != nil {
t.Errorf("read %s: %v", name, err)
continue
}
if string(data) != want {
t.Errorf("%s content = %q, want %q", name, data, want)
}
}
}
func TestFileStaging_PathTraversal(t *testing.T) {
db := setupStagingTestDB(t)
sha := "traversal123"
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/traversal", FileSize: 4, MimeType: "text/plain", RefCount: 1})
db.Create(&model.File{Name: "../../../etc/passwd", BlobSHA256: sha})
var file model.File
db.First(&file)
st := &stagingMockStorage{}
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return io.NopCloser(bytes.NewReader([]byte("safe"))), storage.ObjectInfo{Key: key}, nil
}
destDir := t.TempDir()
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
if err != nil {
t.Fatalf("DownloadFilesToDir: %v", err)
}
sanitized := filepath.Join(destDir, "passwd")
data, err := os.ReadFile(sanitized)
if err != nil {
t.Fatalf("read sanitized file: %v", err)
}
if string(data) != "safe" {
t.Errorf("content = %q, want %q", data, "safe")
}
entries, err := os.ReadDir(destDir)
if err != nil {
t.Fatalf("readdir: %v", err)
}
for _, e := range entries {
if e.Name() != "passwd" {
t.Errorf("unexpected file in destDir: %s", e.Name())
}
}
}
func TestFileStaging_EmptyList(t *testing.T) {
db := setupStagingTestDB(t)
st := &stagingMockStorage{}
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{}, t.TempDir())
if err != nil {
t.Errorf("expected nil for empty list, got %v", err)
}
}
func TestFileStaging_GetObjectFails(t *testing.T) {
db := setupStagingTestDB(t)
sha := "fail123"
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/fail", FileSize: 5, MimeType: "text/plain", RefCount: 1})
db.Create(&model.File{Name: "willfail.txt", BlobSHA256: sha})
var file model.File
db.First(&file)
st := &stagingMockStorage{}
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
return nil, storage.ObjectInfo{}, fmt.Errorf("minio down")
}
destDir := t.TempDir()
svc := newStagingService(t, st, db)
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
if err == nil {
t.Fatal("expected error when GetObject fails")
}
if !strings.Contains(err.Error(), "minio down") {
t.Errorf("error = %q, want 'minio down'", err.Error())
}
}

View File

@@ -0,0 +1,142 @@
package service
import (
"context"
"fmt"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
// FolderService provides CRUD operations for folders with path validation
// and directory tree management.
type FolderService struct {
folderStore *store.FolderStore
fileStore *store.FileStore
logger *zap.Logger
}
// NewFolderService creates a new FolderService.
func NewFolderService(folderStore *store.FolderStore, fileStore *store.FileStore, logger *zap.Logger) *FolderService {
return &FolderService{
folderStore: folderStore,
fileStore: fileStore,
logger: logger,
}
}
// CreateFolder validates the name, computes a materialized path, checks for
// duplicates, and persists the folder.
func (s *FolderService) CreateFolder(ctx context.Context, name string, parentID *int64) (*model.FolderResponse, error) {
if err := model.ValidateFolderName(name); err != nil {
return nil, fmt.Errorf("invalid folder name: %w", err)
}
var path string
if parentID == nil {
path = "/" + name + "/"
} else {
parent, err := s.folderStore.GetByID(ctx, *parentID)
if err != nil {
return nil, fmt.Errorf("get parent folder: %w", err)
}
if parent == nil {
return nil, fmt.Errorf("parent folder %d not found", *parentID)
}
path = parent.Path + name + "/"
}
existing, err := s.folderStore.GetByPath(ctx, path)
if err != nil {
return nil, fmt.Errorf("check duplicate path: %w", err)
}
if existing != nil {
return nil, fmt.Errorf("folder with path %q already exists", path)
}
folder := &model.Folder{
Name: name,
ParentID: parentID,
Path: path,
}
if err := s.folderStore.Create(ctx, folder); err != nil {
return nil, fmt.Errorf("create folder: %w", err)
}
return s.toFolderResponse(ctx, folder)
}
// GetFolder retrieves a folder by ID with file and subfolder counts.
func (s *FolderService) GetFolder(ctx context.Context, id int64) (*model.FolderResponse, error) {
folder, err := s.folderStore.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get folder: %w", err)
}
if folder == nil {
return nil, fmt.Errorf("folder %d not found", id)
}
return s.toFolderResponse(ctx, folder)
}
// ListFolders returns all direct children of the given parent folder (or root
// if parentID is nil).
func (s *FolderService) ListFolders(ctx context.Context, parentID *int64) ([]model.FolderResponse, error) {
folders, err := s.folderStore.ListByParentID(ctx, parentID)
if err != nil {
return nil, fmt.Errorf("list folders: %w", err)
}
result := make([]model.FolderResponse, 0, len(folders))
for i := range folders {
resp, err := s.toFolderResponse(ctx, &folders[i])
if err != nil {
return nil, err
}
result = append(result, *resp)
}
return result, nil
}
// DeleteFolder soft-deletes a folder only if it has no children (sub-folders
// or files).
func (s *FolderService) DeleteFolder(ctx context.Context, id int64) error {
hasChildren, err := s.folderStore.HasChildren(ctx, id)
if err != nil {
return fmt.Errorf("check children: %w", err)
}
if hasChildren {
return fmt.Errorf("folder is not empty")
}
if err := s.folderStore.Delete(ctx, id); err != nil {
return fmt.Errorf("delete folder: %w", err)
}
return nil
}
// toFolderResponse converts a Folder model into a FolderResponse DTO with
// computed file and subfolder counts.
func (s *FolderService) toFolderResponse(ctx context.Context, f *model.Folder) (*model.FolderResponse, error) {
subFolders, err := s.folderStore.ListByParentID(ctx, &f.ID)
if err != nil {
return nil, fmt.Errorf("count subfolders: %w", err)
}
_, fileCount, err := s.fileStore.List(ctx, &f.ID, 1, 1)
if err != nil {
return nil, fmt.Errorf("count files: %w", err)
}
return &model.FolderResponse{
ID: f.ID,
Name: f.Name,
ParentID: f.ParentID,
Path: f.Path,
FileCount: fileCount,
SubFolderCount: int64(len(subFolders)),
CreatedAt: f.CreatedAt,
}, nil
}

View File

@@ -0,0 +1,230 @@
package service
import (
"context"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
gormlogger "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupFolderService(t *testing.T) *FolderService {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Folder{}, &model.File{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return NewFolderService(
store.NewFolderStore(db),
store.NewFileStore(db),
zap.NewNop(),
)
}
func TestCreateFolder_ValidName(t *testing.T) {
svc := setupFolderService(t)
resp, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
if resp.Name != "datasets" {
t.Errorf("Name = %q, want %q", resp.Name, "datasets")
}
if resp.Path != "/datasets/" {
t.Errorf("Path = %q, want %q", resp.Path, "/datasets/")
}
if resp.ParentID != nil {
t.Errorf("ParentID should be nil for root folder, got %d", *resp.ParentID)
}
}
func TestCreateFolder_SubFolder(t *testing.T) {
svc := setupFolderService(t)
parent, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("create parent: %v", err)
}
child, err := svc.CreateFolder(context.Background(), "images", &parent.ID)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
if child.Path != "/datasets/images/" {
t.Errorf("Path = %q, want %q", child.Path, "/datasets/images/")
}
if child.ParentID == nil || *child.ParentID != parent.ID {
t.Errorf("ParentID = %v, want %d", child.ParentID, parent.ID)
}
}
func TestCreateFolder_RejectPathTraversal(t *testing.T) {
svc := setupFolderService(t)
for _, name := range []string{"..", "../etc", "/absolute", "a/b"} {
_, err := svc.CreateFolder(context.Background(), name, nil)
if err == nil {
t.Errorf("expected error for folder name %q, got nil", name)
}
}
}
func TestCreateFolder_DuplicatePath(t *testing.T) {
svc := setupFolderService(t)
_, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("first create: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "datasets", nil)
if err == nil {
t.Fatal("expected error for duplicate folder name")
}
if !strings.Contains(err.Error(), "already exists") {
t.Errorf("error should mention 'already exists', got: %v", err)
}
}
func TestCreateFolder_ParentNotFound(t *testing.T) {
svc := setupFolderService(t)
badID := int64(99999)
_, err := svc.CreateFolder(context.Background(), "orphan", &badID)
if err == nil {
t.Fatal("expected error for non-existent parent")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestGetFolder(t *testing.T) {
svc := setupFolderService(t)
created, err := svc.CreateFolder(context.Background(), "datasets", nil)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
resp, err := svc.GetFolder(context.Background(), created.ID)
if err != nil {
t.Fatalf("GetFolder() error = %v", err)
}
if resp.ID != created.ID {
t.Errorf("ID = %d, want %d", resp.ID, created.ID)
}
if resp.Name != "datasets" {
t.Errorf("Name = %q, want %q", resp.Name, "datasets")
}
if resp.FileCount != 0 {
t.Errorf("FileCount = %d, want 0", resp.FileCount)
}
if resp.SubFolderCount != 0 {
t.Errorf("SubFolderCount = %d, want 0", resp.SubFolderCount)
}
}
func TestGetFolder_NotFound(t *testing.T) {
svc := setupFolderService(t)
_, err := svc.GetFolder(context.Background(), 99999)
if err == nil {
t.Fatal("expected error for non-existent folder")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestListFolders(t *testing.T) {
svc := setupFolderService(t)
parent, err := svc.CreateFolder(context.Background(), "root", nil)
if err != nil {
t.Fatalf("create root: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "child1", &parent.ID)
if err != nil {
t.Fatalf("create child1: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "child2", &parent.ID)
if err != nil {
t.Fatalf("create child2: %v", err)
}
list, err := svc.ListFolders(context.Background(), &parent.ID)
if err != nil {
t.Fatalf("ListFolders() error = %v", err)
}
if len(list) != 2 {
t.Fatalf("len(list) = %d, want 2", len(list))
}
names := make(map[string]bool)
for _, f := range list {
names[f.Name] = true
}
if !names["child1"] || !names["child2"] {
t.Errorf("expected child1 and child2, got %v", names)
}
}
func TestListFolders_Root(t *testing.T) {
svc := setupFolderService(t)
_, err := svc.CreateFolder(context.Background(), "alpha", nil)
if err != nil {
t.Fatalf("create alpha: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "beta", nil)
if err != nil {
t.Fatalf("create beta: %v", err)
}
list, err := svc.ListFolders(context.Background(), nil)
if err != nil {
t.Fatalf("ListFolders() error = %v", err)
}
if len(list) != 2 {
t.Errorf("len(list) = %d, want 2", len(list))
}
}
func TestDeleteFolder_Success(t *testing.T) {
svc := setupFolderService(t)
created, err := svc.CreateFolder(context.Background(), "temp", nil)
if err != nil {
t.Fatalf("CreateFolder() error = %v", err)
}
if err := svc.DeleteFolder(context.Background(), created.ID); err != nil {
t.Fatalf("DeleteFolder() error = %v", err)
}
_, err = svc.GetFolder(context.Background(), created.ID)
if err == nil {
t.Error("expected error after deletion")
}
}
func TestDeleteFolder_NonEmpty(t *testing.T) {
svc := setupFolderService(t)
parent, err := svc.CreateFolder(context.Background(), "haschild", nil)
if err != nil {
t.Fatalf("create parent: %v", err)
}
_, err = svc.CreateFolder(context.Background(), "child", &parent.ID)
if err != nil {
t.Fatalf("create child: %v", err)
}
err = svc.DeleteFolder(context.Background(), parent.ID)
if err == nil {
t.Fatal("expected error when deleting non-empty folder")
}
if !strings.Contains(err.Error(), "not empty") {
t.Errorf("error should mention 'not empty', got: %v", err)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strconv"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
@@ -31,6 +32,9 @@ func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest)
Qos: strToPtrOrNil(req.QOS),
Name: strToPtrOrNil(req.JobName),
}
if req.WorkDir != "" {
jobDesc.CurrentWorkingDirectory = &req.WorkDir
}
if req.CPUs > 0 {
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
}
@@ -40,17 +44,41 @@ func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest)
}
}
jobDesc.Environment = slurm.StringArray{
"PATH=/usr/local/bin:/usr/bin:/bin",
"HOME=/root",
}
submitReq := &slurm.JobSubmitReq{
Script: &script,
Job: jobDesc,
}
s.logger.Debug("slurm API request",
zap.String("operation", "SubmitJob"),
zap.Any("body", submitReq),
)
start := time.Now()
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
return nil, fmt.Errorf("submit job: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "SubmitJob"),
zap.Duration("took", took),
zap.Any("body", result),
)
resp := &model.JobResponse{}
if result.Result != nil && result.Result.JobID != nil {
resp.JobID = *result.Result.JobID
@@ -62,44 +90,173 @@ func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest)
return resp, nil
}
// GetJobs lists all current jobs from Slurm.
func (s *JobService) GetJobs(ctx context.Context) ([]model.JobResponse, error) {
// GetJobs lists all current jobs from Slurm with in-memory pagination.
func (s *JobService) GetJobs(ctx context.Context, query *model.JobListQuery) (*model.JobListResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobs"),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
return nil, fmt.Errorf("get jobs: %w", err)
}
jobs := make([]model.JobResponse, 0, len(result.Jobs))
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobs"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
jobs = append(jobs, mapJobInfo(&result.Jobs[i]))
allJobs = append(allJobs, mapJobInfo(&result.Jobs[i]))
}
return jobs, nil
total := len(allJobs)
page := query.Page
pageSize := query.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
// GetJob retrieves a single job by ID.
// GetJob retrieves a single job by ID. If the job is not found in the active
// queue (404 or empty result), it falls back to querying SlurmDBD history.
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
s.logger.Debug("slurm API request",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
if slurm.IsNotFound(err) {
s.logger.Debug("job not in active queue, querying history",
zap.String("job_id", jobID),
)
return s.getJobFromHistory(ctx, jobID)
}
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
return nil, fmt.Errorf("get job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
return nil, nil
s.logger.Debug("empty jobs response, querying history",
zap.String("job_id", jobID),
)
return s.getJobFromHistory(ctx, jobID)
}
resp := mapJobInfo(&result.Jobs[0])
return &resp, nil
}
func (s *JobService) getJobFromHistory(ctx context.Context, jobID string) (*model.JobResponse, error) {
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJob(ctx, jobID)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurmdb API error response",
zap.String("operation", "getJobFromHistory"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
if slurm.IsNotFound(err) {
return nil, nil
}
return nil, fmt.Errorf("get job history %s: %w", jobID, err)
}
s.logger.Debug("slurmdb API response",
zap.String("operation", "getJobFromHistory"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
if len(result.Jobs) == 0 {
return nil, nil
}
resp := mapSlurmdbJob(&result.Jobs[0])
return &resp, nil
}
// CancelJob cancels a job by ID.
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
_, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
s.logger.Debug("slurm API request",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
)
start := time.Now()
result, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
return fmt.Errorf("cancel job %s: %w", jobID, err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "CancelJob"),
zap.String("job_id", jobID),
zap.Duration("took", took),
zap.Any("body", result),
)
s.logger.Info("job cancelled", zap.String("job_id", jobID))
return nil
}
@@ -128,13 +285,60 @@ func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQ
if query.EndTime != "" {
opts.EndTime = strToPtr(query.EndTime)
}
if query.SubmitTime != "" {
opts.SubmitTime = strToPtr(query.SubmitTime)
}
if query.Cluster != "" {
opts.Cluster = strToPtr(query.Cluster)
}
if query.Qos != "" {
opts.Qos = strToPtr(query.Qos)
}
if query.Constraints != "" {
opts.Constraints = strToPtr(query.Constraints)
}
if query.ExitCode != "" {
opts.ExitCode = strToPtr(query.ExitCode)
}
if query.Node != "" {
opts.Node = strToPtr(query.Node)
}
if query.Reservation != "" {
opts.Reservation = strToPtr(query.Reservation)
}
if query.Groups != "" {
opts.Groups = strToPtr(query.Groups)
}
if query.Wckey != "" {
opts.Wckey = strToPtr(query.Wckey)
}
s.logger.Debug("slurm API request",
zap.String("operation", "GetJobHistory"),
zap.Any("body", opts),
)
start := time.Now()
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
took := time.Since(start)
if err != nil {
s.logger.Debug("slurm API error response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Error(err),
)
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
return nil, fmt.Errorf("get job history: %w", err)
}
s.logger.Debug("slurm API response",
zap.String("operation", "GetJobHistory"),
zap.Duration("took", took),
zap.Int("job_count", len(result.Jobs)),
zap.Any("body", result),
)
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
for i := range result.Jobs {
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
@@ -150,17 +354,17 @@ func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQ
pageSize = 20
}
start := (page - 1) * pageSize
end := start + pageSize
if start > total {
start = total
startIdx := (page - 1) * pageSize
end := startIdx + pageSize
if startIdx > total {
startIdx = total
}
if end > total {
end = total
}
return &model.JobListResponse{
Jobs: allJobs[start:end],
Jobs: allJobs[startIdx:end],
Total: total,
Page: page,
PageSize: pageSize,
@@ -181,6 +385,14 @@ func strToPtrOrNil(s string) *string {
return &s
}
func mapUint32NoValToInt32(v *slurm.Uint32NoVal) *int32 {
if v != nil && v.Number != nil {
n := int32(*v.Number)
return &n
}
return nil
}
// mapJobInfo maps SDK JobInfo to API JobResponse.
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
resp := model.JobResponse{}
@@ -194,6 +406,17 @@ func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
if ji.Partition != nil {
resp.Partition = *ji.Partition
}
resp.Account = derefStr(ji.Account)
resp.User = derefStr(ji.UserName)
resp.Cluster = derefStr(ji.Cluster)
resp.QOS = derefStr(ji.Qos)
resp.Priority = mapUint32NoValToInt32(ji.Priority)
resp.TimeLimit = uint32NoValString(ji.TimeLimit)
resp.StateReason = derefStr(ji.StateReason)
resp.Cpus = mapUint32NoValToInt32(ji.Cpus)
resp.Tasks = mapUint32NoValToInt32(ji.Tasks)
resp.NodeCount = mapUint32NoValToInt32(ji.NodeCount)
resp.BatchHost = derefStr(ji.BatchHost)
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
resp.SubmitTime = ji.SubmitTime.Number
}
@@ -210,6 +433,13 @@ func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
if ji.Nodes != nil {
resp.Nodes = *ji.Nodes
}
resp.StdOut = derefStr(ji.StandardOutput)
resp.StdErr = derefStr(ji.StandardError)
resp.StdIn = derefStr(ji.StandardInput)
resp.WorkDir = derefStr(ji.CurrentWorkingDirectory)
resp.Command = derefStr(ji.Command)
resp.ArrayJobID = mapUint32NoValToInt32(ji.ArrayJobID)
resp.ArrayTaskID = mapUint32NoValToInt32(ji.ArrayTaskID)
return resp
}
@@ -224,11 +454,20 @@ func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
}
if j.State != nil {
resp.State = j.State.Current
resp.StateReason = derefStr(j.State.Reason)
}
if j.Partition != nil {
resp.Partition = *j.Partition
}
resp.Account = derefStr(j.Account)
if j.User != nil {
resp.User = *j.User
}
resp.Cluster = derefStr(j.Cluster)
resp.QOS = derefStr(j.Qos)
resp.Priority = mapUint32NoValToInt32(j.Priority)
if j.Time != nil {
resp.TimeLimit = uint32NoValString(j.Time.Limit)
if j.Time.Submission != nil {
resp.SubmitTime = j.Time.Submission
}
@@ -239,8 +478,19 @@ func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
resp.EndTime = j.Time.End
}
}
if j.ExitCode != nil && j.ExitCode.ReturnCode != nil && j.ExitCode.ReturnCode.Number != nil {
code := int32(*j.ExitCode.ReturnCode.Number)
resp.ExitCode = &code
}
if j.Nodes != nil {
resp.Nodes = *j.Nodes
}
if j.Required != nil {
resp.Cpus = j.Required.CPUs
}
if j.AllocationNodes != nil {
resp.NodeCount = j.AllocationNodes
}
resp.WorkDir = derefStr(j.WorkingDirectory)
return resp
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gcy_hpc_server/internal/model"
@@ -148,14 +149,17 @@ func TestGetJobs(t *testing.T) {
defer cleanup()
svc := NewJobService(client, zap.NewNop())
jobs, err := svc.GetJobs(context.Background())
result, err := svc.GetJobs(context.Background(), &model.JobListQuery{Page: 1, PageSize: 20})
if err != nil {
t.Fatalf("GetJobs: %v", err)
}
if len(jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(jobs))
if result.Total != 1 {
t.Fatalf("expected total 1, got %d", result.Total)
}
j := jobs[0]
if len(result.Jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(result.Jobs))
}
j := result.Jobs[0]
if j.JobID != 100 {
t.Errorf("expected JobID 100, got %d", j.JobID)
}
@@ -174,6 +178,12 @@ func TestGetJobs(t *testing.T) {
if j.Nodes != "node01" {
t.Errorf("expected Nodes node01, got %s", j.Nodes)
}
if result.Page != 1 {
t.Errorf("expected Page 1, got %d", result.Page)
}
if result.PageSize != 20 {
t.Errorf("expected PageSize 20, got %d", result.PageSize)
}
}
func TestGetJob(t *testing.T) {
@@ -506,13 +516,13 @@ func TestJobService_SubmitJob_SuccessLog(t *testing.T) {
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["job_name"] != "log-test-job" {
t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"])
}
@@ -539,13 +549,13 @@ func TestJobService_SubmitJob_ErrorLog(t *testing.T) {
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["operation"] != "submit" {
t.Errorf("expected operation=submit, got %v", fields["operation"])
}
@@ -568,13 +578,13 @@ func TestJobService_CancelJob_SuccessLog(t *testing.T) {
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.InfoLevel {
t.Errorf("expected InfoLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["job_id"] != "555" {
t.Errorf("expected job_id=555, got %v", fields["job_id"])
}
@@ -594,13 +604,13 @@ func TestJobService_CancelJob_ErrorLog(t *testing.T) {
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["operation"] != "cancel" {
t.Errorf("expected operation=cancel, got %v", fields["operation"])
}
@@ -620,19 +630,19 @@ func TestJobService_GetJobs_ErrorLog(t *testing.T) {
defer srv.Close()
svc, recorded := newJobServiceWithObserver(srv)
_, err := svc.GetJobs(context.Background())
_, err := svc.GetJobs(context.Background(), &model.JobListQuery{Page: 1, PageSize: 20})
if err == nil {
t.Fatal("expected error, got nil")
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["operation"] != "get_jobs" {
t.Errorf("expected operation=get_jobs, got %v", fields["operation"])
}
@@ -655,13 +665,13 @@ func TestJobService_GetJob_ErrorLog(t *testing.T) {
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["operation"] != "get_job" {
t.Errorf("expected operation=get_job, got %v", fields["operation"])
}
@@ -687,13 +697,13 @@ func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
}
entries := recorded.All()
if len(entries) != 1 {
t.Fatalf("expected 1 log entry, got %d", len(entries))
if len(entries) != 3 {
t.Fatalf("expected 3 log entries, got %d", len(entries))
}
if entries[0].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
if entries[2].Level != zapcore.ErrorLevel {
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
}
fields := entries[0].ContextMap()
fields := entries[2].ContextMap()
if fields["operation"] != "get_job_history" {
t.Errorf("expected operation=get_job_history, got %v", fields["operation"])
}
@@ -701,3 +711,157 @@ func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
t.Error("expected error field in log entry")
}
}
// ---------------------------------------------------------------------------
// Fallback to SlurmDBD history tests
// ---------------------------------------------------------------------------
func TestGetJob_FallbackToHistory_Found(t *testing.T) {
jobID := int32(198)
name := "hist-job"
ts := int64(1700000000)
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/slurm/v0.0.40/job/198":
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]interface{}{
"errors": []map[string]interface{}{
{
"description": "Unable to query JobId=198",
"error_number": float64(2017),
"error": "Invalid job id specified",
"source": "_handle_job_get",
},
},
"jobs": []interface{}{},
})
case "/slurmdb/v0.0.40/job/198":
resp := slurm.OpenapiSlurmdbdJobsResp{
Jobs: slurm.JobList{
{
JobID: &jobID,
Name: &name,
State: &slurm.JobState{Current: []string{"COMPLETED"}},
Time: &slurm.JobTime{Submission: &ts, Start: &ts, End: &ts},
},
},
}
json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "198")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job == nil {
t.Fatal("expected job, got nil")
}
if job.JobID != 198 {
t.Errorf("expected JobID 198, got %d", job.JobID)
}
if job.Name != "hist-job" {
t.Errorf("expected Name hist-job, got %s", job.Name)
}
}
func TestGetJob_FallbackToHistory_NotFound(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "999")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job != nil {
t.Errorf("expected nil, got %+v", job)
}
}
func TestGetJob_FallbackToHistory_HistoryError(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/slurm/v0.0.40/job/500":
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]interface{}{
"errors": []map[string]interface{}{
{
"description": "Unable to query JobId=500",
"error_number": float64(2017),
"error": "Invalid job id specified",
"source": "_handle_job_get",
},
},
"jobs": []interface{}{},
})
case "/slurmdb/v0.0.40/job/500":
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"errors":[{"error":"db error"}]}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "500")
if err == nil {
t.Fatal("expected error, got nil")
}
if job != nil {
t.Errorf("expected nil job, got %+v", job)
}
if !strings.Contains(err.Error(), "get job history") {
t.Errorf("expected error to contain 'get job history', got %s", err.Error())
}
}
func TestGetJob_FallbackToHistory_EmptyHistory(t *testing.T) {
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/slurm/v0.0.40/job/777":
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]interface{}{
"errors": []map[string]interface{}{
{
"description": "Unable to query JobId=777",
"error_number": float64(2017),
"error": "Invalid job id specified",
"source": "_handle_job_get",
},
},
"jobs": []interface{}{},
})
case "/slurmdb/v0.0.40/job/777":
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer cleanup()
svc := NewJobService(client, zap.NewNop())
job, err := svc.GetJob(context.Background(), "777")
if err != nil {
t.Fatalf("GetJob: %v", err)
}
if job != nil {
t.Errorf("expected nil, got %+v", job)
}
}

View File

@@ -0,0 +1,112 @@
package service
import (
"fmt"
"math/rand"
"regexp"
"sort"
"strconv"
"strings"
"gcy_hpc_server/internal/model"
)
var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
// ValidateParams checks that all required parameters are present and values match their types.
// Parameters not in the schema are silently ignored.
func ValidateParams(params []model.ParameterSchema, values map[string]string) error {
var errs []string
for _, p := range params {
if !paramNameRegex.MatchString(p.Name) {
errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name))
continue
}
val, ok := values[p.Name]
if p.Required && !ok {
errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name))
continue
}
if !ok {
continue
}
switch p.Type {
case model.ParamTypeInteger:
if _, err := strconv.Atoi(val); err != nil {
errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val))
}
case model.ParamTypeBoolean:
if val != "true" && val != "false" && val != "1" && val != "0" {
errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val))
}
case model.ParamTypeEnum:
if len(p.Options) > 0 {
found := false
for _, opt := range p.Options {
if val == opt {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val))
}
}
case model.ParamTypeFile, model.ParamTypeDirectory:
case model.ParamTypeString:
}
}
if len(errs) > 0 {
return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; "))
}
return nil
}
// RenderScript replaces $PARAM tokens in the template with user-provided values.
// Only tokens defined in the schema are replaced. Replacement is done longest-name-first
// to avoid partial matches (e.g., $JOB_NAME before $JOB).
// All values are shell-escaped using single-quote wrapping.
func RenderScript(template string, params []model.ParameterSchema, values map[string]string) string {
sorted := make([]model.ParameterSchema, len(params))
copy(sorted, params)
sort.Slice(sorted, func(i, j int) bool {
return len(sorted[i].Name) > len(sorted[j].Name)
})
result := template
for _, p := range sorted {
val, ok := values[p.Name]
if !ok {
if p.Default != "" {
val = p.Default
} else {
continue
}
}
escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'"
result = strings.ReplaceAll(result, "$"+p.Name, escaped)
}
return result
}
// SanitizeDirName sanitizes a directory name.
func SanitizeDirName(name string) string {
replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
return replacer.Replace(name)
}
// RandomSuffix generates a random suffix of length n.
func RandomSuffix(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}

View File

@@ -0,0 +1,554 @@
package service
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
)
type TaskService struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
fileStore *store.FileStore // nil ok
blobStore *store.BlobStore // nil ok
stagingSvc *FileStagingService // nil ok — MinIO unavailable
jobSvc *JobService
workDirBase string
logger *zap.Logger
// async processing
taskCh chan int64 // buffered channel, cap=16
cancelFn context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex // protects taskCh from send-on-closed
started bool // prevent double-start
stopped bool
}
func NewTaskService(
taskStore *store.TaskStore,
appStore *store.ApplicationStore,
fileStore *store.FileStore,
blobStore *store.BlobStore,
stagingSvc *FileStagingService,
jobSvc *JobService,
workDirBase string,
logger *zap.Logger,
) *TaskService {
return &TaskService{
taskStore: taskStore,
appStore: appStore,
fileStore: fileStore,
blobStore: blobStore,
stagingSvc: stagingSvc,
jobSvc: jobSvc,
workDirBase: workDirBase,
logger: logger,
taskCh: make(chan int64, 16),
}
}
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
app, err := s.appStore.GetByID(ctx, req.AppID)
if err != nil {
return nil, fmt.Errorf("get application: %w", err)
}
if app == nil {
return nil, fmt.Errorf("application %d not found", req.AppID)
}
// 2. Validate file limit
if len(req.InputFileIDs) > 100 {
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
}
// 3. Deduplicate file IDs
fileIDs := uniqueInt64s(req.InputFileIDs)
// 4. Validate file IDs exist
if s.fileStore != nil && len(fileIDs) > 0 {
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
if err != nil {
return nil, fmt.Errorf("validate file ids: %w", err)
}
found := make(map[int64]bool, len(files))
for _, f := range files {
found[f.ID] = true
}
for _, id := range fileIDs {
if !found[id] {
return nil, fmt.Errorf("file %d not found", id)
}
}
}
// 5. Auto-generate task name if empty
taskName := req.TaskName
if taskName == "" {
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
}
// 6. Marshal values
valuesJSON := json.RawMessage(`{}`)
if len(req.Values) > 0 {
b, err := json.Marshal(req.Values)
if err != nil {
return nil, fmt.Errorf("marshal values: %w", err)
}
valuesJSON = b
}
// 7. Marshal input_file_ids
fileIDsJSON := json.RawMessage(`[]`)
if len(fileIDs) > 0 {
b, err := json.Marshal(fileIDs)
if err != nil {
return nil, fmt.Errorf("marshal file ids: %w", err)
}
fileIDsJSON = b
}
// 8. Create task record
task := &model.Task{
TaskName: taskName,
AppID: app.ID,
AppName: app.Name,
Status: model.TaskStatusSubmitted,
Values: valuesJSON,
InputFileIDs: fileIDsJSON,
SubmittedAt: time.Now(),
}
taskID, err := s.taskStore.Create(ctx, task)
if err != nil {
return nil, fmt.Errorf("create task: %w", err)
}
task.ID = taskID
return task, nil
}
// ProcessTask runs the full synchronous processing pipeline for a task.
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
// 1. Fetch task
task, err := s.taskStore.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("get task: %w", err)
}
if task == nil {
return fmt.Errorf("task %d not found", taskID)
}
fail := func(step, msg string) error {
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
return fmt.Errorf("%s", msg)
}
currentStep := task.CurrentStep
var workDir string
var app *model.Application
if currentStep == "" || currentStep == model.TaskStepPreparing {
// 2. Set preparing
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
}
// 3. Fetch app
app, err = s.appStore.GetByID(ctx, task.AppID)
if err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
}
if app == nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
}
// 4-5. Create work directory
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
if err := os.MkdirAll(workDir, 0777); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
}
// 6. CHMOD traversal — critical for multi-user HPC
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
os.Chmod(dir, 0777)
}
os.Chmod(s.workDirBase, 0777)
// 7. UpdateWorkDir
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
}
} else {
app, err = s.appStore.GetByID(ctx, task.AppID)
if err != nil {
return fail(currentStep, fmt.Sprintf("get application: %v", err))
}
if app == nil {
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
}
workDir = task.WorkDir
}
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
if currentStep == model.TaskStepDownloading && workDir != "" {
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
for _, f := range matches {
os.Remove(f)
}
}
// 8. Set downloading
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
}
// 9. Parse input_file_ids
var fileIDs []int64
if len(task.InputFileIDs) > 0 {
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
}
}
// 10-12. Download files
if len(fileIDs) > 0 {
if s.stagingSvc == nil {
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
}
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
}
}
}
// 13-14. Set ready + submitting
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
}
// 15. Parse app parameters
var params []model.ParameterSchema
if len(app.Parameters) > 0 {
if err := json.Unmarshal(app.Parameters, &params); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
}
}
// 16. Parse task values
values := make(map[string]string)
if len(task.Values) > 0 {
if err := json.Unmarshal(task.Values, &values); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
}
}
if err := ValidateParams(params, values); err != nil {
return fail(model.TaskStepSubmitting, err.Error())
}
// 17. Render script
rendered := RenderScript(app.ScriptTemplate, params, values)
// 18. Submit to Slurm
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
Script: rendered,
WorkDir: workDir,
})
if err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
}
// 19. Update slurm_job_id and status to queued
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
}
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
}
return nil
}
// ListTasks returns a paginated list of tasks.
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
return s.taskStore.List(ctx, query)
}
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
// for old API compatibility.
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
// 1. Create task
task, err := s.CreateTask(ctx, req)
if err != nil {
return nil, err
}
// 2. Process synchronously
if err := s.ProcessTask(ctx, task.ID); err != nil {
return nil, err
}
// 3. Re-fetch to get updated slurm_job_id
task, err = s.taskStore.GetByID(ctx, task.ID)
if err != nil {
return nil, fmt.Errorf("re-fetch task: %w", err)
}
if task == nil || task.SlurmJobID == nil {
return nil, fmt.Errorf("task has no slurm job id after processing")
}
// 4. Return JobResponse
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
}
// uniqueInt64s deduplicates and sorts a slice of int64.
func uniqueInt64s(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
seen := make(map[int64]bool, len(ids))
result := make([]int64, 0, len(ids))
for _, id := range ids {
if !seen[id] {
seen[id] = true
result = append(result, id)
}
}
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
return result
}
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
if len(slurmState) == 0 {
return model.TaskStatusRunning
}
state := strings.ToUpper(slurmState[0])
switch state {
case "PENDING":
return model.TaskStatusQueued
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
return model.TaskStatusRunning
case "COMPLETED":
return model.TaskStatusCompleted
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
return model.TaskStatusFailed
default:
return model.TaskStatusRunning
}
}
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
task, err := s.taskStore.GetByID(ctx, taskID)
if err != nil {
s.logger.Error("failed to fetch task for refresh",
zap.Int64("task_id", taskID),
zap.Error(err),
)
return err
}
if task == nil || task.SlurmJobID == nil {
return nil
}
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
if err != nil {
s.logger.Warn("failed to query slurm job status during refresh",
zap.Int64("task_id", taskID),
zap.Int32("slurm_job_id", *task.SlurmJobID),
zap.Error(err),
)
return nil
}
if jobResp == nil {
return nil
}
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
if newStatus != task.Status {
s.logger.Info("updating task status from slurm",
zap.Int64("task_id", taskID),
zap.String("old_status", task.Status),
zap.String("new_status", newStatus),
)
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
}
return nil
}
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
staleThreshold := 30 * time.Second
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
for _, status := range nonTerminal {
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
Status: status,
Page: 1,
PageSize: 1000,
})
if err != nil {
s.logger.Warn("failed to list tasks for stale refresh",
zap.String("status", status),
zap.Error(err),
)
continue
}
cutoff := time.Now().Add(-staleThreshold)
for i := range tasks {
if tasks[i].UpdatedAt.Before(cutoff) {
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
s.logger.Warn("failed to refresh stale task",
zap.Int64("task_id", tasks[i].ID),
zap.Error(err),
)
}
}
}
}
return nil
}
func (s *TaskService) StartProcessor(ctx context.Context) {
s.mu.Lock()
if s.started {
s.mu.Unlock()
return
}
s.started = true
s.mu.Unlock()
ctx, s.cancelFn = context.WithCancel(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
if r := recover(); r != nil {
s.logger.Error("processor panic", zap.Any("panic", r))
}
}()
for {
select {
case <-ctx.Done():
return
case taskID, ok := <-s.taskCh:
if !ok {
return
}
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
s.processWithRetry(taskCtx, taskID)
cancel()
}
}
}()
s.RecoverStuckTasks(ctx)
}
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
task, err := s.CreateTask(ctx, req)
if err != nil {
return 0, err
}
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return 0, fmt.Errorf("processor stopped, cannot submit task")
}
select {
case s.taskCh <- task.ID:
default:
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
}
s.mu.Unlock()
return task.ID, nil
}
func (s *TaskService) StopProcessor() {
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return
}
s.stopped = true
close(s.taskCh)
s.mu.Unlock()
if s.cancelFn != nil {
s.cancelFn()
}
s.wg.Wait()
s.mu.Lock()
drainCh := s.taskCh
s.taskCh = make(chan int64, 16)
s.mu.Unlock()
for taskID := range drainCh {
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
}
}
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
err := s.ProcessTask(ctx, taskID)
if err == nil {
return
}
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
if fetchErr != nil || task == nil {
return
}
if task.RetryCount < 3 {
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
s.mu.Lock()
if !s.stopped {
select {
case s.taskCh <- taskID:
default:
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
}
}
s.mu.Unlock()
}
}
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
if err != nil {
s.logger.Error("failed to get stuck tasks", zap.Error(err))
return
}
for i := range tasks {
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
s.mu.Lock()
if !s.stopped {
select {
case s.taskCh <- tasks[i].ID:
default:
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
}
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,416 @@
package service
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
gormlogger "gorm.io/gorm/logger"
"gorm.io/gorm"
)
func setupAsyncTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type asyncTestEnv struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
svc *TaskService
srv *httptest.Server
db *gorm.DB
workDirBase string
}
func newAsyncTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *asyncTestEnv {
t.Helper()
db := setupAsyncTestDB(t)
ts := store.NewTaskStore(db)
as := store.NewApplicationStore(db)
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
workDirBase := filepath.Join(t.TempDir(), "workdir")
os.MkdirAll(workDirBase, 0777)
svc := NewTaskService(ts, as, nil, nil, nil, jobSvc, workDirBase, zap.NewNop())
return &asyncTestEnv{
taskStore: ts,
appStore: as,
svc: svc,
srv: srv,
db: db,
workDirBase: workDirBase,
}
}
func (e *asyncTestEnv) close() {
e.srv.Close()
}
func (e *asyncTestEnv) createApp(t *testing.T, name, script string) int64 {
t.Helper()
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
Name: name,
ScriptTemplate: script,
Parameters: json.RawMessage(`[]`),
})
if err != nil {
t.Fatalf("create app: %v", err)
}
return id
}
func TestTaskService_Async_SubmitAndProcess(t *testing.T) {
jobID := int32(42)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "async-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "async-test",
})
if err != nil {
t.Fatalf("SubmitAsync: %v", err)
}
if taskID == 0 {
t.Fatal("expected non-zero task ID")
}
time.Sleep(500 * time.Millisecond)
task, err := env.taskStore.GetByID(ctx, taskID)
if err != nil {
t.Fatalf("GetByID: %v", err)
}
if task.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusQueued)
}
env.svc.StopProcessor()
}
func TestTaskService_Retry_MaxExhaustion(t *testing.T) {
callCount := int32(0)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&callCount, 1)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"slurm down"}`))
}))
defer env.close()
appID := env.createApp(t, "retry-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "retry-test",
})
if err != nil {
t.Fatalf("SubmitAsync: %v", err)
}
time.Sleep(2 * time.Second)
task, _ := env.taskStore.GetByID(ctx, taskID)
if task.Status != model.TaskStatusFailed {
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusFailed)
}
if task.RetryCount < 3 {
t.Errorf("RetryCount = %d, want >= 3", task.RetryCount)
}
env.svc.StopProcessor()
}
func TestTaskService_Recover_StuckTasks(t *testing.T) {
jobID := int32(99)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "stuck-app", "#!/bin/bash\necho hello")
ctx := context.Background()
task := &model.Task{
TaskName: "stuck-task",
AppID: appID,
AppName: "stuck-app",
Status: model.TaskStatusPreparing,
CurrentStep: model.TaskStepPreparing,
RetryCount: 0,
SubmittedAt: time.Now(),
}
taskID, err := env.taskStore.Create(ctx, task)
if err != nil {
t.Fatalf("Create stuck task: %v", err)
}
staleTime := time.Now().Add(-10 * time.Minute)
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, taskID)
env.svc.StartProcessor(ctx)
time.Sleep(1 * time.Second)
updated, _ := env.taskStore.GetByID(ctx, taskID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
env.svc.StopProcessor()
}
func TestTaskService_Shutdown_InFlight(t *testing.T) {
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
jobID := int32(77)
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "shutdown-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "shutdown-test",
})
done := make(chan struct{})
go func() {
env.svc.StopProcessor()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("StopProcessor did not complete within timeout")
}
task, _ := env.taskStore.GetByID(ctx, taskID)
if task.Status != model.TaskStatusQueued && task.Status != model.TaskStatusSubmitted {
t.Logf("task status after shutdown: %q (acceptable)", task.Status)
}
}
func TestTaskService_PanicRecovery(t *testing.T) {
jobID := int32(55)
panicDone := int32(0)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.CompareAndSwapInt32(&panicDone, 0, 1) {
panic("intentional test panic")
}
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "panic-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "panic-test",
})
time.Sleep(1 * time.Second)
atomic.StoreInt32(&panicDone, 1)
env.svc.StopProcessor()
_ = taskID
}
func TestTaskService_SubmitAsync_DuringShutdown(t *testing.T) {
env := newAsyncTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "shutdown-err-app", "#!/bin/bash\necho hello")
ctx := context.Background()
env.svc.StartProcessor(ctx)
env.svc.StopProcessor()
_, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "after-shutdown",
})
if err == nil {
t.Fatal("expected error when submitting after shutdown")
}
}
// TestTaskService_SubmitAsync_ChannelFull_NonBlocking verifies SubmitAsync
// returns without blocking when the task channel buffer (cap=16) is full.
// Before fix: SubmitAsync holds s.mu while blocking on full channel → deadlock.
// After fix: non-blocking select returns immediately.
func TestTaskService_SubmitAsync_ChannelFull_NonBlocking(t *testing.T) {
jobID := int32(42)
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(5 * time.Second)
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "channel-full-app", "#!/bin/bash\necho hello")
ctx := context.Background()
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
taskIDs := make([]int64, 17)
for i := range taskIDs {
id, err := env.taskStore.Create(ctx, &model.Task{
TaskName: fmt.Sprintf("fill-%d", i),
AppID: appID,
AppName: "channel-full-app",
Status: model.TaskStatusSubmitted,
CurrentStep: model.TaskStepSubmitting,
SubmittedAt: time.Now(),
})
if err != nil {
t.Fatalf("create fill task %d: %v", i, err)
}
taskIDs[i] = id
}
env.svc.StartProcessor(ctx)
defer env.svc.StopProcessor()
// Consumer grabs first ID immediately; remaining 15 sit in channel.
// Push one more to fill buffer to 16 (full).
for _, id := range taskIDs {
env.svc.taskCh <- id
}
// Overflow submit: must return within 3s (non-blocking after fix)
done := make(chan error, 1)
go func() {
_, submitErr := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
AppID: appID,
TaskName: "overflow-task",
})
done <- submitErr
}()
select {
case err := <-done:
if err != nil {
t.Logf("SubmitAsync returned error (acceptable after fix): %v", err)
} else {
t.Log("SubmitAsync returned without blocking — channel send is non-blocking")
}
case <-time.After(3 * time.Second):
t.Fatal("SubmitAsync blocked for >3s — channel send is blocking, potential deadlock")
}
}
// TestTaskService_Retry_ChannelFull_NonBlocking verifies processWithRetry
// does not deadlock when re-enqueuing a failed task into a full channel.
// Before fix: processWithRetry holds s.mu while blocking on s.taskCh <- taskID → deadlock.
// After fix: non-blocking select drops the retry with a Warn log.
func TestTaskService_Retry_ChannelFull_NonBlocking(t *testing.T) {
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"slurm down"}`))
}))
defer env.close()
appID := env.createApp(t, "retry-full-app", "#!/bin/bash\necho hello")
ctx := context.Background()
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
taskIDs := make([]int64, 17)
for i := range taskIDs {
id, err := env.taskStore.Create(ctx, &model.Task{
TaskName: fmt.Sprintf("retry-%d", i),
AppID: appID,
AppName: "retry-full-app",
Status: model.TaskStatusSubmitted,
CurrentStep: model.TaskStepSubmitting,
RetryCount: 0,
SubmittedAt: time.Now(),
})
if err != nil {
t.Fatalf("create retry task %d: %v", i, err)
}
taskIDs[i] = id
}
env.svc.StartProcessor(ctx)
// Push all 17 IDs: consumer grabs one (processing ~1s), 16 fill the buffer
for _, id := range taskIDs {
env.svc.taskCh <- id
}
// Wait for consumer to finish first task and attempt retry into full channel
time.Sleep(2 * time.Second)
// If processWithRetry deadlocked holding s.mu, StopProcessor hangs on mutex acquisition
done := make(chan struct{})
go func() {
env.svc.StopProcessor()
close(done)
}()
select {
case <-done:
t.Log("StopProcessor completed — retry channel send is non-blocking")
case <-time.After(5 * time.Second):
t.Fatal("StopProcessor did not complete within 5s — deadlock from retry channel send")
}
}

View File

@@ -0,0 +1,294 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskSvcTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type taskSvcTestEnv struct {
taskStore *store.TaskStore
jobSvc *JobService
svc *TaskService
srv *httptest.Server
db *gorm.DB
}
func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv {
t.Helper()
db := newTaskSvcTestDB(t)
ts := store.NewTaskStore(db)
srv := httptest.NewServer(handler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop())
return &taskSvcTestEnv{
taskStore: ts,
jobSvc: jobSvc,
svc: svc,
srv: srv,
db: db,
}
}
func (e *taskSvcTestEnv) close() {
e.srv.Close()
}
func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task {
return &model.Task{
TaskName: name,
AppID: 1,
AppName: "test-app",
Status: status,
CurrentStep: "",
RetryCount: 0,
UserID: "user1",
SubmittedAt: time.Now(),
SlurmJobID: slurmJobID,
}
}
func TestTaskService_MapSlurmState_AllStates(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
cases := []struct {
input []string
expected string
}{
{[]string{"PENDING"}, model.TaskStatusQueued},
{[]string{"RUNNING"}, model.TaskStatusRunning},
{[]string{"CONFIGURING"}, model.TaskStatusRunning},
{[]string{"COMPLETING"}, model.TaskStatusRunning},
{[]string{"COMPLETED"}, model.TaskStatusCompleted},
{[]string{"FAILED"}, model.TaskStatusFailed},
{[]string{"CANCELLED"}, model.TaskStatusFailed},
{[]string{"TIMEOUT"}, model.TaskStatusFailed},
{[]string{"NODE_FAIL"}, model.TaskStatusFailed},
{[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed},
{[]string{"PREEMPTED"}, model.TaskStatusFailed},
{[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning},
{[]string{"unknown_state"}, model.TaskStatusRunning},
{[]string{"pending"}, model.TaskStatusQueued},
{[]string{"Running"}, model.TaskStatusRunning},
}
for _, tc := range cases {
got := env.svc.mapSlurmStateToTaskStatus(tc.input)
if got != tc.expected {
t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected)
}
}
}
func TestTaskService_MapSlurmState_Empty(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
got := env.svc.mapSlurmStateToTaskStatus([]string{})
if got != model.TaskStatusRunning {
t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning)
}
got = env.svc.mapSlurmStateToTaskStatus(nil)
if got != model.TaskStatusRunning {
t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"RUNNING"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID)
id, err := env.taskStore.Create(ctx, task)
if err != nil {
t.Fatalf("Create: %v", err)
}
err = env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("refreshTaskStatus: %v", err)
}
updated, _ := env.taskStore.GetByID(ctx, id)
if updated.Status != model.TaskStatusRunning {
t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) {
env := newTaskSvcTestEnv(t, nil)
defer env.close()
ctx := context.Background()
task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusQueued {
t.Errorf("status should remain unchanged, got %q", got.Status)
}
}
func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"down"}`))
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("expected no error (soft fail), got %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusQueued {
t.Errorf("status should remain unchanged on slurm error, got %q", got.Status)
}
}
func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"RUNNING"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID)
id, _ := env.taskStore.Create(ctx, task)
err := env.svc.refreshTaskStatus(ctx, id)
if err != nil {
t.Fatalf("refreshTaskStatus: %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusRunning {
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning)
}
}
func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) {
jobID := int32(42)
slurmQueried := false
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slurmQueried = true
w.WriteHeader(http.StatusInternalServerError)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID)
id, _ := env.taskStore.Create(ctx, task)
freshTask, _ := env.taskStore.GetByID(ctx, id)
if freshTask == nil {
t.Fatal("task not found")
}
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id)
err := env.svc.RefreshStaleTasks(ctx)
if err != nil {
t.Fatalf("RefreshStaleTasks: %v", err)
}
if slurmQueried {
t.Error("expected no Slurm query for fresh task")
}
}
func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) {
jobID := int32(42)
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := slurm.OpenapiJobInfoResp{
Jobs: slurm.JobInfoMsg{
{
JobID: &jobID,
JobState: []string{"COMPLETED"},
},
},
}
json.NewEncoder(w).Encode(resp)
}))
defer env.close()
ctx := context.Background()
task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID)
id, _ := env.taskStore.Create(ctx, task)
staleTime := time.Now().Add(-60 * time.Second)
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id)
err := env.svc.RefreshStaleTasks(ctx)
if err != nil {
t.Fatalf("RefreshStaleTasks: %v", err)
}
got, _ := env.taskStore.GetByID(ctx, id)
if got.Status != model.TaskStatusCompleted {
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted)
}
}

View File

@@ -0,0 +1,538 @@
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/slurm"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
gormlogger "gorm.io/gorm/logger"
"gorm.io/gorm"
)
func setupTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
type taskTestEnv struct {
taskStore *store.TaskStore
appStore *store.ApplicationStore
fileStore *store.FileStore
blobStore *store.BlobStore
svc *TaskService
srv *httptest.Server
db *gorm.DB
workDirBase string
}
func newTaskTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *taskTestEnv {
t.Helper()
db := setupTaskTestDB(t)
ts := store.NewTaskStore(db)
as := store.NewApplicationStore(db)
fs := store.NewFileStore(db)
bs := store.NewBlobStore(db)
srv := httptest.NewServer(slurmHandler)
client, _ := slurm.NewClient(srv.URL, srv.Client())
jobSvc := NewJobService(client, zap.NewNop())
workDirBase := filepath.Join(t.TempDir(), "workdir")
os.MkdirAll(workDirBase, 0777)
svc := NewTaskService(ts, as, fs, bs, nil, jobSvc, workDirBase, zap.NewNop())
return &taskTestEnv{
taskStore: ts,
appStore: as,
fileStore: fs,
blobStore: bs,
svc: svc,
srv: srv,
db: db,
workDirBase: workDirBase,
}
}
func (e *taskTestEnv) close() {
e.srv.Close()
}
func (e *taskTestEnv) createApp(t *testing.T, name, script string, params json.RawMessage) int64 {
t.Helper()
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
Name: name,
ScriptTemplate: script,
Parameters: params,
})
if err != nil {
t.Fatalf("create app: %v", err)
}
return id
}
func TestTaskService_CreateTask_Success(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "my-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
TaskName: "test-task",
Values: map[string]string{"KEY": "val"},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if task.ID == 0 {
t.Error("expected non-zero task ID")
}
if task.AppID != appID {
t.Errorf("AppID = %d, want %d", task.AppID, appID)
}
if task.AppName != "my-app" {
t.Errorf("AppName = %q, want %q", task.AppName, "my-app")
}
if task.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusSubmitted)
}
if task.TaskName != "test-task" {
t.Errorf("TaskName = %q, want %q", task.TaskName, "test-task")
}
var values map[string]string
if err := json.Unmarshal(task.Values, &values); err != nil {
t.Fatalf("unmarshal values: %v", err)
}
if values["KEY"] != "val" {
t.Errorf("values[KEY] = %q, want %q", values["KEY"], "val")
}
}
func TestTaskService_CreateTask_InvalidAppID(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: 999,
})
if err == nil {
t.Fatal("expected error for invalid app_id")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestTaskService_CreateTask_ExceedsFileLimit(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
fileIDs := make([]int64, 101)
for i := range fileIDs {
fileIDs[i] = int64(i + 1)
}
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
InputFileIDs: fileIDs,
})
if err == nil {
t.Fatal("expected error for exceeding file limit")
}
if !strings.Contains(err.Error(), "exceeds limit") {
t.Errorf("error should mention limit, got: %v", err)
}
}
func TestTaskService_CreateTask_DuplicateFileIDs(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
ctx := context.Background()
for _, id := range []int64{1, 2} {
f := &model.File{
Name: "file.txt",
BlobSHA256: "abc123",
}
if err := env.fileStore.Create(ctx, f); err != nil {
t.Fatalf("create file: %v", err)
}
if f.ID != id {
t.Fatalf("expected file ID %d, got %d", id, f.ID)
}
}
task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{
AppID: appID,
InputFileIDs: []int64{1, 1, 2, 2},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
var fileIDs []int64
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
t.Fatalf("unmarshal file ids: %v", err)
}
if len(fileIDs) != 2 {
t.Fatalf("expected 2 deduplicated file IDs, got %d: %v", len(fileIDs), fileIDs)
}
if fileIDs[0] != 1 || fileIDs[1] != 2 {
t.Errorf("expected [1,2], got %v", fileIDs)
}
}
func TestTaskService_CreateTask_AutoName(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "My Cool App", "#!/bin/bash\necho hi", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if !strings.HasPrefix(task.TaskName, "My_Cool_App_") {
t.Errorf("auto-generated name should start with 'My_Cool_App_', got %q", task.TaskName)
}
}
func TestTaskService_CreateTask_NilValues(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: nil,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if string(task.Values) != `{}` {
t.Errorf("Values = %q, want {}", string(task.Values))
}
}
func TestTaskService_ProcessTask_Success(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "test-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{"INPUT": "hello"},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err != nil {
t.Fatalf("ProcessTask: %v", err)
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
if updated.SlurmJobID == nil || *updated.SlurmJobID != 42 {
t.Errorf("SlurmJobID = %v, want 42", updated.SlurmJobID)
}
if updated.WorkDir == "" {
t.Error("WorkDir should not be empty")
}
if !strings.HasPrefix(updated.WorkDir, env.workDirBase) {
t.Errorf("WorkDir = %q, should start with %q", updated.WorkDir, env.workDirBase)
}
info, err := os.Stat(updated.WorkDir)
if err != nil {
t.Fatalf("stat workdir: %v", err)
}
if !info.IsDir() {
t.Error("WorkDir should be a directory")
}
}
func TestTaskService_ProcessTask_TaskNotFound(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
err := env.svc.ProcessTask(context.Background(), 999)
if err == nil {
t.Fatal("expected error for non-existent task")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestTaskService_ProcessTask_SlurmError(t *testing.T) {
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"slurm down"}`))
}))
defer env.close()
appID := env.createApp(t, "test-app", "#!/bin/bash\necho hello", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err == nil {
t.Fatal("expected error from Slurm")
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusFailed {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusFailed)
}
if updated.CurrentStep != model.TaskStepSubmitting {
t.Errorf("CurrentStep = %q, want %q", updated.CurrentStep, model.TaskStepSubmitting)
}
if !strings.Contains(updated.ErrorMessage, "submit job") {
t.Errorf("ErrorMessage should mention 'submit job', got: %q", updated.ErrorMessage)
}
}
func TestTaskService_ProcessTaskSync(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "sync-app", "#!/bin/bash\necho hello", nil)
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
AppID: appID,
})
if err != nil {
t.Fatalf("ProcessTaskSync: %v", err)
}
if resp.JobID != 42 {
t.Errorf("JobID = %d, want 42", resp.JobID)
}
}
func TestTaskService_ProcessTaskSync_NoMinIO(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "no-minio-app", "#!/bin/bash\necho hello", nil)
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
AppID: appID,
InputFileIDs: nil,
})
if err != nil {
t.Fatalf("ProcessTaskSync: %v", err)
}
if resp.JobID != 42 {
t.Errorf("JobID = %d, want 42", resp.JobID)
}
}
func TestTaskService_ProcessTask_NilValues(t *testing.T) {
jobID := int32(55)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "nil-val-app", "#!/bin/bash\necho hello", nil)
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: nil,
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err != nil {
t.Fatalf("ProcessTask: %v", err)
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
}
func TestTaskService_ListTasks(t *testing.T) {
env := newTaskTestEnv(t, nil)
defer env.close()
appID := env.createApp(t, "list-app", "#!/bin/bash\necho hi", nil)
for i := 0; i < 3; i++ {
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
TaskName: "task-" + string(rune('A'+i)),
})
if err != nil {
t.Fatalf("CreateTask %d: %v", i, err)
}
}
tasks, total, err := env.svc.ListTasks(context.Background(), &model.TaskListQuery{
Page: 1,
PageSize: 10,
})
if err != nil {
t.Fatalf("ListTasks: %v", err)
}
if total != 3 {
t.Errorf("total = %d, want 3", total)
}
if len(tasks) != 3 {
t.Errorf("len(tasks) = %d, want 3", len(tasks))
}
}
func TestTaskService_ProcessTask_ValidateParams_MissingRequired(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
// App requires INPUT param, but we submit without it
appID := env.createApp(t, "validation-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{}, // missing required INPUT
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err == nil {
t.Fatal("expected error for missing required parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
}
errStr := err.Error()
if !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "missing") && !strings.Contains(errStr, "INPUT") {
t.Errorf("error should mention 'validation', 'missing', or 'INPUT', got: %v", err)
}
}
func TestTaskService_ProcessTask_ValidateParams_InvalidInteger(t *testing.T) {
jobID := int32(42)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
// App expects integer param NUM, but we submit "abc"
appID := env.createApp(t, "int-validation-app", "#!/bin/bash\necho $NUM", json.RawMessage(`[{"name":"NUM","type":"integer","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{"NUM": "abc"}, // invalid integer
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err == nil {
t.Fatal("expected error for invalid integer parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
}
errStr := err.Error()
if !strings.Contains(errStr, "integer") && !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "NUM") {
t.Errorf("error should mention 'integer', 'validation', or 'NUM', got: %v", err)
}
}
func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) {
jobID := int32(99)
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
})
}))
defer env.close()
appID := env.createApp(t, "valid-params-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
AppID: appID,
Values: map[string]string{"INPUT": "hello"},
})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
err = env.svc.ProcessTask(context.Background(), task.ID)
if err != nil {
t.Fatalf("ProcessTask with valid params: %v", err)
}
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
if updated.Status != model.TaskStatusQueued {
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
}
if updated.SlurmJobID == nil || *updated.SlurmJobID != 99 {
t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID)
}
}

View File

@@ -0,0 +1,443 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type UploadService struct {
storage storage.ObjectStorage
blobStore *store.BlobStore
fileStore *store.FileStore
uploadStore *store.UploadStore
cfg config.MinioConfig
db *gorm.DB
logger *zap.Logger
}
func NewUploadService(
st storage.ObjectStorage,
blobStore *store.BlobStore,
fileStore *store.FileStore,
uploadStore *store.UploadStore,
cfg config.MinioConfig,
db *gorm.DB,
logger *zap.Logger,
) *UploadService {
return &UploadService{
storage: st,
blobStore: blobStore,
fileStore: fileStore,
uploadStore: uploadStore,
cfg: cfg,
db: db,
logger: logger,
}
}
func (s *UploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
if err := model.ValidateFileName(req.FileName); err != nil {
return nil, fmt.Errorf("invalid file name: %w", err)
}
if req.FileSize < 0 {
return nil, fmt.Errorf("file size cannot be negative")
}
if req.FileSize > s.cfg.MaxFileSize {
return nil, fmt.Errorf("file size %d exceeds maximum %d", req.FileSize, s.cfg.MaxFileSize)
}
chunkSize := s.cfg.ChunkSize
if req.ChunkSize != nil {
chunkSize = *req.ChunkSize
}
if chunkSize < s.cfg.MinChunkSize {
return nil, fmt.Errorf("chunk size %d is below minimum %d", chunkSize, s.cfg.MinChunkSize)
}
totalChunks := int((req.FileSize + chunkSize - 1) / chunkSize)
if req.FileSize == 0 {
totalChunks = 0
}
if totalChunks > 10000 {
return nil, fmt.Errorf("total chunks %d exceeds limit of 10000", totalChunks)
}
blob, err := s.blobStore.GetBySHA256(ctx, req.SHA256)
if err != nil {
return nil, fmt.Errorf("check dedup: %w", err)
}
if blob != nil {
if err := s.blobStore.IncrementRef(ctx, req.SHA256); err != nil {
return nil, fmt.Errorf("increment ref: %w", err)
}
file := &model.File{
Name: req.FileName,
FolderID: req.FolderID,
BlobSHA256: req.SHA256,
}
if err := s.fileStore.Create(ctx, file); err != nil {
return nil, fmt.Errorf("create file: %w", err)
}
return model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: blob.FileSize,
MimeType: blob.MimeType,
SHA256: blob.SHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}, nil
}
mimeType := req.MimeType
if mimeType == "" {
mimeType = "application/octet-stream"
}
session := &model.UploadSession{
FileName: req.FileName,
FileSize: req.FileSize,
ChunkSize: chunkSize,
TotalChunks: totalChunks,
SHA256: req.SHA256,
FolderID: req.FolderID,
Status: "pending",
MinioPrefix: fmt.Sprintf("uploads/%d/", time.Now().UnixNano()),
MimeType: mimeType,
ExpiresAt: time.Now().Add(time.Duration(s.cfg.SessionTTL) * time.Hour),
}
if err := s.uploadStore.CreateSession(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
return model.UploadSessionResponse{
ID: session.ID,
FileName: session.FileName,
FileSize: session.FileSize,
ChunkSize: session.ChunkSize,
TotalChunks: session.TotalChunks,
SHA256: session.SHA256,
Status: session.Status,
UploadedChunks: []int{},
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
}, nil
}
func (s *UploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
if session == nil {
return fmt.Errorf("session not found")
}
switch session.Status {
case "pending", "uploading", "failed":
default:
return fmt.Errorf("cannot upload to session with status %q", session.Status)
}
if chunkIndex < 0 || chunkIndex >= session.TotalChunks {
return fmt.Errorf("chunk index %d out of range [0, %d)", chunkIndex, session.TotalChunks)
}
key := fmt.Sprintf("%schunk_%05d", session.MinioPrefix, chunkIndex)
hasher := sha256.New()
teeReader := io.TeeReader(reader, hasher)
_, err = s.storage.PutObject(ctx, s.cfg.Bucket, key, teeReader, size, storage.PutObjectOptions{
DisableMultipart: true,
})
if err != nil {
return fmt.Errorf("put object: %w", err)
}
chunkSHA256 := hex.EncodeToString(hasher.Sum(nil))
chunk := &model.UploadChunk{
SessionID: sessionID,
ChunkIndex: chunkIndex,
MinioKey: key,
SHA256: chunkSHA256,
Size: size,
Status: "uploaded",
}
if err := s.uploadStore.UpsertChunk(ctx, chunk); err != nil {
return fmt.Errorf("upsert chunk: %w", err)
}
if session.Status == "pending" {
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "uploading"); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("update status to uploading: %w", err)
}
}
}
return nil
}
func (s *UploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
session, chunks, err := s.uploadStore.GetSessionWithChunks(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("get session: %w", err)
}
if session == nil {
return nil, fmt.Errorf("session not found")
}
uploadedChunks := make([]int, 0, len(chunks))
for _, c := range chunks {
if c.Status == "uploaded" {
uploadedChunks = append(uploadedChunks, c.ChunkIndex)
}
}
return &model.UploadSessionResponse{
ID: session.ID,
FileName: session.FileName,
FileSize: session.FileSize,
ChunkSize: session.ChunkSize,
TotalChunks: session.TotalChunks,
SHA256: session.SHA256,
Status: session.Status,
UploadedChunks: uploadedChunks,
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
}, nil
}
func (s *UploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
var fileResp *model.FileResponse
var totalChunks int
var minioPrefix string
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var session model.UploadSession
query := tx.WithContext(ctx)
if !isSQLite(tx) {
query = query.Clauses(clause.Locking{Strength: "UPDATE"})
}
if err := query.First(&session, sessionID).Error; err != nil {
return fmt.Errorf("get session: %w", err)
}
switch session.Status {
case "uploading", "failed":
default:
return fmt.Errorf("cannot complete session with status %q", session.Status)
}
if session.TotalChunks > 0 {
var chunkCount int64
if err := tx.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Count(&chunkCount).Error; err != nil {
return fmt.Errorf("count chunks: %w", err)
}
if int(chunkCount) != session.TotalChunks {
return fmt.Errorf("not all chunks uploaded: %d/%d", chunkCount, session.TotalChunks)
}
}
totalChunks = session.TotalChunks
minioPrefix = session.MinioPrefix
if err := updateStatusTx(tx, ctx, sessionID, "merging"); err != nil {
return fmt.Errorf("update status to merging: %w", err)
}
blob, err := s.blobStore.GetBySHA256ForUpdate(ctx, tx, session.SHA256)
if err != nil {
return fmt.Errorf("get blob for update: %w", err)
}
if blob != nil {
result := tx.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ?", session.SHA256).
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
if result.Error != nil {
return fmt.Errorf("increment ref: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("blob not found for ref increment")
}
} else {
if session.TotalChunks > 0 {
chunkKeys := make([]string, session.TotalChunks)
for i := 0; i < session.TotalChunks; i++ {
chunkKeys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
}
dstKey := "files/" + session.SHA256
if _, err := s.storage.ComposeObject(ctx, s.cfg.Bucket, dstKey, chunkKeys); err != nil {
return errComposeFailed{err: err}
}
}
blobRecord := &model.FileBlob{
SHA256: session.SHA256,
MinioKey: "files/" + session.SHA256,
FileSize: session.FileSize,
MimeType: session.MimeType,
RefCount: 1,
}
if session.TotalChunks == 0 {
blobRecord.MinioKey = "files/" + session.SHA256
blobRecord.FileSize = 0
}
if err := tx.WithContext(ctx).Create(blobRecord).Error; err != nil {
return fmt.Errorf("create blob: %w", err)
}
}
file := &model.File{
Name: session.FileName,
FolderID: session.FolderID,
BlobSHA256: session.SHA256,
}
if err := tx.WithContext(ctx).Create(file).Error; err != nil {
return fmt.Errorf("create file: %w", err)
}
if err := updateStatusTx(tx, ctx, sessionID, "completed"); err != nil {
return fmt.Errorf("update status to completed: %w", err)
}
fileResp = &model.FileResponse{
ID: file.ID,
Name: file.Name,
FolderID: file.FolderID,
Size: session.FileSize,
MimeType: session.MimeType,
SHA256: session.SHA256,
CreatedAt: file.CreatedAt,
UpdatedAt: file.UpdatedAt,
}
return nil
})
if err != nil {
var cfe errComposeFailed
if errors.As(err, &cfe) {
if statusErr := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "failed"); statusErr != nil {
s.logger.Warn("failed to mark session as failed", zap.Int64("session_id", sessionID), zap.Error(statusErr))
}
return nil, fmt.Errorf("compose object: %w", cfe.err)
}
return nil, err
}
if totalChunks > 0 {
keys := make([]string, totalChunks)
for i := 0; i < totalChunks; i++ {
keys[i] = fmt.Sprintf("%schunk_%05d", minioPrefix, i)
}
go func() {
bgCtx := context.Background()
if delErr := s.storage.RemoveObjects(bgCtx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
s.logger.Warn("delete temp chunks", zap.Error(delErr))
}
}()
}
return fileResp, nil
}
func (s *UploadService) CancelUpload(ctx context.Context, sessionID int64) error {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
if session == nil {
return fmt.Errorf("session not found")
}
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "cancelled"); err != nil {
return fmt.Errorf("update status to cancelled: %w", err)
}
if session.TotalChunks > 0 {
keys, listErr := s.listChunkKeys(ctx, sessionID)
if listErr != nil {
s.logger.Warn("list chunk keys for cancel", zap.Error(listErr))
} else if len(keys) > 0 {
if delErr := s.storage.RemoveObjects(ctx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
s.logger.Warn("remove chunk objects for cancel", zap.Error(delErr))
}
}
}
if err := s.uploadStore.DeleteSession(ctx, sessionID); err != nil {
return fmt.Errorf("delete session: %w", err)
}
return nil
}
func (s *UploadService) listChunkKeys(ctx context.Context, sessionID int64) ([]string, error) {
session, err := s.uploadStore.GetSession(ctx, sessionID)
if err != nil || session == nil {
return nil, fmt.Errorf("get session: %w", err)
}
keys := make([]string, session.TotalChunks)
for i := 0; i < session.TotalChunks; i++ {
keys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
}
return keys, nil
}
func updateStatusTx(tx *gorm.DB, ctx context.Context, id int64, status string) error {
result := tx.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
func isSQLite(db *gorm.DB) bool {
return db.Dialector.Name() == "sqlite"
}
func objectKeys(objects []storage.ObjectInfo) []string {
keys := make([]string, len(objects))
for i, o := range objects {
keys[i] = o.Key
}
return keys
}
type errComposeFailed struct {
err error
}
func (e errComposeFailed) Error() string {
return fmt.Sprintf("compose failed: %v", e.err)
}

View File

@@ -0,0 +1,678 @@
package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"strings"
"testing"
"time"
"gcy_hpc_server/internal/config"
"gcy_hpc_server/internal/model"
"gcy_hpc_server/internal/storage"
"gcy_hpc_server/internal/store"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type uploadMockStorage struct {
putObjectFn func(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error)
composeObjectFn func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error)
listObjectsFn func(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error)
removeObjectsFn func(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error
removeObjectFn func(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
bucketExistsFn func(ctx context.Context, bucket string) (bool, error)
makeBucketFn func(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error
statObjectFn func(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error)
abortMultipartFn func(ctx context.Context, bucket, object, uploadID string) error
removeIncompleteFn func(ctx context.Context, bucket, object string) error
}
func (m *uploadMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
if m.putObjectFn != nil {
return m.putObjectFn(ctx, bucket, key, reader, size, opts)
}
io.Copy(io.Discard, reader)
return storage.UploadInfo{ETag: "etag", Size: size}, nil
}
func (m *uploadMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
if m.getObjectFn != nil {
return m.getObjectFn(ctx, bucket, key, opts)
}
return nil, storage.ObjectInfo{}, nil
}
func (m *uploadMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
if m.composeObjectFn != nil {
return m.composeObjectFn(ctx, bucket, dst, sources)
}
return storage.UploadInfo{ETag: "composed", Size: 0}, nil
}
func (m *uploadMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
if m.abortMultipartFn != nil {
return m.abortMultipartFn(ctx, bucket, object, uploadID)
}
return nil
}
func (m *uploadMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
if m.removeIncompleteFn != nil {
return m.removeIncompleteFn(ctx, bucket, object)
}
return nil
}
func (m *uploadMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
if m.removeObjectFn != nil {
return m.removeObjectFn(ctx, bucket, key, opts)
}
return nil
}
func (m *uploadMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
if m.listObjectsFn != nil {
return m.listObjectsFn(ctx, bucket, prefix, recursive)
}
return nil, nil
}
func (m *uploadMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
if m.removeObjectsFn != nil {
return m.removeObjectsFn(ctx, bucket, keys, opts)
}
return nil
}
func (m *uploadMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
if m.bucketExistsFn != nil {
return m.bucketExistsFn(ctx, bucket)
}
return true, nil
}
func (m *uploadMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
if m.makeBucketFn != nil {
return m.makeBucketFn(ctx, bucket, opts)
}
return nil
}
func (m *uploadMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
if m.statObjectFn != nil {
return m.statObjectFn(ctx, bucket, key, opts)
}
return storage.ObjectInfo{}, nil
}
func setupUploadTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}, &model.UploadSession{}, &model.UploadChunk{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func uploadTestConfig() config.MinioConfig {
return config.MinioConfig{
Bucket: "test-bucket",
ChunkSize: 16 << 20,
MaxFileSize: 50 << 30,
MinChunkSize: 5 << 20,
SessionTTL: 48,
}
}
func newUploadTestService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *UploadService {
t.Helper()
return NewUploadService(
st,
store.NewBlobStore(db),
store.NewFileStore(db),
store.NewUploadStore(db),
uploadTestConfig(),
db,
zap.NewNop(),
)
}
func sha256Of(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
func TestInitUpload_CreatesSession(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.txt",
FileSize: 32 << 20,
SHA256: "abc123",
})
if err != nil {
t.Fatalf("InitUpload: %v", err)
}
sessResp, ok := resp.(model.UploadSessionResponse)
if !ok {
t.Fatalf("expected UploadSessionResponse, got %T", resp)
}
if sessResp.Status != "pending" {
t.Errorf("status = %q, want pending", sessResp.Status)
}
if sessResp.TotalChunks != 2 {
t.Errorf("TotalChunks = %d, want 2", sessResp.TotalChunks)
}
if sessResp.FileName != "test.txt" {
t.Errorf("FileName = %q, want test.txt", sessResp.FileName)
}
}
func TestInitUpload_DedupBlobExists(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
blobSHA := sha256Of([]byte("hello"))
blob := &model.FileBlob{
SHA256: blobSHA,
MinioKey: "files/" + blobSHA,
FileSize: 5,
MimeType: "text/plain",
RefCount: 1,
}
if err := db.Create(blob).Error; err != nil {
t.Fatalf("create blob: %v", err)
}
resp, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "mydoc.txt",
FileSize: 5,
SHA256: blobSHA,
})
if err != nil {
t.Fatalf("InitUpload: %v", err)
}
fileResp, ok := resp.(model.FileResponse)
if !ok {
t.Fatalf("expected FileResponse, got %T", resp)
}
if fileResp.Name != "mydoc.txt" {
t.Errorf("Name = %q, want mydoc.txt", fileResp.Name)
}
if fileResp.SHA256 != blobSHA {
t.Errorf("SHA256 mismatch")
}
var after model.FileBlob
db.First(&after, blob.ID)
if after.RefCount != 2 {
t.Errorf("RefCount = %d, want 2", after.RefCount)
}
}
func TestInitUpload_ChunkSizeTooSmall(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
tiny := int64(1024)
_, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.txt",
FileSize: 10 << 20,
SHA256: "abc",
ChunkSize: &tiny,
})
if err == nil {
t.Fatal("expected error for chunk size too small")
}
if !strings.Contains(err.Error(), "below minimum") {
t.Errorf("error = %q, want 'below minimum'", err.Error())
}
}
func TestInitUpload_TooManyChunks(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
cfg := uploadTestConfig()
cfg.ChunkSize = 1
cfg.MinChunkSize = 1
svc2 := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
_, err := svc2.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "big.txt",
FileSize: 10002,
SHA256: "abc",
})
if err == nil {
t.Fatal("expected error for too many chunks")
}
if !strings.Contains(err.Error(), "exceeds limit") {
t.Errorf("error = %q, want 'exceeds limit'", err.Error())
}
}
func TestInitUpload_DangerousFilename(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
for _, name := range []string{"", "..", "foo/bar", "foo\\bar", " name"} {
_, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: name,
FileSize: 100,
SHA256: "abc",
})
if err == nil {
t.Errorf("expected error for filename %q", name)
}
}
}
func TestUploadChunk_UploadsAndStoresSHA256(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 << 20,
SHA256: "deadbeef",
})
sessResp := resp.(model.UploadSessionResponse)
data := []byte("chunk data here")
chunkSHA := sha256Of(data)
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader(data), int64(len(data)))
if err != nil {
t.Fatalf("UploadChunk: %v", err)
}
var chunk model.UploadChunk
db.Where("session_id = ? AND chunk_index = ?", sessResp.ID, 0).First(&chunk)
if chunk.SHA256 != chunkSHA {
t.Errorf("chunk SHA256 = %q, want %q", chunk.SHA256, chunkSHA)
}
if chunk.Status != "uploaded" {
t.Errorf("chunk status = %q, want uploaded", chunk.Status)
}
session, _ := svc.uploadStore.GetSession(context.Background(), sessResp.ID)
if session.Status != "uploading" {
t.Errorf("session status = %q, want uploading", session.Status)
}
}
func TestUploadChunk_RejectsCompletedSession(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 << 20,
SHA256: "deadbeef",
})
sessResp := resp.(model.UploadSessionResponse)
svc.uploadStore.UpdateSessionStatus(context.Background(), sessResp.ID, "completed")
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader([]byte("x")), 1)
if err == nil {
t.Fatal("expected error for completed session")
}
if !strings.Contains(err.Error(), "cannot upload") {
t.Errorf("error = %q", err.Error())
}
}
func TestUploadChunk_RejectsExpiredSession(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 10 << 20,
SHA256: "deadbeef",
})
sessResp := resp.(model.UploadSessionResponse)
svc.uploadStore.UpdateSessionStatus(context.Background(), sessResp.ID, "expired")
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader([]byte("x")), 1)
if err == nil {
t.Fatal("expected error for expired session")
}
}
func TestGetUploadStatus_ReturnsChunkIndices(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "test.bin",
FileSize: 32 << 20,
SHA256: "abc",
})
sessResp := resp.(model.UploadSessionResponse)
data := []byte("chunk0data")
svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader(data), int64(len(data)))
status, err := svc.GetUploadStatus(context.Background(), sessResp.ID)
if err != nil {
t.Fatalf("GetUploadStatus: %v", err)
}
if len(status.UploadedChunks) != 1 || status.UploadedChunks[0] != 0 {
t.Errorf("UploadedChunks = %v, want [0]", status.UploadedChunks)
}
if status.TotalChunks != 2 {
t.Errorf("TotalChunks = %d, want 2", status.TotalChunks)
}
}
func TestCompleteUpload_CreatesBlobAndFile(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "aaa111"
sess := &model.UploadSession{
FileName: "test.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/testcomplete/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/testcomplete/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/testcomplete/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload: %v", err)
}
if fileResp.Name != "test.bin" {
t.Errorf("Name = %q, want test.bin", fileResp.Name)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 = %q, want %q", fileResp.SHA256, fileSHA)
}
var blob model.FileBlob
db.Where("sha256 = ?", fileSHA).First(&blob)
if blob.RefCount != 1 {
t.Errorf("blob RefCount = %d, want 1", blob.RefCount)
}
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
if session == nil {
t.Fatal("session should exist")
}
if session.Status != "completed" {
t.Errorf("session status = %q, want completed", session.Status)
}
}
func TestCompleteUpload_ReusesExistingBlob(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
composeCalled := false
st.composeObjectFn = func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
composeCalled = true
return storage.UploadInfo{}, nil
}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "reuse123"
db.Create(&model.FileBlob{
SHA256: fileSHA,
MinioKey: "files/" + fileSHA,
FileSize: 10 << 20,
MimeType: "application/octet-stream",
RefCount: 1,
})
sess := &model.UploadSession{
FileName: "reuse.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/reuse/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/reuse/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/reuse/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload: %v", err)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 mismatch")
}
if composeCalled {
t.Error("ComposeObject should not be called when blob exists")
}
var blob model.FileBlob
db.Where("sha256 = ?", fileSHA).First(&blob)
if blob.RefCount != 2 {
t.Errorf("RefCount = %d, want 2", blob.RefCount)
}
}
func TestCompleteUpload_NotAllChunks(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
sess := &model.UploadSession{
FileName: "partial.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: "partial123",
Status: "uploading",
MinioPrefix: "uploads/partial/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/partial/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
_, err := svc.CompleteUpload(context.Background(), sess.ID)
if err == nil {
t.Fatal("expected error for incomplete chunks")
}
if !strings.Contains(err.Error(), "not all chunks uploaded") {
t.Errorf("error = %q", err.Error())
}
}
func TestCompleteUpload_ComposeObjectFails(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
st.composeObjectFn = func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
return storage.UploadInfo{}, fmt.Errorf("compose failed")
}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "fail123"
sess := &model.UploadSession{
FileName: "fail.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/fail/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/fail/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/fail/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
_, err := svc.CompleteUpload(context.Background(), sess.ID)
if err == nil {
t.Fatal("expected error for compose failure")
}
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
if session.Status != "failed" {
t.Errorf("session status = %q, want failed", session.Status)
}
}
func TestCompleteUpload_RetriesFailedSession(t *testing.T) {
db := setupUploadTestDB(t)
cfg := uploadTestConfig()
cfg.ChunkSize = 5 << 20
cfg.MinChunkSize = 1
st := &uploadMockStorage{}
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
fileSHA := "retry123"
sess := &model.UploadSession{
FileName: "retry.bin",
FileSize: 10 << 20,
ChunkSize: 5 << 20,
TotalChunks: 2,
SHA256: fileSHA,
Status: "failed",
MinioPrefix: "uploads/retry/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/retry/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/retry/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload on retry: %v", err)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 mismatch")
}
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
if session.Status != "completed" {
t.Errorf("session status = %q, want completed", session.Status)
}
}
func TestCancelUpload_CleansUp(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
FileName: "cancel.bin",
FileSize: 10 << 20,
SHA256: "cancel123",
})
sessResp := resp.(model.UploadSessionResponse)
err := svc.CancelUpload(context.Background(), sessResp.ID)
if err != nil {
t.Fatalf("CancelUpload: %v", err)
}
session, _ := svc.uploadStore.GetSession(context.Background(), sessResp.ID)
if session != nil {
t.Error("session should be deleted")
}
}
func TestZeroByteFile_CompletesImmediately(t *testing.T) {
db := setupUploadTestDB(t)
st := &uploadMockStorage{}
svc := newUploadTestService(t, st, db)
fileSHA := "empty123"
sess := &model.UploadSession{
FileName: "empty.bin",
FileSize: 0,
ChunkSize: 16 << 20,
TotalChunks: 0,
SHA256: fileSHA,
Status: "uploading",
MinioPrefix: "uploads/empty/",
MimeType: "application/octet-stream",
ExpiresAt: time.Now().Add(48 * time.Hour),
}
db.Create(sess)
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
if err != nil {
t.Fatalf("CompleteUpload zero byte: %v", err)
}
if fileResp.SHA256 != fileSHA {
t.Errorf("SHA256 mismatch")
}
if fileResp.Size != 0 {
t.Errorf("Size = %d, want 0", fileResp.Size)
}
var blob model.FileBlob
db.Where("sha256 = ?", fileSHA).First(&blob)
if blob.RefCount != 1 {
t.Errorf("RefCount = %d, want 1", blob.RefCount)
}
uploadStore := store.NewUploadStore(db)
session, _ := uploadStore.GetSession(context.Background(), sess.ID)
if session == nil {
t.Fatal("session should exist")
}
if session.Status != "completed" {
t.Errorf("session status = %q, want completed", session.Status)
}
}

View File

@@ -137,11 +137,11 @@ func TestClient_ErrorHandling(t *testing.T) {
t.Fatal("expected error for 500 response")
}
errorResp, ok := err.(*ErrorResponse)
errorResp, ok := err.(*SlurmAPIError)
if !ok {
t.Fatalf("expected *ErrorResponse, got %T", err)
t.Fatalf("expected *SlurmAPIError, got %T", err)
}
if errorResp.Response.StatusCode != 500 {
t.Errorf("expected status 500, got %d", errorResp.Response.StatusCode)
if errorResp.StatusCode != 500 {
t.Errorf("expected status 500, got %d", errorResp.StatusCode)
}
}

View File

@@ -1,38 +1,85 @@
package slurm
import (
"encoding/json"
"fmt"
"io"
"net/http"
)
// ErrorResponse represents an error returned by the Slurm REST API.
type ErrorResponse struct {
Response *http.Response
Message string
// errorResponseFields is used to parse errors/warnings from a Slurm API error body.
type errorResponseFields struct {
Errors OpenapiErrors `json:"errors,omitempty"`
Warnings OpenapiWarnings `json:"warnings,omitempty"`
}
func (r *ErrorResponse) Error() string {
// SlurmAPIError represents a structured error returned by the Slurm REST API.
// It captures both the HTTP details and the parsed Slurm error array when available.
type SlurmAPIError struct {
Response *http.Response
StatusCode int
Errors OpenapiErrors
Warnings OpenapiWarnings
Message string // raw body fallback when JSON parsing fails
}
func (e *SlurmAPIError) Error() string {
if len(e.Errors) > 0 {
first := e.Errors[0]
detail := ""
if first.Error != nil {
detail = *first.Error
} else if first.Description != nil {
detail = *first.Description
}
if detail != "" {
return fmt.Sprintf("%v %v: %d %s",
r.Response.Request.Method, r.Response.Request.URL,
r.Response.StatusCode, r.Message)
e.Response.Request.Method, e.Response.Request.URL,
e.StatusCode, detail)
}
}
return fmt.Sprintf("%v %v: %d %s",
e.Response.Request.Method, e.Response.Request.URL,
e.StatusCode, e.Message)
}
// IsNotFound reports whether err is a SlurmAPIError with HTTP 404 status.
func IsNotFound(err error) bool {
if apiErr, ok := err.(*SlurmAPIError); ok {
return apiErr.StatusCode == http.StatusNotFound
}
return false
}
// CheckResponse checks the API response for errors. It returns nil if the
// response is a 2xx status code. For non-2xx codes, it reads the response
// body and returns an ErrorResponse.
// body, attempts to parse structured Slurm errors, and returns a SlurmAPIError.
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; c >= 200 && c <= 299 {
return nil
}
errorResponse := &ErrorResponse{Response: r}
data, err := io.ReadAll(r.Body)
if err != nil || len(data) == 0 {
errorResponse.Message = r.Status
return errorResponse
return &SlurmAPIError{
Response: r,
StatusCode: r.StatusCode,
Message: r.Status,
}
}
errorResponse.Message = string(data)
return errorResponse
apiErr := &SlurmAPIError{
Response: r,
StatusCode: r.StatusCode,
Message: string(data),
}
// Try to extract structured errors/warnings from JSON body.
var fields errorResponseFields
if json.Unmarshal(data, &fields) == nil {
apiErr.Errors = fields.Errors
apiErr.Warnings = fields.Warnings
}
return apiErr
}

View File

@@ -0,0 +1,220 @@
package slurm
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCheckResponse(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
wantErr bool
wantStatusCode int
wantErrors int
wantWarnings int
wantMessageContains string
}{
{
name: "2xx response returns nil",
statusCode: http.StatusOK,
body: `{"meta":{}}`,
wantErr: false,
},
{
name: "404 with valid JSON error body",
statusCode: http.StatusNotFound,
body: `{"errors":[{"description":"Unable to query JobId=198","error_number":2017,"error":"Invalid job id specified","source":"_handle_job_get"}],"warnings":[]}`,
wantErr: true,
wantStatusCode: 404,
wantErrors: 1,
wantWarnings: 0,
wantMessageContains: "Invalid job id specified",
},
{
name: "500 with non-JSON body",
statusCode: http.StatusInternalServerError,
body: "internal server error",
wantErr: true,
wantStatusCode: 500,
wantErrors: 0,
wantWarnings: 0,
wantMessageContains: "internal server error",
},
{
name: "503 with empty body returns http.Status text",
statusCode: http.StatusServiceUnavailable,
body: "",
wantErr: true,
wantStatusCode: 503,
wantErrors: 0,
wantWarnings: 0,
wantMessageContains: "503 Service Unavailable",
},
{
name: "400 with valid JSON but empty errors array",
statusCode: http.StatusBadRequest,
body: `{"errors":[],"warnings":[]}`,
wantErr: true,
wantStatusCode: 400,
wantErrors: 0,
wantWarnings: 0,
wantMessageContains: `{"errors":[],"warnings":[]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
rec.WriteHeader(tt.statusCode)
if tt.body != "" {
rec.Body.WriteString(tt.body)
}
err := CheckResponse(rec.Result())
if (err != nil) != tt.wantErr {
t.Fatalf("CheckResponse() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
return
}
apiErr, ok := err.(*SlurmAPIError)
if !ok {
t.Fatalf("expected *SlurmAPIError, got %T", err)
}
if apiErr.StatusCode != tt.wantStatusCode {
t.Errorf("StatusCode = %d, want %d", apiErr.StatusCode, tt.wantStatusCode)
}
if len(apiErr.Errors) != tt.wantErrors {
t.Errorf("len(Errors) = %d, want %d", len(apiErr.Errors), tt.wantErrors)
}
if len(apiErr.Warnings) != tt.wantWarnings {
t.Errorf("len(Warnings) = %d, want %d", len(apiErr.Warnings), tt.wantWarnings)
}
if !strings.Contains(apiErr.Message, tt.wantMessageContains) {
t.Errorf("Message = %q, want to contain %q", apiErr.Message, tt.wantMessageContains)
}
})
}
}
func TestIsNotFound(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "404 SlurmAPIError returns true",
err: &SlurmAPIError{StatusCode: http.StatusNotFound},
want: true,
},
{
name: "500 SlurmAPIError returns false",
err: &SlurmAPIError{StatusCode: http.StatusInternalServerError},
want: false,
},
{
name: "200 SlurmAPIError returns false",
err: &SlurmAPIError{StatusCode: http.StatusOK},
want: false,
},
{
name: "plain error returns false",
err: fmt.Errorf("some error"),
want: false,
},
{
name: "nil returns false",
err: nil,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsNotFound(tt.err); got != tt.want {
t.Errorf("IsNotFound() = %v, want %v", got, tt.want)
}
})
}
}
func TestSlurmAPIError_Error(t *testing.T) {
fakeReq := httptest.NewRequest("GET", "http://localhost/slurm/v0.0.40/job/123", nil)
tests := []struct {
name string
err *SlurmAPIError
wantContains []string
}{
{
name: "with Error field set",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusNotFound,
Errors: OpenapiErrors{{Error: Ptr("Job not found")}},
Message: "raw body",
},
wantContains: []string{"404", "Job not found"},
},
{
name: "with Description field set when Error is nil",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusBadRequest,
Errors: OpenapiErrors{{Description: Ptr("Unable to query")}},
Message: "raw body",
},
wantContains: []string{"400", "Unable to query"},
},
{
name: "with both Error and Description nil falls through to Message",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusInternalServerError,
Errors: OpenapiErrors{{}},
Message: "something went wrong",
},
wantContains: []string{"500", "something went wrong"},
},
{
name: "with empty Errors slice falls through to Message",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusServiceUnavailable,
Errors: OpenapiErrors{},
Message: "service unavailable fallback",
},
wantContains: []string{"503", "service unavailable fallback"},
},
{
name: "with non-empty Errors but empty detail string falls through to Message",
err: &SlurmAPIError{
Response: &http.Response{Request: fakeReq},
StatusCode: http.StatusBadGateway,
Errors: OpenapiErrors{{ErrorNumber: Ptr(int32(42))}},
Message: "gateway error detail",
},
wantContains: []string{"502", "gateway error detail"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.err.Error()
for _, substr := range tt.wantContains {
if !strings.Contains(got, substr) {
t.Errorf("Error() = %q, want to contain %q", got, substr)
}
}
})
}
}

View File

@@ -54,8 +54,8 @@ type PartitionInfoMaximumsOversubscribe struct {
// PartitionInfoMaximums represents maximum resource limits for a partition (v0.0.40_partition_info.maximums).
type PartitionInfoMaximums struct {
CpusPerNode *int32 `json:"cpus_per_node,omitempty"`
CpusPerSocket *int32 `json:"cpus_per_socket,omitempty"`
CpusPerNode *Uint32NoVal `json:"cpus_per_node,omitempty"`
CpusPerSocket *Uint32NoVal `json:"cpus_per_socket,omitempty"`
MemoryPerCPU *int64 `json:"memory_per_cpu,omitempty"`
PartitionMemoryPerCPU *Uint64NoVal `json:"partition_memory_per_cpu,omitempty"`
PartitionMemoryPerNode *Uint64NoVal `json:"partition_memory_per_node,omitempty"`

View File

@@ -43,8 +43,8 @@ func TestPartitionInfoRoundTrip(t *testing.T) {
},
GraceTime: Ptr(int32(300)),
Maximums: &PartitionInfoMaximums{
CpusPerNode: Ptr(int32(128)),
CpusPerSocket: Ptr(int32(64)),
CpusPerNode: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(128))},
CpusPerSocket: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(64))},
MemoryPerCPU: Ptr(int64(8192)),
PartitionMemoryPerCPU: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(8192))},
PartitionMemoryPerNode: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(262144))},

286
internal/storage/minio.go Normal file
View File

@@ -0,0 +1,286 @@
package storage
import (
"context"
"fmt"
"io"
"time"
"gcy_hpc_server/internal/config"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
// ObjectInfo contains metadata about a stored object.
type ObjectInfo struct {
Key string
Size int64
LastModified time.Time
ETag string
ContentType string
}
// UploadInfo contains metadata about an uploaded object.
type UploadInfo struct {
ETag string
Size int64
}
// GetOptions specifies parameters for GetObject, including optional Range.
type GetOptions struct {
Start *int64 // Range start byte offset (nil = no range)
End *int64 // Range end byte offset (nil = no range)
}
// MultipartUpload represents an incomplete multipart upload.
type MultipartUpload struct {
ObjectName string
UploadID string
Initiated time.Time
}
// RemoveObjectsOptions specifies options for removing multiple objects.
type RemoveObjectsOptions struct {
ForceDelete bool
}
// PutObjectOptions for PutObject.
type PutObjectOptions struct {
ContentType string
DisableMultipart bool // true for small chunks (already pre-split)
}
// RemoveObjectOptions for RemoveObject.
type RemoveObjectOptions struct {
ForceDelete bool
}
// MakeBucketOptions for MakeBucket.
type MakeBucketOptions struct {
Region string
}
// StatObjectOptions for StatObject.
type StatObjectOptions struct{}
// ObjectStorage defines the interface for object storage operations.
// Implementations should wrap MinIO SDK calls with custom transfer types.
type ObjectStorage interface {
// PutObject uploads an object from a reader.
PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts PutObjectOptions) (UploadInfo, error)
// GetObject retrieves an object. opts may contain Range parameters.
GetObject(ctx context.Context, bucket, key string, opts GetOptions) (io.ReadCloser, ObjectInfo, error)
// ComposeObject merges multiple source objects into a single destination.
ComposeObject(ctx context.Context, bucket, dst string, sources []string) (UploadInfo, error)
// AbortMultipartUpload aborts an incomplete multipart upload by upload ID.
AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error
// RemoveIncompleteUpload removes all incomplete uploads for an object.
// This is the preferred cleanup method — it encapsulates list + abort logic.
RemoveIncompleteUpload(ctx context.Context, bucket, object string) error
// RemoveObject deletes a single object.
RemoveObject(ctx context.Context, bucket, key string, opts RemoveObjectOptions) error
// ListObjects lists objects with a given prefix.
ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]ObjectInfo, error)
// RemoveObjects deletes multiple objects.
RemoveObjects(ctx context.Context, bucket string, keys []string, opts RemoveObjectsOptions) error
// BucketExists checks if a bucket exists.
BucketExists(ctx context.Context, bucket string) (bool, error)
// MakeBucket creates a new bucket.
MakeBucket(ctx context.Context, bucket string, opts MakeBucketOptions) error
// StatObject gets metadata about an object without downloading it.
StatObject(ctx context.Context, bucket, key string, opts StatObjectOptions) (ObjectInfo, error)
}
var _ ObjectStorage = (*MinioClient)(nil)
type MinioClient struct {
core *minio.Core
bucket string
}
func NewMinioClient(cfg config.MinioConfig) (*MinioClient, error) {
transport, err := minio.DefaultTransport(cfg.UseSSL)
if err != nil {
return nil, fmt.Errorf("create default transport: %w", err)
}
transport.MaxIdleConnsPerHost = 100
transport.IdleConnTimeout = 90 * time.Second
core, err := minio.NewCore(cfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
Secure: cfg.UseSSL,
Transport: transport,
})
if err != nil {
return nil, fmt.Errorf("create minio client: %w", err)
}
mc := &MinioClient{core: core, bucket: cfg.Bucket}
ctx := context.Background()
exists, err := core.BucketExists(ctx, cfg.Bucket)
if err != nil {
return nil, fmt.Errorf("check bucket: %w", err)
}
if !exists {
if err := core.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{}); err != nil {
return nil, fmt.Errorf("create bucket: %w", err)
}
}
return mc, nil
}
func (m *MinioClient) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts PutObjectOptions) (UploadInfo, error) {
info, err := m.core.PutObject(ctx, bucket, key, reader, size, "", "", minio.PutObjectOptions{
ContentType: opts.ContentType,
DisableMultipart: opts.DisableMultipart,
})
if err != nil {
return UploadInfo{}, fmt.Errorf("put object %s/%s: %w", bucket, key, err)
}
return UploadInfo{ETag: info.ETag, Size: info.Size}, nil
}
func (m *MinioClient) GetObject(ctx context.Context, bucket, key string, opts GetOptions) (io.ReadCloser, ObjectInfo, error) {
var gopts minio.GetObjectOptions
if opts.Start != nil || opts.End != nil {
start := int64(0)
end := int64(0)
if opts.Start != nil {
start = *opts.Start
}
if opts.End != nil {
end = *opts.End
}
if err := gopts.SetRange(start, end); err != nil {
return nil, ObjectInfo{}, fmt.Errorf("set range: %w", err)
}
}
body, info, _, err := m.core.GetObject(ctx, bucket, key, gopts)
if err != nil {
return nil, ObjectInfo{}, fmt.Errorf("get object %s/%s: %w", bucket, key, err)
}
return body, toObjectInfo(info), nil
}
func (m *MinioClient) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (UploadInfo, error) {
srcs := make([]minio.CopySrcOptions, len(sources))
for i, src := range sources {
srcs[i] = minio.CopySrcOptions{Bucket: bucket, Object: src}
}
do := minio.CopyDestOptions{
Bucket: bucket,
Object: dst,
ReplaceMetadata: true,
}
info, err := m.core.ComposeObject(ctx, do, srcs...)
if err != nil {
return UploadInfo{}, fmt.Errorf("compose object %s/%s: %w", bucket, dst, err)
}
return UploadInfo{ETag: info.ETag, Size: info.Size}, nil
}
func (m *MinioClient) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
if err := m.core.AbortMultipartUpload(ctx, bucket, object, uploadID); err != nil {
return fmt.Errorf("abort multipart upload %s/%s %s: %w", bucket, object, uploadID, err)
}
return nil
}
func (m *MinioClient) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
if err := m.core.RemoveIncompleteUpload(ctx, bucket, object); err != nil {
return fmt.Errorf("remove incomplete upload %s/%s: %w", bucket, object, err)
}
return nil
}
func (m *MinioClient) RemoveObject(ctx context.Context, bucket, key string, opts RemoveObjectOptions) error {
if err := m.core.RemoveObject(ctx, bucket, key, minio.RemoveObjectOptions{
ForceDelete: opts.ForceDelete,
}); err != nil {
return fmt.Errorf("remove object %s/%s: %w", bucket, key, err)
}
return nil
}
func (m *MinioClient) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]ObjectInfo, error) {
ch := m.core.Client.ListObjects(ctx, bucket, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: recursive,
})
var result []ObjectInfo
for obj := range ch {
if obj.Err != nil {
return result, fmt.Errorf("list objects %s/%s: %w", bucket, prefix, obj.Err)
}
result = append(result, toObjectInfo(obj))
}
return result, nil
}
func (m *MinioClient) RemoveObjects(ctx context.Context, bucket string, keys []string, opts RemoveObjectsOptions) error {
objectsCh := make(chan minio.ObjectInfo, len(keys))
for _, key := range keys {
objectsCh <- minio.ObjectInfo{Key: key}
}
close(objectsCh)
errCh := m.core.RemoveObjects(ctx, bucket, objectsCh, minio.RemoveObjectsOptions{})
for err := range errCh {
if err.Err != nil {
return fmt.Errorf("remove object %s: %w", err.ObjectName, err.Err)
}
}
return nil
}
func (m *MinioClient) BucketExists(ctx context.Context, bucket string) (bool, error) {
ok, err := m.core.BucketExists(ctx, bucket)
if err != nil {
return false, fmt.Errorf("bucket exists %s: %w", bucket, err)
}
return ok, nil
}
func (m *MinioClient) MakeBucket(ctx context.Context, bucket string, opts MakeBucketOptions) error {
if err := m.core.MakeBucket(ctx, bucket, minio.MakeBucketOptions{
Region: opts.Region,
}); err != nil {
return fmt.Errorf("make bucket %s: %w", bucket, err)
}
return nil
}
func (m *MinioClient) StatObject(ctx context.Context, bucket, key string, _ StatObjectOptions) (ObjectInfo, error) {
info, err := m.core.StatObject(ctx, bucket, key, minio.StatObjectOptions{})
if err != nil {
return ObjectInfo{}, fmt.Errorf("stat object %s/%s: %w", bucket, key, err)
}
return toObjectInfo(info), nil
}
func toObjectInfo(info minio.ObjectInfo) ObjectInfo {
return ObjectInfo{
Key: info.Key,
Size: info.Size,
LastModified: info.LastModified,
ETag: info.ETag,
ContentType: info.ContentType,
}
}

View File

@@ -0,0 +1,7 @@
package storage
import "testing"
func TestMinioClientImplementsObjectStorage(t *testing.T) {
var _ ObjectStorage = (*MinioClient)(nil)
}

View File

@@ -0,0 +1,114 @@
package store
import (
"context"
"encoding/json"
"errors"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type ApplicationStore struct {
db *gorm.DB
}
func NewApplicationStore(db *gorm.DB) *ApplicationStore {
return &ApplicationStore{db: db}
}
func (s *ApplicationStore) List(ctx context.Context, page, pageSize int) ([]model.Application, int, error) {
var apps []model.Application
var total int64
if err := s.db.WithContext(ctx).Model(&model.Application{}).Count(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&apps).Error; err != nil {
return nil, 0, err
}
return apps, int(total), nil
}
func (s *ApplicationStore) GetByID(ctx context.Context, id int64) (*model.Application, error) {
var app model.Application
err := s.db.WithContext(ctx).First(&app, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &app, nil
}
func (s *ApplicationStore) Create(ctx context.Context, req *model.CreateApplicationRequest) (int64, error) {
params := req.Parameters
if len(params) == 0 {
params = json.RawMessage(`[]`)
}
app := &model.Application{
Name: req.Name,
Description: req.Description,
Icon: req.Icon,
Category: req.Category,
ScriptTemplate: req.ScriptTemplate,
Parameters: params,
Scope: req.Scope,
}
if err := s.db.WithContext(ctx).Create(app).Error; err != nil {
return 0, err
}
return app.ID, nil
}
func (s *ApplicationStore) Update(ctx context.Context, id int64, req *model.UpdateApplicationRequest) error {
updates := map[string]interface{}{}
if req.Name != nil {
updates["name"] = *req.Name
}
if req.Description != nil {
updates["description"] = *req.Description
}
if req.Icon != nil {
updates["icon"] = *req.Icon
}
if req.Category != nil {
updates["category"] = *req.Category
}
if req.ScriptTemplate != nil {
updates["script_template"] = *req.ScriptTemplate
}
if req.Parameters != nil {
updates["parameters"] = *req.Parameters
}
if req.Scope != nil {
updates["scope"] = *req.Scope
}
if len(updates) == 0 {
return nil
}
result := s.db.WithContext(ctx).Model(&model.Application{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
func (s *ApplicationStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.Application{}, id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,227 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newAppTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Application{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func TestApplicationStore_Create_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, err := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "gromacs",
Description: "Molecular dynamics simulator",
Category: "simulation",
ScriptTemplate: "#!/bin/bash\nmodule load gromacs",
Parameters: []byte(`[{"name":"ntasks","type":"number","required":true}]`),
Scope: "system",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Errorf("Create() id = %d, want positive", id)
}
}
func TestApplicationStore_GetByID_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "lammps",
ScriptTemplate: "#!/bin/bash\nmodule load lammps",
})
app, err := s.GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if app == nil {
t.Fatal("GetByID() returned nil, expected application")
}
if app.Name != "lammps" {
t.Errorf("Name = %q, want %q", app.Name, "lammps")
}
if app.ID != id {
t.Errorf("ID = %d, want %d", app.ID, id)
}
}
func TestApplicationStore_GetByID_NotFound(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
app, err := s.GetByID(context.Background(), 99999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if app != nil {
t.Error("GetByID() expected nil for not-found, got non-nil")
}
}
func TestApplicationStore_List_Pagination(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
for i := 0; i < 5; i++ {
s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "app-" + string(rune('A'+i)),
ScriptTemplate: "#!/bin/bash\necho " + string(rune('A'+i)),
})
}
apps, total, err := s.List(context.Background(), 1, 3)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
if len(apps) != 3 {
t.Errorf("len(apps) = %d, want 3", len(apps))
}
apps2, total2, err := s.List(context.Background(), 2, 3)
if err != nil {
t.Fatalf("List() page 2 error = %v", err)
}
if total2 != 5 {
t.Errorf("total2 = %d, want 5", total2)
}
if len(apps2) != 2 {
t.Errorf("len(apps2) = %d, want 2", len(apps2))
}
}
func TestApplicationStore_Update_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "orig",
ScriptTemplate: "#!/bin/bash\necho original",
})
newName := "updated"
newDesc := "updated description"
err := s.Update(context.Background(), id, &model.UpdateApplicationRequest{
Name: &newName,
Description: &newDesc,
})
if err != nil {
t.Fatalf("Update() error = %v", err)
}
app, _ := s.GetByID(context.Background(), id)
if app.Name != "updated" {
t.Errorf("Name = %q, want %q", app.Name, "updated")
}
if app.Description != "updated description" {
t.Errorf("Description = %q, want %q", app.Description, "updated description")
}
}
func TestApplicationStore_Update_NotFound(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
name := "nope"
err := s.Update(context.Background(), 99999, &model.UpdateApplicationRequest{
Name: &name,
})
if err == nil {
t.Fatal("Update() expected error for not-found, got nil")
}
}
func TestApplicationStore_Delete_Success(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "to-delete",
ScriptTemplate: "#!/bin/bash\necho bye",
})
err := s.Delete(context.Background(), id)
if err != nil {
t.Fatalf("Delete() error = %v", err)
}
app, _ := s.GetByID(context.Background(), id)
if app != nil {
t.Error("GetByID() after delete returned non-nil")
}
}
func TestApplicationStore_Delete_Idempotent(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
err := s.Delete(context.Background(), 99999)
if err != nil {
t.Fatalf("Delete() non-existent error = %v, want nil (idempotent)", err)
}
}
func TestApplicationStore_Create_DuplicateName(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
_, err := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "dup-app",
ScriptTemplate: "#!/bin/bash\necho 1",
})
if err != nil {
t.Fatalf("first Create() error = %v", err)
}
_, err = s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "dup-app",
ScriptTemplate: "#!/bin/bash\necho 2",
})
if err == nil {
t.Fatal("expected error for duplicate name, got nil")
}
}
func TestApplicationStore_Create_EmptyParameters(t *testing.T) {
db := newAppTestDB(t)
s := NewApplicationStore(db)
id, err := s.Create(context.Background(), &model.CreateApplicationRequest{
Name: "no-params",
ScriptTemplate: "#!/bin/bash\necho hello",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
app, _ := s.GetByID(context.Background(), id)
if string(app.Parameters) != "[]" {
t.Errorf("Parameters = %q, want []", string(app.Parameters))
}
}

View File

@@ -0,0 +1,108 @@
package store
import (
"context"
"errors"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// BlobStore manages physical file blobs with reference counting.
type BlobStore struct {
db *gorm.DB
}
// NewBlobStore creates a new BlobStore.
func NewBlobStore(db *gorm.DB) *BlobStore {
return &BlobStore{db: db}
}
// Create inserts a new FileBlob record.
func (s *BlobStore) Create(ctx context.Context, blob *model.FileBlob) error {
return s.db.WithContext(ctx).Create(blob).Error
}
// GetBySHA256 returns the FileBlob with the given SHA256 hash.
// Returns (nil, nil) if not found.
func (s *BlobStore) GetBySHA256(ctx context.Context, sha256 string) (*model.FileBlob, error) {
var blob model.FileBlob
err := s.db.WithContext(ctx).Where("sha256 = ?", sha256).First(&blob).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &blob, nil
}
// IncrementRef atomically increments the ref_count for the blob with the given SHA256.
func (s *BlobStore) IncrementRef(ctx context.Context, sha256 string) error {
result := s.db.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ?", sha256).
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DecrementRef atomically decrements the ref_count for the blob with the given SHA256.
// Returns the new ref_count after decrementing.
func (s *BlobStore) DecrementRef(ctx context.Context, sha256 string) (int64, error) {
result := s.db.WithContext(ctx).Model(&model.FileBlob{}).
Where("sha256 = ? AND ref_count > 0", sha256).
UpdateColumn("ref_count", gorm.Expr("ref_count - 1"))
if result.Error != nil {
return 0, result.Error
}
if result.RowsAffected == 0 {
return 0, gorm.ErrRecordNotFound
}
var blob model.FileBlob
if err := s.db.WithContext(ctx).Where("sha256 = ?", sha256).First(&blob).Error; err != nil {
return 0, err
}
return int64(blob.RefCount), nil
}
// Delete removes a FileBlob record by SHA256 (hard delete).
func (s *BlobStore) Delete(ctx context.Context, sha256 string) error {
result := s.db.WithContext(ctx).Where("sha256 = ?", sha256).Delete(&model.FileBlob{})
if result.Error != nil {
return result.Error
}
return nil
}
func (s *BlobStore) GetBySHA256s(ctx context.Context, sha256s []string) ([]model.FileBlob, error) {
var blobs []model.FileBlob
if len(sha256s) == 0 {
return blobs, nil
}
err := s.db.WithContext(ctx).Where("sha256 IN ?", sha256s).Find(&blobs).Error
return blobs, err
}
// GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock.
// Returns (nil, nil) if not found.
func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) {
var blob model.FileBlob
err := tx.WithContext(ctx).
Clauses(clause.Locking{Strength: "UPDATE"}).
Where("sha256 = ?", sha256).First(&blob).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &blob, nil
}

View File

@@ -0,0 +1,199 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupBlobTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.FileBlob{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestBlobStore_Create(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blob := &model.FileBlob{
SHA256: "abc123",
MinioKey: "files/abc123",
FileSize: 1024,
MimeType: "application/octet-stream",
RefCount: 0,
}
if err := store.Create(ctx, blob); err != nil {
t.Fatalf("Create() error = %v", err)
}
if blob.ID == 0 {
t.Error("Create() did not set ID")
}
}
func TestBlobStore_GetBySHA256(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
blob, err := store.GetBySHA256(ctx, "abc")
if err != nil {
t.Fatalf("GetBySHA256() error = %v", err)
}
if blob == nil {
t.Fatal("GetBySHA256() returned nil")
}
if blob.SHA256 != "abc" {
t.Errorf("SHA256 = %q, want %q", blob.SHA256, "abc")
}
if blob.RefCount != 0 {
t.Errorf("RefCount = %d, want 0", blob.RefCount)
}
}
func TestBlobStore_GetBySHA256_NotFound(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blob, err := store.GetBySHA256(ctx, "nonexistent")
if err != nil {
t.Fatalf("GetBySHA256() error = %v", err)
}
if blob != nil {
t.Error("GetBySHA256() should return nil for not found")
}
}
func TestBlobStore_IncrementDecrementRef(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
if err := store.IncrementRef(ctx, "abc"); err != nil {
t.Fatalf("IncrementRef() error = %v", err)
}
blob, _ := store.GetBySHA256(ctx, "abc")
if blob.RefCount != 1 {
t.Errorf("RefCount after 1st increment = %d, want 1", blob.RefCount)
}
store.IncrementRef(ctx, "abc")
blob, _ = store.GetBySHA256(ctx, "abc")
if blob.RefCount != 2 {
t.Errorf("RefCount after 2nd increment = %d, want 2", blob.RefCount)
}
refCount, err := store.DecrementRef(ctx, "abc")
if err != nil {
t.Fatalf("DecrementRef() error = %v", err)
}
if refCount != 1 {
t.Errorf("DecrementRef() returned %d, want 1", refCount)
}
refCount, _ = store.DecrementRef(ctx, "abc")
if refCount != 0 {
t.Errorf("DecrementRef() returned %d, want 0", refCount)
}
}
func TestBlobStore_Delete(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
if err := store.Delete(ctx, "abc"); err != nil {
t.Fatalf("Delete() error = %v", err)
}
blob, _ := store.GetBySHA256(ctx, "abc")
if blob != nil {
t.Error("Delete() did not remove blob")
}
}
func TestBlobStore_GetBySHA256s(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "h1", MinioKey: "files/h1", FileSize: 100})
store.Create(ctx, &model.FileBlob{SHA256: "h2", MinioKey: "files/h2", FileSize: 200})
store.Create(ctx, &model.FileBlob{SHA256: "h3", MinioKey: "files/h3", FileSize: 300})
blobs, err := store.GetBySHA256s(ctx, []string{"h1", "h3"})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 2 {
t.Fatalf("len(blobs) = %d, want 2", len(blobs))
}
keys := map[string]bool{}
for _, b := range blobs {
keys[b.SHA256] = true
}
if !keys["h1"] || !keys["h3"] {
t.Errorf("expected h1 and h3, got %v", blobs)
}
}
func TestBlobStore_GetBySHA256s_Empty(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blobs, err := store.GetBySHA256s(ctx, []string{})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 0 {
t.Errorf("len(blobs) = %d, want 0", len(blobs))
}
}
func TestBlobStore_GetBySHA256s_NotFound(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
blobs, err := store.GetBySHA256s(ctx, []string{"nonexistent"})
if err != nil {
t.Fatalf("GetBySHA256s() error = %v", err)
}
if len(blobs) != 0 {
t.Errorf("len(blobs) = %d, want 0 for non-existent SHA256s", len(blobs))
}
}
func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) {
db := setupBlobTestDB(t)
store := NewBlobStore(db)
ctx := context.Background()
store.Create(ctx, &model.FileBlob{SHA256: "dup", MinioKey: "files/dup1", FileSize: 100})
err := store.Create(ctx, &model.FileBlob{SHA256: "dup", MinioKey: "files/dup2", FileSize: 200})
if err == nil {
t.Error("expected error for duplicate SHA256, got nil")
}
}

View File

@@ -0,0 +1,108 @@
package store
import (
"context"
"errors"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type FileStore struct {
db *gorm.DB
}
func NewFileStore(db *gorm.DB) *FileStore {
return &FileStore{db: db}
}
func (s *FileStore) Create(ctx context.Context, file *model.File) error {
return s.db.WithContext(ctx).Create(file).Error
}
func (s *FileStore) GetByID(ctx context.Context, id int64) (*model.File, error) {
var file model.File
err := s.db.WithContext(ctx).First(&file, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &file, nil
}
func (s *FileStore) List(ctx context.Context, folderID *int64, page, pageSize int) ([]model.File, int64, error) {
query := s.db.WithContext(ctx).Model(&model.File{})
if folderID == nil {
query = query.Where("folder_id IS NULL")
} else {
query = query.Where("folder_id = ?", *folderID)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var files []model.File
offset := (page - 1) * pageSize
if err := query.Order("id DESC").Limit(pageSize).Offset(offset).Find(&files).Error; err != nil {
return nil, 0, err
}
return files, total, nil
}
func (s *FileStore) Search(ctx context.Context, queryStr string, page, pageSize int) ([]model.File, int64, error) {
query := s.db.WithContext(ctx).Model(&model.File{}).Where("name LIKE ?", "%"+queryStr+"%")
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var files []model.File
offset := (page - 1) * pageSize
if err := query.Order("id DESC").Limit(pageSize).Offset(offset).Find(&files).Error; err != nil {
return nil, 0, err
}
return files, total, nil
}
func (s *FileStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.File{}, id)
if result.Error != nil {
return result.Error
}
return nil
}
func (s *FileStore) CountByBlobSHA256(ctx context.Context, blobSHA256 string) (int64, error) {
var count int64
err := s.db.WithContext(ctx).Model(&model.File{}).
Where("blob_sha256 = ?", blobSHA256).
Count(&count).Error
return count, err
}
func (s *FileStore) GetByIDs(ctx context.Context, ids []int64) ([]model.File, error) {
var files []model.File
if len(ids) == 0 {
return files, nil
}
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&files).Error
return files, err
}
func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) {
var file model.File
err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil
}
if err != nil {
return "", err
}
return file.BlobSHA256, nil
}

View File

@@ -0,0 +1,323 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupFileTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestFileStore_CreateAndGetByID(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
file := &model.File{
Name: "test.bin",
BlobSHA256: "abc123",
}
if err := store.Create(ctx, file); err != nil {
t.Fatalf("Create() error = %v", err)
}
if file.ID == 0 {
t.Fatal("Create() did not set ID")
}
got, err := store.GetByID(ctx, file.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.Name != "test.bin" {
t.Errorf("Name = %q, want %q", got.Name, "test.bin")
}
if got.BlobSHA256 != "abc123" {
t.Errorf("BlobSHA256 = %q, want %q", got.BlobSHA256, "abc123")
}
}
func TestFileStore_GetByID_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
got, err := store.GetByID(ctx, 999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() should return nil for not found")
}
}
func TestFileStore_ListByFolder(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
folderID := int64(1)
store.Create(ctx, &model.File{Name: "f1.bin", BlobSHA256: "a1", FolderID: &folderID})
store.Create(ctx, &model.File{Name: "f2.bin", BlobSHA256: "a2", FolderID: &folderID})
store.Create(ctx, &model.File{Name: "root.bin", BlobSHA256: "a3"}) // root (folder_id=nil)
files, total, err := store.List(ctx, &folderID, 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
if len(files) != 2 {
t.Errorf("len(files) = %d, want 2", len(files))
}
}
func TestFileStore_ListRootFolder(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "root.bin", BlobSHA256: "a1"})
folderID := int64(1)
store.Create(ctx, &model.File{Name: "sub.bin", BlobSHA256: "a2", FolderID: &folderID})
files, total, err := store.List(ctx, nil, 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
if len(files) != 1 {
t.Errorf("len(files) = %d, want 1", len(files))
}
if files[0].Name != "root.bin" {
t.Errorf("files[0].Name = %q, want %q", files[0].Name, "root.bin")
}
}
func TestFileStore_Pagination(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
for i := 0; i < 25; i++ {
store.Create(ctx, &model.File{Name: "file.bin", BlobSHA256: "hash"})
}
files, total, err := store.List(ctx, nil, 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 25 {
t.Errorf("total = %d, want 25", total)
}
if len(files) != 10 {
t.Errorf("page 1 len = %d, want 10", len(files))
}
files, _, _ = store.List(ctx, nil, 3, 10)
if len(files) != 5 {
t.Errorf("page 3 len = %d, want 5", len(files))
}
}
func TestFileStore_Search(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "experiment_results.csv", BlobSHA256: "a1"})
store.Create(ctx, &model.File{Name: "training_log.txt", BlobSHA256: "a2"})
store.Create(ctx, &model.File{Name: "model_weights.bin", BlobSHA256: "a3"})
files, total, err := store.Search(ctx, "results", 1, 10)
if err != nil {
t.Fatalf("Search() error = %v", err)
}
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
if len(files) != 1 || files[0].Name != "experiment_results.csv" {
t.Errorf("expected experiment_results.csv, got %v", files)
}
}
func TestFileStore_Delete_SoftDelete(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
file := &model.File{Name: "deleteme.bin", BlobSHA256: "abc"}
store.Create(ctx, file)
if err := store.Delete(ctx, file.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
// GetByID should return nil (soft deleted)
got, err := store.GetByID(ctx, file.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() should return nil after soft delete")
}
// List should not include soft deleted
_, total, _ := store.List(ctx, nil, 1, 10)
if total != 0 {
t.Errorf("total after delete = %d, want 0", total)
}
}
func TestFileStore_CountByBlobSHA256(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "shared_hash"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "shared_hash"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "shared_hash"})
count, err := store.CountByBlobSHA256(ctx, "shared_hash")
if err != nil {
t.Fatalf("CountByBlobSHA256() error = %v", err)
}
if count != 3 {
t.Errorf("count = %d, want 3", count)
}
// Soft delete one
store.Delete(ctx, 1)
count, _ = store.CountByBlobSHA256(ctx, "shared_hash")
if count != 2 {
t.Errorf("count after soft delete = %d, want 2", count)
}
}
func TestFileStore_GetBlobSHA256ByID(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
file := &model.File{Name: "test.bin", BlobSHA256: "my_hash"}
store.Create(ctx, file)
sha256, err := store.GetBlobSHA256ByID(ctx, file.ID)
if err != nil {
t.Fatalf("GetBlobSHA256ByID() error = %v", err)
}
if sha256 != "my_hash" {
t.Errorf("sha256 = %q, want %q", sha256, "my_hash")
}
}
func TestFileStore_GetByIDs(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
files, err := store.GetByIDs(ctx, []int64{1, 3})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 2 {
t.Fatalf("len(files) = %d, want 2", len(files))
}
names := map[string]bool{}
for _, f := range files {
names[f.Name] = true
}
if !names["a.bin"] || !names["c.bin"] {
t.Errorf("expected a.bin and c.bin, got %v", files)
}
}
func TestFileStore_GetByIDs_Empty(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
files, err := store.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 0 {
t.Errorf("len(files) = %d, want 0", len(files))
}
}
func TestFileStore_GetByIDs_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
files, err := store.GetByIDs(ctx, []int64{999})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 0 {
t.Errorf("len(files) = %d, want 0 for non-existent IDs", len(files))
}
}
func TestFileStore_GetByIDs_SoftDeleteExcluded(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
store.Delete(ctx, 2)
files, err := store.GetByIDs(ctx, []int64{1, 2, 3})
if err != nil {
t.Fatalf("GetByIDs() error = %v", err)
}
if len(files) != 2 {
t.Fatalf("len(files) = %d, want 2 (soft-deleted excluded)", len(files))
}
for _, f := range files {
if f.ID == 2 {
t.Error("soft-deleted file ID 2 should not appear")
}
}
}
func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) {
db := setupFileTestDB(t)
store := NewFileStore(db)
ctx := context.Background()
sha256, err := store.GetBlobSHA256ByID(ctx, 999)
if err != nil {
t.Fatalf("GetBlobSHA256ByID() error = %v", err)
}
if sha256 != "" {
t.Errorf("sha256 = %q, want empty for not found", sha256)
}
}

View File

@@ -0,0 +1,105 @@
package store
import (
"context"
"errors"
"fmt"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type FolderStore struct {
db *gorm.DB
}
func NewFolderStore(db *gorm.DB) *FolderStore {
return &FolderStore{db: db}
}
func (s *FolderStore) Create(ctx context.Context, folder *model.Folder) error {
return s.db.WithContext(ctx).Create(folder).Error
}
func (s *FolderStore) GetByID(ctx context.Context, id int64) (*model.Folder, error) {
var folder model.Folder
err := s.db.WithContext(ctx).First(&folder, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &folder, nil
}
func (s *FolderStore) GetByPath(ctx context.Context, path string) (*model.Folder, error) {
var folder model.Folder
err := s.db.WithContext(ctx).Where("path = ?", path).First(&folder).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &folder, nil
}
func (s *FolderStore) ListByParentID(ctx context.Context, parentID *int64) ([]model.Folder, error) {
var folders []model.Folder
query := s.db.WithContext(ctx)
if parentID == nil {
query = query.Where("parent_id IS NULL")
} else {
query = query.Where("parent_id = ?", *parentID)
}
if err := query.Order("name ASC").Find(&folders).Error; err != nil {
return nil, err
}
return folders, nil
}
// GetSubTree returns all folders whose path starts with the given prefix.
func (s *FolderStore) GetSubTree(ctx context.Context, path string) ([]model.Folder, error) {
var folders []model.Folder
if err := s.db.WithContext(ctx).Where("path LIKE ?", path+"%").Find(&folders).Error; err != nil {
return nil, err
}
return folders, nil
}
// HasChildren checks if a folder has sub-folders or files.
func (s *FolderStore) HasChildren(ctx context.Context, id int64) (bool, error) {
folder, err := s.GetByID(ctx, id)
if err != nil {
return false, err
}
if folder == nil {
return false, nil
}
// Check for sub-folders
var subFolderCount int64
if err := s.db.WithContext(ctx).Model(&model.Folder{}).Where("parent_id = ?", id).Count(&subFolderCount).Error; err != nil {
return false, fmt.Errorf("count sub-folders: %w", err)
}
if subFolderCount > 0 {
return true, nil
}
// Check for files
var fileCount int64
if err := s.db.WithContext(ctx).Model(&model.File{}).Where("folder_id = ?", id).Count(&fileCount).Error; err != nil {
return false, fmt.Errorf("count files: %w", err)
}
return fileCount > 0, nil
}
func (s *FolderStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.Folder{}, id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,294 @@
package store
import (
"context"
"testing"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupFolderTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Folder{}, &model.File{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestFolderStore_CreateAndGetByID(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{
Name: "data",
Path: "/data/",
}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create() error = %v", err)
}
if folder.ID <= 0 {
t.Fatalf("Create() id = %d, want positive", folder.ID)
}
got, err := s.GetByID(context.Background(), folder.ID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.Name != "data" {
t.Errorf("Name = %q, want %q", got.Name, "data")
}
if got.Path != "/data/" {
t.Errorf("Path = %q, want %q", got.Path, "/data/")
}
}
func TestFolderStore_GetByID_NotFound(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
got, err := s.GetByID(context.Background(), 99999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() expected nil for not-found")
}
}
func TestFolderStore_GetByPath(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{
Name: "data",
Path: "/data/",
}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create() error = %v", err)
}
got, err := s.GetByPath(context.Background(), "/data/")
if err != nil {
t.Fatalf("GetByPath() error = %v", err)
}
if got == nil {
t.Fatal("GetByPath() returned nil")
}
if got.ID != folder.ID {
t.Errorf("ID = %d, want %d", got.ID, folder.ID)
}
got, err = s.GetByPath(context.Background(), "/nonexistent/")
if err != nil {
t.Fatalf("GetByPath() nonexistent error = %v", err)
}
if got != nil {
t.Error("GetByPath() expected nil for nonexistent path")
}
}
func TestFolderStore_ListByParentID(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
root1 := &model.Folder{Name: "alpha", Path: "/alpha/"}
root2 := &model.Folder{Name: "beta", Path: "/beta/"}
if err := s.Create(context.Background(), root1); err != nil {
t.Fatalf("Create root1: %v", err)
}
if err := s.Create(context.Background(), root2); err != nil {
t.Fatalf("Create root2: %v", err)
}
sub := &model.Folder{Name: "sub", Path: "/alpha/sub/", ParentID: &root1.ID}
if err := s.Create(context.Background(), sub); err != nil {
t.Fatalf("Create sub: %v", err)
}
roots, err := s.ListByParentID(context.Background(), nil)
if err != nil {
t.Fatalf("ListByParentID(nil) error = %v", err)
}
if len(roots) != 2 {
t.Fatalf("root folders = %d, want 2", len(roots))
}
if roots[0].Name != "alpha" {
t.Errorf("roots[0].Name = %q, want %q (alphabetical)", roots[0].Name, "alpha")
}
children, err := s.ListByParentID(context.Background(), &root1.ID)
if err != nil {
t.Fatalf("ListByParentID(root1) error = %v", err)
}
if len(children) != 1 {
t.Fatalf("children = %d, want 1", len(children))
}
if children[0].Name != "sub" {
t.Errorf("children[0].Name = %q, want %q", children[0].Name, "sub")
}
}
func TestFolderStore_GetSubTree(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
data := &model.Folder{Name: "data", Path: "/data/"}
if err := s.Create(context.Background(), data); err != nil {
t.Fatalf("Create data: %v", err)
}
results := &model.Folder{Name: "results", Path: "/data/results/", ParentID: &data.ID}
if err := s.Create(context.Background(), results); err != nil {
t.Fatalf("Create results: %v", err)
}
other := &model.Folder{Name: "other", Path: "/other/"}
if err := s.Create(context.Background(), other); err != nil {
t.Fatalf("Create other: %v", err)
}
subtree, err := s.GetSubTree(context.Background(), "/data/")
if err != nil {
t.Fatalf("GetSubTree() error = %v", err)
}
if len(subtree) != 2 {
t.Fatalf("subtree = %d, want 2", len(subtree))
}
}
func TestFolderStore_HasChildren_WithSubFolders(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
parent := &model.Folder{Name: "parent", Path: "/parent/"}
if err := s.Create(context.Background(), parent); err != nil {
t.Fatalf("Create parent: %v", err)
}
child := &model.Folder{Name: "child", Path: "/parent/child/", ParentID: &parent.ID}
if err := s.Create(context.Background(), child); err != nil {
t.Fatalf("Create child: %v", err)
}
has, err := s.HasChildren(context.Background(), parent.ID)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if !has {
t.Error("HasChildren() = false, want true (has sub-folders)")
}
}
func TestFolderStore_HasChildren_WithFiles(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{Name: "docs", Path: "/docs/"}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create folder: %v", err)
}
file := &model.File{
Name: "readme.txt",
FolderID: &folder.ID,
BlobSHA256: "abc123",
}
if err := db.Create(file).Error; err != nil {
t.Fatalf("Create file: %v", err)
}
has, err := s.HasChildren(context.Background(), folder.ID)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if !has {
t.Error("HasChildren() = false, want true (has files)")
}
}
func TestFolderStore_HasChildren_Empty(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{Name: "empty", Path: "/empty/"}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create folder: %v", err)
}
has, err := s.HasChildren(context.Background(), folder.ID)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if has {
t.Error("HasChildren() = true, want false (empty folder)")
}
}
func TestFolderStore_HasChildren_NotFound(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
has, err := s.HasChildren(context.Background(), 99999)
if err != nil {
t.Fatalf("HasChildren() error = %v", err)
}
if has {
t.Error("HasChildren() = true for nonexistent, want false")
}
}
func TestFolderStore_Delete(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
folder := &model.Folder{Name: "temp", Path: "/temp/"}
if err := s.Create(context.Background(), folder); err != nil {
t.Fatalf("Create() error = %v", err)
}
if err := s.Delete(context.Background(), folder.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
got, err := s.GetByID(context.Background(), folder.ID)
if err != nil {
t.Fatalf("GetByID() after delete error = %v", err)
}
if got != nil {
t.Error("GetByID() after delete returned non-nil, expected soft-deleted")
}
}
func TestFolderStore_Delete_Idempotent(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
if err := s.Delete(context.Background(), 99999); err != nil {
t.Fatalf("Delete() non-existent error = %v, want nil (idempotent)", err)
}
}
func TestFolderStore_Path_UniqueConstraint(t *testing.T) {
db := setupFolderTestDB(t)
s := NewFolderStore(db)
f1 := &model.Folder{Name: "data", Path: "/data/"}
if err := s.Create(context.Background(), f1); err != nil {
t.Fatalf("Create first: %v", err)
}
f2 := &model.Folder{Name: "data2", Path: "/data/"}
if err := s.Create(context.Background(), f2); err == nil {
t.Fatal("expected error for duplicate path, got nil")
}
}

View File

@@ -1 +0,0 @@
DROP TABLE IF EXISTS job_templates;

View File

@@ -1,14 +0,0 @@
CREATE TABLE IF NOT EXISTS job_templates (
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT,
script TEXT NOT NULL,
partition VARCHAR(255),
qos VARCHAR(255),
cpus INT UNSIGNED,
memory VARCHAR(50),
time_limit VARCHAR(50),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY idx_name (name)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

View File

@@ -40,5 +40,13 @@ func NewGormDB(dsn string, zapLogger *zap.Logger, gormLevel string) (*gorm.DB, e
// AutoMigrate runs GORM auto-migration for all models.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(&model.JobTemplate{})
return db.AutoMigrate(
&model.Application{},
&model.FileBlob{},
&model.File{},
&model.Folder{},
&model.UploadSession{},
&model.UploadChunk{},
&model.Task{},
)
}

View File

@@ -0,0 +1,141 @@
package store
import (
"context"
"errors"
"fmt"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
)
type TaskStore struct {
db *gorm.DB
}
func NewTaskStore(db *gorm.DB) *TaskStore {
return &TaskStore{db: db}
}
func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) {
if err := s.db.WithContext(ctx).Create(task).Error; err != nil {
return 0, err
}
return task.ID, nil
}
func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) {
var task model.Task
err := s.db.WithContext(ctx).First(&task, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &task, nil
}
func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
page := query.Page
pageSize := query.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 10
}
q := s.db.WithContext(ctx).Model(&model.Task{})
if query.Status != "" {
q = q.Where("status = ?", query.Status)
}
var total int64
if err := q.Count(&total).Error; err != nil {
return nil, 0, err
}
var tasks []model.Task
offset := (page - 1) * pageSize
if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
return nil, 0, err
}
return tasks, total, nil
}
func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error {
updates := map[string]interface{}{
"status": status,
"error_message": errorMsg,
}
now := time.Now()
switch status {
case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady,
model.TaskStatusQueued, model.TaskStatusRunning:
updates["started_at"] = &now
case model.TaskStatusCompleted, model.TaskStatusFailed:
updates["finished_at"] = &now
}
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
Update("slurm_job_id", slurmJobID)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
Update("work_dir", workDir)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error {
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
"status": status,
"current_step": currentStep,
"retry_count": retryCount,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("task %d not found", id)
}
return nil
}
func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) {
cutoff := time.Now().Add(-maxAge)
var tasks []model.Task
err := s.db.WithContext(ctx).
Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}).
Where("updated_at < ?", cutoff).
Find(&tasks).Error
return tasks, err
}

View File

@@ -0,0 +1,229 @@
package store
import (
"context"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTaskTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.Task{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func makeTestTask(name, status string) *model.Task {
return &model.Task{
TaskName: name,
AppID: 1,
AppName: "test-app",
Status: status,
CurrentStep: "",
RetryCount: 0,
UserID: "user1",
SubmittedAt: time.Now(),
}
}
func TestTaskStore_CreateAndGetByID(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
task := makeTestTask("test-task", model.TaskStatusSubmitted)
id, err := s.Create(ctx, task)
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Errorf("Create() id = %d, want positive", id)
}
got, err := s.GetByID(ctx, id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got == nil {
t.Fatal("GetByID() returned nil")
}
if got.TaskName != "test-task" {
t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task")
}
if got.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
}
if got.ID != id {
t.Errorf("ID = %d, want %d", got.ID, id)
}
}
func TestTaskStore_GetByID_NotFound(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
got, err := s.GetByID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got != nil {
t.Error("GetByID() expected nil for not-found, got non-nil")
}
}
func TestTaskStore_ListPagination(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted))
}
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
if len(tasks) != 3 {
t.Errorf("len(tasks) = %d, want 3", len(tasks))
}
tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3})
if err != nil {
t.Fatalf("List() page 2 error = %v", err)
}
if total2 != 5 {
t.Errorf("total2 = %d, want 5", total2)
}
if len(tasks2) != 2 {
t.Errorf("len(tasks2) = %d, want 2", len(tasks2))
}
}
func TestTaskStore_ListStatusFilter(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning))
s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning))
s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted))
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
for _, t2 := range tasks {
if t2.Status != model.TaskStatusRunning {
t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning)
}
}
}
func TestTaskStore_UpdateStatus(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted))
err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "")
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
got, _ := s.GetByID(ctx, id)
if got.Status != model.TaskStatusPreparing {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing)
}
if got.StartedAt == nil {
t.Error("StartedAt expected non-nil after preparing status")
}
err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke")
if err != nil {
t.Fatalf("UpdateStatus(failed) error = %v", err)
}
got, _ = s.GetByID(ctx, id)
if got.ErrorMessage != "something broke" {
t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke")
}
if got.FinishedAt == nil {
t.Error("FinishedAt expected non-nil after failed status")
}
}
func TestTaskStore_UpdateRetryState(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
task := makeTestTask("retry-test", model.TaskStatusFailed)
task.CurrentStep = model.TaskStepDownloading
task.RetryCount = 1
id, _ := s.Create(ctx, task)
err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2)
if err != nil {
t.Fatalf("UpdateRetryState() error = %v", err)
}
got, _ := s.GetByID(ctx, id)
if got.Status != model.TaskStatusSubmitted {
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
}
if got.CurrentStep != model.TaskStepDownloading {
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading)
}
if got.RetryCount != 2 {
t.Errorf("RetryCount = %d, want 2", got.RetryCount)
}
}
func TestTaskStore_GetStuckTasks(t *testing.T) {
db := newTaskTestDB(t)
s := NewTaskStore(db)
ctx := context.Background()
stuck := makeTestTask("stuck-1", model.TaskStatusDownloading)
stuck.UpdatedAt = time.Now().Add(-1 * time.Hour)
s.Create(ctx, stuck)
recent := makeTestTask("recent-1", model.TaskStatusDownloading)
s.Create(ctx, recent)
done := makeTestTask("done-1", model.TaskStatusCompleted)
done.UpdatedAt = time.Now().Add(-2 * time.Hour)
s.Create(ctx, done)
tasks, err := s.GetStuckTasks(ctx, 30*time.Minute)
if err != nil {
t.Fatalf("GetStuckTasks() error = %v", err)
}
if len(tasks) != 1 {
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
}
if tasks[0].TaskName != "stuck-1" {
t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1")
}
}

View File

@@ -1,119 +0,0 @@
package store
import (
"context"
"errors"
"gorm.io/gorm"
"gcy_hpc_server/internal/model"
)
// TemplateStore provides CRUD operations for job templates via GORM.
type TemplateStore struct {
db *gorm.DB
}
// NewTemplateStore creates a new TemplateStore.
func NewTemplateStore(db *gorm.DB) *TemplateStore {
return &TemplateStore{db: db}
}
// List returns a paginated list of job templates and the total count.
func (s *TemplateStore) List(ctx context.Context, page, pageSize int) ([]model.JobTemplate, int, error) {
var templates []model.JobTemplate
var total int64
if err := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Count(&total).Error; err != nil {
return nil, 0, err
}
offset := (page - 1) * pageSize
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&templates).Error; err != nil {
return nil, 0, err
}
return templates, int(total), nil
}
// GetByID returns a single job template by ID. Returns nil, nil when not found.
func (s *TemplateStore) GetByID(ctx context.Context, id int64) (*model.JobTemplate, error) {
var t model.JobTemplate
err := s.db.WithContext(ctx).First(&t, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &t, nil
}
// Create inserts a new job template and returns the generated ID.
func (s *TemplateStore) Create(ctx context.Context, req *model.CreateTemplateRequest) (int64, error) {
t := &model.JobTemplate{
Name: req.Name,
Description: req.Description,
Script: req.Script,
Partition: req.Partition,
QOS: req.QOS,
CPUs: req.CPUs,
Memory: req.Memory,
TimeLimit: req.TimeLimit,
}
if err := s.db.WithContext(ctx).Create(t).Error; err != nil {
return 0, err
}
return t.ID, nil
}
// Update modifies an existing job template. Only non-empty/non-zero fields are updated.
func (s *TemplateStore) Update(ctx context.Context, id int64, req *model.UpdateTemplateRequest) error {
updates := map[string]interface{}{}
if req.Name != "" {
updates["name"] = req.Name
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.Script != "" {
updates["script"] = req.Script
}
if req.Partition != "" {
updates["partition"] = req.Partition
}
if req.QOS != "" {
updates["qos"] = req.QOS
}
if req.CPUs > 0 {
updates["cpus"] = req.CPUs
}
if req.Memory != "" {
updates["memory"] = req.Memory
}
if req.TimeLimit != "" {
updates["time_limit"] = req.TimeLimit
}
if len(updates) == 0 {
return nil // nothing to update
}
result := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// Delete removes a job template by ID. Idempotent — returns nil even if the row doesn't exist.
func (s *TemplateStore) Delete(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&model.JobTemplate{}, id)
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -1,205 +0,0 @@
package store
import (
"context"
"testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gcy_hpc_server/internal/model"
)
func newTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.JobTemplate{}); err != nil {
t.Fatalf("auto migrate: %v", err)
}
return db
}
func TestTemplateStore_List(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-1", Script: "echo 1"})
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-2", Script: "echo 2"})
templates, total, err := s.List(context.Background(), 1, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 2 {
t.Errorf("total = %d, want 2", total)
}
if len(templates) != 2 {
t.Fatalf("len(templates) = %d, want 2", len(templates))
}
// DESC order, so job-2 is first
if templates[0].Name != "job-2" {
t.Errorf("templates[0].Name = %q, want %q", templates[0].Name, "job-2")
}
}
func TestTemplateStore_List_Page2(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
for i := 0; i < 15; i++ {
s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "job-" + string(rune('A'+i)), Script: "echo",
})
}
templates, total, err := s.List(context.Background(), 2, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if total != 15 {
t.Errorf("total = %d, want 15", total)
}
if len(templates) != 5 {
t.Fatalf("len(templates) = %d, want 5", len(templates))
}
}
func TestTemplateStore_GetByID(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "test-job", Script: "echo hi", Partition: "batch", QOS: "normal", CPUs: 2, Memory: "4G",
})
tpl, err := s.GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if tpl == nil {
t.Fatal("GetByID() returned nil")
}
if tpl.Name != "test-job" {
t.Errorf("Name = %q, want %q", tpl.Name, "test-job")
}
if tpl.CPUs != 2 {
t.Errorf("CPUs = %d, want 2", tpl.CPUs)
}
}
func TestTemplateStore_GetByID_NotFound(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
tpl, err := s.GetByID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByID() error = %v, want nil", err)
}
if tpl != nil {
t.Fatal("GetByID() should return nil for not found")
}
}
func TestTemplateStore_Create(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, err := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "new-job", Script: "echo", Partition: "gpu",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id == 0 {
t.Fatal("Create() returned id=0")
}
}
func TestTemplateStore_Update(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "old", Script: "echo",
})
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
Name: "updated",
Script: "echo new",
CPUs: 8,
})
if err != nil {
t.Fatalf("Update() error = %v", err)
}
tpl, _ := s.GetByID(context.Background(), id)
if tpl.Name != "updated" {
t.Errorf("Name = %q, want %q", tpl.Name, "updated")
}
if tpl.CPUs != 8 {
t.Errorf("CPUs = %d, want 8", tpl.CPUs)
}
}
func TestTemplateStore_Update_Partial(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "original", Script: "echo orig", Partition: "batch",
})
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
Name: "renamed",
})
if err != nil {
t.Fatalf("Update() error = %v", err)
}
tpl, _ := s.GetByID(context.Background(), id)
if tpl.Name != "renamed" {
t.Errorf("Name = %q, want %q", tpl.Name, "renamed")
}
// Script and Partition should be unchanged
if tpl.Script != "echo orig" {
t.Errorf("Script = %q, want %q", tpl.Script, "echo orig")
}
if tpl.Partition != "batch" {
t.Errorf("Partition = %q, want %q", tpl.Partition, "batch")
}
}
func TestTemplateStore_Delete(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
Name: "to-delete", Script: "echo",
})
err := s.Delete(context.Background(), id)
if err != nil {
t.Fatalf("Delete() error = %v", err)
}
tpl, _ := s.GetByID(context.Background(), id)
if tpl != nil {
t.Fatal("Delete() did not remove the record")
}
}
func TestTemplateStore_Delete_NotFound(t *testing.T) {
db := newTestDB(t)
s := NewTemplateStore(db)
err := s.Delete(context.Background(), 999)
if err != nil {
t.Fatalf("Delete() should not error for non-existent record, got: %v", err)
}
}

View File

@@ -0,0 +1,117 @@
package store
import (
"context"
"errors"
"fmt"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UploadStore manages upload sessions and chunks with idempotent upsert support.
type UploadStore struct {
db *gorm.DB
}
// NewUploadStore creates a new UploadStore.
func NewUploadStore(db *gorm.DB) *UploadStore {
return &UploadStore{db: db}
}
// CreateSession inserts a new upload session.
func (s *UploadStore) CreateSession(ctx context.Context, session *model.UploadSession) error {
return s.db.WithContext(ctx).Create(session).Error
}
// GetSession returns an upload session by ID. Returns (nil, nil) if not found.
func (s *UploadStore) GetSession(ctx context.Context, id int64) (*model.UploadSession, error) {
var session model.UploadSession
err := s.db.WithContext(ctx).First(&session, id).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &session, nil
}
// GetSessionWithChunks returns an upload session along with its chunks ordered by chunk_index.
func (s *UploadStore) GetSessionWithChunks(ctx context.Context, id int64) (*model.UploadSession, []model.UploadChunk, error) {
session, err := s.GetSession(ctx, id)
if err != nil {
return nil, nil, err
}
if session == nil {
return nil, nil, nil
}
var chunks []model.UploadChunk
if err := s.db.WithContext(ctx).Where("session_id = ?", id).Order("chunk_index ASC").Find(&chunks).Error; err != nil {
return nil, nil, err
}
return session, chunks, nil
}
// UpdateSessionStatus updates the status field of an upload session.
// Returns gorm.ErrRecordNotFound if no row was affected.
func (s *UploadStore) UpdateSessionStatus(ctx context.Context, id int64, status string) error {
result := s.db.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// ListExpiredSessions returns sessions that are not in a terminal state and have expired.
func (s *UploadStore) ListExpiredSessions(ctx context.Context) ([]model.UploadSession, error) {
var sessions []model.UploadSession
err := s.db.WithContext(ctx).
Where("status NOT IN ?", []string{"completed", "cancelled", "expired"}).
Where("expires_at < ?", time.Now()).
Find(&sessions).Error
return sessions, err
}
// DeleteSession removes all chunks for a session, then the session itself.
func (s *UploadStore) DeleteSession(ctx context.Context, id int64) error {
if err := s.db.WithContext(ctx).Where("session_id = ?", id).Delete(&model.UploadChunk{}).Error; err != nil {
return fmt.Errorf("delete chunks: %w", err)
}
result := s.db.WithContext(ctx).Delete(&model.UploadSession{}, id)
return result.Error
}
// UpsertChunk inserts a chunk or updates it if the (session_id, chunk_index) pair already exists.
// Uses GORM clause.OnConflict for dialect-neutral upsert (works with both SQLite and MySQL).
func (s *UploadStore) UpsertChunk(ctx context.Context, chunk *model.UploadChunk) error {
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "session_id"}, {Name: "chunk_index"}},
DoUpdates: clause.AssignmentColumns([]string{"minio_key", "sha256", "size", "status", "updated_at"}),
}).Create(chunk).Error
}
// GetUploadedChunkIndices returns the chunk indices that have been successfully uploaded.
func (s *UploadStore) GetUploadedChunkIndices(ctx context.Context, sessionID int64) ([]int, error) {
var indices []int
err := s.db.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Pluck("chunk_index", &indices).Error
return indices, err
}
// CountUploadedChunks returns the number of chunks with status "uploaded" for a session.
func (s *UploadStore) CountUploadedChunks(ctx context.Context, sessionID int64) (int, error) {
var count int64
err := s.db.WithContext(ctx).Model(&model.UploadChunk{}).
Where("session_id = ? AND status = ?", sessionID, "uploaded").
Count(&count).Error
return int(count), err
}

View File

@@ -0,0 +1,329 @@
package store
import (
"context"
"testing"
"time"
"gcy_hpc_server/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupUploadTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
if err := db.AutoMigrate(&model.UploadSession{}, &model.UploadChunk{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func newTestSession(status string, expiresAt time.Time) *model.UploadSession {
return &model.UploadSession{
FileName: "test.bin",
FileSize: 100 * 1024 * 1024,
ChunkSize: 16 << 20,
TotalChunks: 7,
SHA256: "abc123",
Status: status,
MinioPrefix: "uploads/1/",
ExpiresAt: expiresAt,
}
}
func newTestChunk(sessionID int64, index int, status string) *model.UploadChunk {
return &model.UploadChunk{
SessionID: sessionID,
ChunkIndex: index,
MinioKey: "uploads/1/chunk_00000",
SHA256: "chunk_hash_0",
Size: 16 << 20,
Status: status,
}
}
func TestUploadStore_CreateSession(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("pending", time.Now().Add(48*time.Hour))
if err := s.CreateSession(context.Background(), session); err != nil {
t.Fatalf("CreateSession() error = %v", err)
}
if session.ID <= 0 {
t.Errorf("ID = %d, want positive", session.ID)
}
if session.FileName != "test.bin" {
t.Errorf("FileName = %q, want %q", session.FileName, "test.bin")
}
if session.ExpiresAt.IsZero() {
t.Error("ExpiresAt is zero, want a real timestamp")
}
}
func TestUploadStore_GetSession(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("pending", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
got, err := s.GetSession(context.Background(), session.ID)
if err != nil {
t.Fatalf("GetSession() error = %v", err)
}
if got == nil {
t.Fatal("GetSession() returned nil")
}
if got.FileName != session.FileName {
t.Errorf("FileName = %q, want %q", got.FileName, session.FileName)
}
if got.Status != "pending" {
t.Errorf("Status = %q, want %q", got.Status, "pending")
}
}
func TestUploadStore_GetSession_NotFound(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
got, err := s.GetSession(context.Background(), 99999)
if err != nil {
t.Fatalf("GetSession() error = %v", err)
}
if got != nil {
t.Error("GetSession() expected nil for not-found")
}
}
func TestUploadStore_UpdateSessionStatus(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("pending", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
if err := s.UpdateSessionStatus(context.Background(), session.ID, "uploading"); err != nil {
t.Fatalf("UpdateSessionStatus() error = %v", err)
}
got, _ := s.GetSession(context.Background(), session.ID)
if got.Status != "uploading" {
t.Errorf("Status = %q, want %q", got.Status, "uploading")
}
}
func TestUploadStore_UpdateSessionStatus_NotFound(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
err := s.UpdateSessionStatus(context.Background(), 99999, "uploading")
if err == nil {
t.Fatal("UpdateSessionStatus() expected error for not-found, got nil")
}
}
func TestUploadStore_GetSessionWithChunks(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 0, MinioKey: "uploads/1/chunk_0", SHA256: "h0", Size: 16 << 20, Status: "uploaded",
})
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 2, MinioKey: "uploads/1/chunk_2", SHA256: "h2", Size: 16 << 20, Status: "uploaded",
})
gotSession, chunks, err := s.GetSessionWithChunks(context.Background(), session.ID)
if err != nil {
t.Fatalf("GetSessionWithChunks() error = %v", err)
}
if gotSession == nil {
t.Fatal("GetSessionWithChunks() session is nil")
}
if len(chunks) != 2 {
t.Fatalf("len(chunks) = %d, want 2", len(chunks))
}
if chunks[0].ChunkIndex != 0 {
t.Errorf("chunks[0].ChunkIndex = %d, want 0", chunks[0].ChunkIndex)
}
if chunks[1].ChunkIndex != 2 {
t.Errorf("chunks[1].ChunkIndex = %d, want 2", chunks[1].ChunkIndex)
}
}
func TestUploadStore_GetSessionWithChunks_NotFound(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
gotSession, chunks, err := s.GetSessionWithChunks(context.Background(), 99999)
if err != nil {
t.Fatalf("GetSessionWithChunks() error = %v", err)
}
if gotSession != nil {
t.Error("expected nil session for not-found")
}
if chunks != nil {
t.Error("expected nil chunks for not-found")
}
}
func TestUploadStore_UpsertChunk_Idempotent(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
chunk := &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 0, MinioKey: "uploads/1/chunk_0", SHA256: "hash_v1", Size: 1024, Status: "uploaded",
}
if err := s.UpsertChunk(context.Background(), chunk); err != nil {
t.Fatalf("first UpsertChunk() error = %v", err)
}
chunk2 := &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 0, MinioKey: "uploads/1/chunk_0", SHA256: "hash_v2", Size: 2048, Status: "uploaded",
}
if err := s.UpsertChunk(context.Background(), chunk2); err != nil {
t.Fatalf("second UpsertChunk() error = %v", err)
}
indices, _ := s.GetUploadedChunkIndices(context.Background(), session.ID)
if len(indices) != 1 {
t.Errorf("len(indices) = %d, want 1 (idempotent)", len(indices))
}
var got model.UploadChunk
db.Where("session_id = ? AND chunk_index = ?", session.ID, 0).First(&got)
if got.SHA256 != "hash_v2" {
t.Errorf("SHA256 = %q, want %q (updated on conflict)", got.SHA256, "hash_v2")
}
}
func TestUploadStore_GetUploadedChunkIndices(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 0, MinioKey: "k0", Size: 100, Status: "uploaded",
})
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 1, MinioKey: "k1", Size: 100, Status: "pending",
})
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 2, MinioKey: "k2", Size: 100, Status: "uploaded",
})
indices, err := s.GetUploadedChunkIndices(context.Background(), session.ID)
if err != nil {
t.Fatalf("GetUploadedChunkIndices() error = %v", err)
}
if len(indices) != 2 {
t.Fatalf("len(indices) = %d, want 2", len(indices))
}
if indices[0] != 0 || indices[1] != 2 {
t.Errorf("indices = %v, want [0 2]", indices)
}
}
func TestUploadStore_CountUploadedChunks(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 0, MinioKey: "k0", Size: 100, Status: "uploaded",
})
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 1, MinioKey: "k1", Size: 100, Status: "uploaded",
})
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 2, MinioKey: "k2", Size: 100, Status: "pending",
})
count, err := s.CountUploadedChunks(context.Background(), session.ID)
if err != nil {
t.Fatalf("CountUploadedChunks() error = %v", err)
}
if count != 2 {
t.Errorf("count = %d, want 2", count)
}
}
func TestUploadStore_ListExpiredSessions(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
past := time.Now().Add(-1 * time.Hour)
future := time.Now().Add(48 * time.Hour)
expired := newTestSession("pending", past)
expired.FileName = "expired.bin"
s.CreateSession(context.Background(), expired)
active := newTestSession("uploading", future)
active.FileName = "active.bin"
s.CreateSession(context.Background(), active)
completed := newTestSession("completed", past)
completed.FileName = "completed.bin"
s.CreateSession(context.Background(), completed)
sessions, err := s.ListExpiredSessions(context.Background())
if err != nil {
t.Fatalf("ListExpiredSessions() error = %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len(sessions) = %d, want 1", len(sessions))
}
if sessions[0].FileName != "expired.bin" {
t.Errorf("FileName = %q, want %q", sessions[0].FileName, "expired.bin")
}
}
func TestUploadStore_DeleteSession(t *testing.T) {
db := setupUploadTestDB(t)
s := NewUploadStore(db)
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
s.CreateSession(context.Background(), session)
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 0, MinioKey: "k0", Size: 100, Status: "uploaded",
})
s.UpsertChunk(context.Background(), &model.UploadChunk{
SessionID: session.ID, ChunkIndex: 1, MinioKey: "k1", Size: 100, Status: "uploaded",
})
if err := s.DeleteSession(context.Background(), session.ID); err != nil {
t.Fatalf("DeleteSession() error = %v", err)
}
got, _ := s.GetSession(context.Background(), session.ID)
if got != nil {
t.Error("session still exists after delete")
}
var chunkCount int64
db.Model(&model.UploadChunk{}).Where("session_id = ?", session.ID).Count(&chunkCount)
if chunkCount != 0 {
t.Errorf("chunkCount = %d, want 0 after delete", chunkCount)
}
}

View File

@@ -0,0 +1,239 @@
// Package mockminio provides an in-memory implementation of storage.ObjectStorage
// for use in tests. It is thread-safe and supports Range reads.
package mockminio
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"sort"
"strings"
"sync"
"time"
"gcy_hpc_server/internal/storage"
)
// Compile-time interface check.
var _ storage.ObjectStorage = (*InMemoryStorage)(nil)
// objectMeta holds metadata for a stored object.
type objectMeta struct {
size int64
etag string
lastModified time.Time
contentType string
}
// InMemoryStorage is a thread-safe, in-memory implementation of
// storage.ObjectStorage. All data is kept in memory; no network or disk I/O
// is performed.
type InMemoryStorage struct {
mu sync.RWMutex
objects map[string][]byte
meta map[string]objectMeta
buckets map[string]bool
}
// NewInMemoryStorage returns a ready-to-use InMemoryStorage.
func NewInMemoryStorage() *InMemoryStorage {
return &InMemoryStorage{
objects: make(map[string][]byte),
meta: make(map[string]objectMeta),
buckets: make(map[string]bool),
}
}
// PutObject reads all bytes from reader and stores them under key.
// The ETag is the SHA-256 hash of the data, formatted as hex.
func (s *InMemoryStorage) PutObject(_ context.Context, _, key string, reader io.Reader, _ int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
data, err := io.ReadAll(reader)
if err != nil {
return storage.UploadInfo{}, fmt.Errorf("read all: %w", err)
}
h := sha256.Sum256(data)
etag := hex.EncodeToString(h[:])
s.mu.Lock()
s.objects[key] = data
s.meta[key] = objectMeta{
size: int64(len(data)),
etag: etag,
lastModified: time.Now(),
contentType: opts.ContentType,
}
s.mu.Unlock()
return storage.UploadInfo{ETag: etag, Size: int64(len(data))}, nil
}
// GetObject retrieves an object. opts.Start and opts.End control byte-range
// reads. Four cases are supported:
// 1. No range (both nil) → return entire object
// 2. Start only (End nil) → from start to end of object
// 3. End only (Start nil) → from byte 0 to end
// 4. Start + End → standard byte range
func (s *InMemoryStorage) GetObject(_ context.Context, _, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
s.mu.RLock()
data, ok := s.objects[key]
meta := s.meta[key]
s.mu.RUnlock()
if !ok {
return nil, storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
}
size := int64(len(data))
// Full object info (Size is always the total object size).
info := storage.ObjectInfo{
Key: key,
Size: size,
ETag: meta.etag,
LastModified: meta.lastModified,
ContentType: meta.contentType,
}
// No range requested → return everything.
if opts.Start == nil && opts.End == nil {
return io.NopCloser(bytes.NewReader(data)), info, nil
}
// Build range. Check each pointer individually to avoid nil dereference.
start := int64(0)
if opts.Start != nil {
start = *opts.Start
}
end := size - 1
if opts.End != nil {
end = *opts.End
}
// Clamp end to last byte.
if end >= size {
end = size - 1
}
if start > end || start < 0 {
return nil, storage.ObjectInfo{}, fmt.Errorf("invalid range: start=%d, end=%d, size=%d", start, end, size)
}
section := io.NewSectionReader(bytes.NewReader(data), start, end-start+1)
return io.NopCloser(section), info, nil
}
// ComposeObject concatenates source objects (in order) into dst.
func (s *InMemoryStorage) ComposeObject(_ context.Context, _, dst string, sources []string) (storage.UploadInfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
var buf bytes.Buffer
for _, src := range sources {
data, ok := s.objects[src]
if !ok {
return storage.UploadInfo{}, fmt.Errorf("source object %s not found", src)
}
buf.Write(data)
}
combined := buf.Bytes()
h := sha256.Sum256(combined)
etag := hex.EncodeToString(h[:])
s.objects[dst] = combined
s.meta[dst] = objectMeta{
size: int64(len(combined)),
etag: etag,
lastModified: time.Now(),
}
return storage.UploadInfo{ETag: etag, Size: int64(len(combined))}, nil
}
// RemoveObject deletes a single object.
func (s *InMemoryStorage) RemoveObject(_ context.Context, _, key string, _ storage.RemoveObjectOptions) error {
s.mu.Lock()
delete(s.objects, key)
delete(s.meta, key)
s.mu.Unlock()
return nil
}
// RemoveObjects deletes multiple objects by key.
func (s *InMemoryStorage) RemoveObjects(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
s.mu.Lock()
for _, k := range keys {
delete(s.objects, k)
delete(s.meta, k)
}
s.mu.Unlock()
return nil
}
// ListObjects returns object info for all objects matching prefix, sorted by key.
func (s *InMemoryStorage) ListObjects(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []storage.ObjectInfo
for k, m := range s.meta {
if strings.HasPrefix(k, prefix) {
result = append(result, storage.ObjectInfo{
Key: k,
Size: m.size,
ETag: m.etag,
LastModified: m.lastModified,
ContentType: m.contentType,
})
}
}
sort.Slice(result, func(i, j int) bool { return result[i].Key < result[j].Key })
return result, nil
}
// BucketExists reports whether the named bucket exists.
func (s *InMemoryStorage) BucketExists(_ context.Context, bucket string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.buckets[bucket], nil
}
// MakeBucket creates a bucket.
func (s *InMemoryStorage) MakeBucket(_ context.Context, bucket string, _ storage.MakeBucketOptions) error {
s.mu.Lock()
s.buckets[bucket] = true
s.mu.Unlock()
return nil
}
// StatObject returns metadata about an object without downloading it.
func (s *InMemoryStorage) StatObject(_ context.Context, _, key string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
s.mu.RLock()
m, ok := s.meta[key]
s.mu.RUnlock()
if !ok {
return storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
}
return storage.ObjectInfo{
Key: key,
Size: m.size,
ETag: m.etag,
LastModified: m.lastModified,
ContentType: m.contentType,
}, nil
}
// AbortMultipartUpload is a no-op for the in-memory implementation.
func (s *InMemoryStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
return nil
}
// RemoveIncompleteUpload is a no-op for the in-memory implementation.
func (s *InMemoryStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
return nil
}

View File

@@ -0,0 +1,378 @@
package mockminio
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"sync"
"testing"
"gcy_hpc_server/internal/storage"
)
func sha256Hex(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
func TestNewInMemoryStorage_ReturnsInitialized(t *testing.T) {
s := NewInMemoryStorage()
if s == nil {
t.Fatal("expected non-nil storage")
}
}
func TestPutObject_StoresData(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
data := []byte("hello world")
info, err := s.PutObject(ctx, "bucket", "key1", bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{ContentType: "text/plain"})
if err != nil {
t.Fatalf("PutObject: %v", err)
}
wantETag := sha256Hex(data)
if info.ETag != wantETag {
t.Errorf("ETag = %q, want %q", info.ETag, wantETag)
}
if info.Size != int64(len(data)) {
t.Errorf("Size = %d, want %d", info.Size, len(data))
}
}
func TestGetObject_FullObject(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
data := []byte("hello world")
s.PutObject(ctx, "bucket", "key1", bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{})
rc, info, err := s.GetObject(ctx, "bucket", "key1", storage.GetOptions{})
if err != nil {
t.Fatalf("GetObject: %v", err)
}
defer rc.Close()
got, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if !bytes.Equal(got, data) {
t.Errorf("got %q, want %q", got, data)
}
if info.Size != int64(len(data)) {
t.Errorf("info.Size = %d, want %d", info.Size, len(data))
}
}
func TestGetObject_RangeStartOnly(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
data := []byte("0123456789")
s.PutObject(ctx, "bucket", "key1", bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{})
start := int64(5)
rc, _, err := s.GetObject(ctx, "bucket", "key1", storage.GetOptions{Start: &start})
if err != nil {
t.Fatalf("GetObject: %v", err)
}
defer rc.Close()
got, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
want := data[5:]
if !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestGetObject_RangeEndOnly(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
data := []byte("0123456789")
s.PutObject(ctx, "bucket", "key1", bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{})
end := int64(4)
rc, _, err := s.GetObject(ctx, "bucket", "key1", storage.GetOptions{End: &end})
if err != nil {
t.Fatalf("GetObject: %v", err)
}
defer rc.Close()
got, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
want := data[:5]
if !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestGetObject_RangeStartAndEnd(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
data := []byte("0123456789")
s.PutObject(ctx, "bucket", "key1", bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{})
start := int64(2)
end := int64(5)
rc, _, err := s.GetObject(ctx, "bucket", "key1", storage.GetOptions{Start: &start, End: &end})
if err != nil {
t.Fatalf("GetObject: %v", err)
}
defer rc.Close()
got, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
want := data[2:6]
if !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestGetObject_NotFound(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
_, _, err := s.GetObject(ctx, "bucket", "nonexistent", storage.GetOptions{})
if err == nil {
t.Fatal("expected error for missing object")
}
}
func TestComposeObject_ConcatenatesSources(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
s.PutObject(ctx, "bucket", "part1", bytes.NewReader([]byte("hello ")), 6, storage.PutObjectOptions{})
s.PutObject(ctx, "bucket", "part2", bytes.NewReader([]byte("world")), 5, storage.PutObjectOptions{})
info, err := s.ComposeObject(ctx, "bucket", "combined", []string{"part1", "part2"})
if err != nil {
t.Fatalf("ComposeObject: %v", err)
}
want := []byte("hello world")
if info.Size != int64(len(want)) {
t.Errorf("Size = %d, want %d", info.Size, len(want))
}
wantETag := sha256Hex(want)
if info.ETag != wantETag {
t.Errorf("ETag = %q, want %q", info.ETag, wantETag)
}
rc, _, err := s.GetObject(ctx, "bucket", "combined", storage.GetOptions{})
if err != nil {
t.Fatalf("GetObject combined: %v", err)
}
defer rc.Close()
got, _ := io.ReadAll(rc)
if !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
func TestComposeObject_MissingSource(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
_, err := s.ComposeObject(ctx, "bucket", "dst", []string{"missing"})
if err == nil {
t.Fatal("expected error for missing source")
}
}
func TestRemoveObject(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
s.PutObject(ctx, "bucket", "key1", bytes.NewReader([]byte("data")), 4, storage.PutObjectOptions{})
err := s.RemoveObject(ctx, "bucket", "key1", storage.RemoveObjectOptions{})
if err != nil {
t.Fatalf("RemoveObject: %v", err)
}
_, _, err = s.GetObject(ctx, "bucket", "key1", storage.GetOptions{})
if err == nil {
t.Fatal("expected error after removal")
}
}
func TestRemoveObjects(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key%d", i)
s.PutObject(ctx, "bucket", key, bytes.NewReader([]byte(key)), int64(len(key)), storage.PutObjectOptions{})
}
err := s.RemoveObjects(ctx, "bucket", []string{"key1", "key3"}, storage.RemoveObjectsOptions{})
if err != nil {
t.Fatalf("RemoveObjects: %v", err)
}
objects, _ := s.ListObjects(ctx, "bucket", "", true)
if len(objects) != 3 {
t.Errorf("got %d objects, want 3", len(objects))
}
}
func TestListObjects(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
s.PutObject(ctx, "bucket", "dir/a", bytes.NewReader([]byte("a")), 1, storage.PutObjectOptions{})
s.PutObject(ctx, "bucket", "dir/b", bytes.NewReader([]byte("bb")), 2, storage.PutObjectOptions{})
s.PutObject(ctx, "bucket", "other/c", bytes.NewReader([]byte("ccc")), 3, storage.PutObjectOptions{})
objects, err := s.ListObjects(ctx, "bucket", "dir/", true)
if err != nil {
t.Fatalf("ListObjects: %v", err)
}
if len(objects) != 2 {
t.Fatalf("got %d objects, want 2", len(objects))
}
if objects[0].Key != "dir/a" || objects[1].Key != "dir/b" {
t.Errorf("keys = %v, want [dir/a dir/b]", []string{objects[0].Key, objects[1].Key})
}
if objects[0].Size != 1 || objects[1].Size != 2 {
t.Errorf("sizes = %v, want [1 2]", []int64{objects[0].Size, objects[1].Size})
}
}
func TestBucketExists(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
ok, _ := s.BucketExists(ctx, "mybucket")
if ok {
t.Error("bucket should not exist yet")
}
s.MakeBucket(ctx, "mybucket", storage.MakeBucketOptions{})
ok, _ = s.BucketExists(ctx, "mybucket")
if !ok {
t.Error("bucket should exist after MakeBucket")
}
}
func TestMakeBucket(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
err := s.MakeBucket(ctx, "test-bucket", storage.MakeBucketOptions{Region: "us-east-1"})
if err != nil {
t.Fatalf("MakeBucket: %v", err)
}
ok, _ := s.BucketExists(ctx, "test-bucket")
if !ok {
t.Error("bucket should exist")
}
}
func TestStatObject(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
data := []byte("test data")
s.PutObject(ctx, "bucket", "key1", bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{ContentType: "text/plain"})
info, err := s.StatObject(ctx, "bucket", "key1", storage.StatObjectOptions{})
if err != nil {
t.Fatalf("StatObject: %v", err)
}
wantETag := sha256Hex(data)
if info.Key != "key1" {
t.Errorf("Key = %q, want %q", info.Key, "key1")
}
if info.Size != int64(len(data)) {
t.Errorf("Size = %d, want %d", info.Size, len(data))
}
if info.ETag != wantETag {
t.Errorf("ETag = %q, want %q", info.ETag, wantETag)
}
if info.ContentType != "text/plain" {
t.Errorf("ContentType = %q, want %q", info.ContentType, "text/plain")
}
if info.LastModified.IsZero() {
t.Error("LastModified should not be zero")
}
}
func TestStatObject_NotFound(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
_, err := s.StatObject(ctx, "bucket", "nonexistent", storage.StatObjectOptions{})
if err == nil {
t.Fatal("expected error for missing object")
}
}
func TestAbortMultipartUpload(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
err := s.AbortMultipartUpload(ctx, "bucket", "key", "upload-id")
if err != nil {
t.Fatalf("AbortMultipartUpload: %v", err)
}
}
func TestRemoveIncompleteUpload(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
err := s.RemoveIncompleteUpload(ctx, "bucket", "key")
if err != nil {
t.Fatalf("RemoveIncompleteUpload: %v", err)
}
}
func TestConcurrentAccess(t *testing.T) {
s := NewInMemoryStorage()
ctx := context.Background()
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines * 2)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
key := fmt.Sprintf("key%d", i%10)
data := []byte(fmt.Sprintf("data-%d", i))
s.PutObject(ctx, "bucket", key, bytes.NewReader(data), int64(len(data)), storage.PutObjectOptions{})
}(i)
go func(i int) {
defer wg.Done()
key := fmt.Sprintf("key%d", i%10)
rc, _, _ := s.GetObject(ctx, "bucket", key, storage.GetOptions{})
if rc != nil {
rc.Close()
}
}(i)
}
wg.Wait()
}

View File

@@ -0,0 +1,544 @@
// Package mockslurm provides a complete HTTP mock server for the Slurm REST API.
// It supports all 11 endpoints (P0: 4 job + P1: 7 cluster/history) and includes
// job eviction from active to history queue on terminal states.
package mockslurm
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"gcy_hpc_server/internal/slurm"
)
// MockJob represents a job tracked by the mock server.
type MockJob struct {
JobID int32
Name string
State string // single state string for internal tracking
Script string
Partition string
WorkDir string
SubmitTime time.Time
StartTime *time.Time
EndTime *time.Time
ExitCode *int32
}
// MockNode represents a node tracked by the mock server.
type MockNode struct {
Name string
}
// MockPartition represents a partition tracked by the mock server.
type MockPartition struct {
Name string
}
// MockSlurm is the mock Slurm API server controller.
type MockSlurm struct {
mu sync.RWMutex
activeJobs map[int32]*MockJob
historyJobs map[int32]*MockJob
nodes []MockNode
partitions []MockPartition
nextID int32
server *httptest.Server
}
// NewMockSlurmServer creates and starts a mock Slurm REST API server.
// Returns the httptest.Server and the MockSlurm controller.
func NewMockSlurmServer() (*httptest.Server, *MockSlurm) {
m := &MockSlurm{
activeJobs: make(map[int32]*MockJob),
historyJobs: make(map[int32]*MockJob),
nodes: []MockNode{
{Name: "node01"},
{Name: "node02"},
{Name: "node03"},
},
partitions: []MockPartition{
{Name: "normal"},
{Name: "gpu"},
},
nextID: 1,
}
mux := http.NewServeMux()
// P0: Exact paths FIRST (before prefix paths)
mux.HandleFunc("/slurm/v0.0.40/job/submit", m.handleJobSubmit)
mux.HandleFunc("/slurm/v0.0.40/jobs", m.handleGetJobs)
// P0: Prefix path for /job/{id} — GET and DELETE
mux.HandleFunc("/slurm/v0.0.40/job/", m.handleJobByID)
// P1: Cluster endpoints
mux.HandleFunc("/slurm/v0.0.40/nodes", m.handleGetNodes)
mux.HandleFunc("/slurm/v0.0.40/node/", m.handleGetNode)
mux.HandleFunc("/slurm/v0.0.40/partitions", m.handleGetPartitions)
mux.HandleFunc("/slurm/v0.0.40/partition/", m.handleGetPartition)
mux.HandleFunc("/slurm/v0.0.40/diag", m.handleDiag)
// P1: SlurmDB endpoints
mux.HandleFunc("/slurmdb/v0.0.40/jobs", m.handleSlurmdbJobs)
mux.HandleFunc("/slurmdb/v0.0.40/job/", m.handleSlurmdbJob)
srv := httptest.NewServer(mux)
m.server = srv
return srv, m
}
// Server returns the underlying httptest.Server.
func (m *MockSlurm) Server() *httptest.Server {
return m.server
}
// ---------------------------------------------------------------------------
// Controller methods
// ---------------------------------------------------------------------------
// SetJobState transitions a job to the given state.
// Terminal states (COMPLETED/FAILED/CANCELLED/TIMEOUT) evict the job from
// activeJobs into historyJobs. RUNNING sets StartTime and stays active.
// PENDING stays in activeJobs.
func (m *MockSlurm) SetJobState(id int32, state string) {
m.mu.Lock()
defer m.mu.Unlock()
mj, ok := m.activeJobs[id]
if !ok {
return
}
now := time.Now()
switch state {
case "RUNNING":
mj.State = state
mj.StartTime = &now
case "COMPLETED", "FAILED", "CANCELLED", "TIMEOUT":
mj.State = state
mj.EndTime = &now
exitCode := int32(0)
if state != "COMPLETED" {
exitCode = 1
}
mj.ExitCode = &exitCode
delete(m.activeJobs, id)
m.historyJobs[id] = mj
case "PENDING":
mj.State = state
}
}
// GetJobState returns the current state of the job with the given ID.
// Returns empty string if the job is not found.
func (m *MockSlurm) GetJobState(id int32) string {
m.mu.RLock()
defer m.mu.RUnlock()
if mj, ok := m.activeJobs[id]; ok {
return mj.State
}
if mj, ok := m.historyJobs[id]; ok {
return mj.State
}
return ""
}
// GetAllActiveJobs returns all jobs currently in the active queue.
func (m *MockSlurm) GetAllActiveJobs() []*MockJob {
m.mu.RLock()
defer m.mu.RUnlock()
jobs := make([]*MockJob, 0, len(m.activeJobs))
for _, mj := range m.activeJobs {
jobs = append(jobs, mj)
}
return jobs
}
// GetAllHistoryJobs returns all jobs in the history queue.
func (m *MockSlurm) GetAllHistoryJobs() []*MockJob {
m.mu.RLock()
defer m.mu.RUnlock()
jobs := make([]*MockJob, 0, len(m.historyJobs))
for _, mj := range m.historyJobs {
jobs = append(jobs, mj)
}
return jobs
}
// ---------------------------------------------------------------------------
// P0: Job Core Endpoints
// ---------------------------------------------------------------------------
// POST /slurm/v0.0.40/job/submit
func (m *MockSlurm) handleJobSubmit(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req slurm.JobSubmitReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
jobID := m.nextID
m.nextID++
job := &MockJob{
JobID: jobID,
State: "PENDING", // MUST be non-empty for mapSlurmStateToTaskStatus
SubmitTime: time.Now(),
}
if req.Script != nil {
job.Script = *req.Script
}
if req.Job != nil {
if req.Job.Name != nil {
job.Name = *req.Job.Name
}
if req.Job.Partition != nil {
job.Partition = *req.Job.Partition
}
if req.Job.CurrentWorkingDirectory != nil {
job.WorkDir = *req.Job.CurrentWorkingDirectory
}
if req.Job.Script != nil {
job.Script = *req.Job.Script
}
}
m.activeJobs[jobID] = job
m.mu.Unlock()
resp := NewSubmitResponse(jobID)
writeJSON(w, http.StatusOK, resp)
}
// GET /slurm/v0.0.40/jobs
func (m *MockSlurm) handleGetJobs(w http.ResponseWriter, r *http.Request) {
m.mu.RLock()
jobs := make([]slurm.JobInfo, 0, len(m.activeJobs))
for _, mj := range m.activeJobs {
jobs = append(jobs, m.mockJobToJobInfo(mj))
}
m.mu.RUnlock()
resp := NewJobInfoResponse(jobs)
writeJSON(w, http.StatusOK, resp)
}
// GET/DELETE /slurm/v0.0.40/job/{id}
func (m *MockSlurm) handleJobByID(w http.ResponseWriter, r *http.Request) {
segments := strings.Split(strings.TrimRight(r.URL.Path, "/"), "/")
// /slurm/v0.0.40/job/{id} → segments[0]="", [1]="slurm", [2]="v0.0.40", [3]="job", [4]=id
if len(segments) < 5 {
m.writeError(w, http.StatusBadRequest, "missing job id")
return
}
last := segments[4]
// Safety net: if "submit" leaks through prefix match, forward to submit handler
if last == "submit" {
m.handleJobSubmit(w, r)
return
}
id, err := strconv.ParseInt(last, 10, 32)
if err != nil {
m.writeError(w, http.StatusBadRequest, "invalid job id")
return
}
switch r.Method {
case http.MethodGet:
m.handleGetJobByID(w, int32(id))
case http.MethodDelete:
m.handleDeleteJobByID(w, int32(id))
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func (m *MockSlurm) handleGetJobByID(w http.ResponseWriter, jobID int32) {
m.mu.RLock()
mj, ok := m.activeJobs[jobID]
m.mu.RUnlock()
if !ok {
m.writeError(w, http.StatusNotFound, "job not found")
return
}
ji := m.mockJobToJobInfo(mj)
resp := NewJobInfoResponse([]slurm.JobInfo{ji})
writeJSON(w, http.StatusOK, resp)
}
func (m *MockSlurm) handleDeleteJobByID(w http.ResponseWriter, jobID int32) {
m.SetJobState(jobID, "CANCELLED")
resp := NewDeleteResponse()
writeJSON(w, http.StatusOK, resp)
}
// ---------------------------------------------------------------------------
// P1: Cluster/History Endpoints
// ---------------------------------------------------------------------------
// GET /slurm/v0.0.40/nodes
func (m *MockSlurm) handleGetNodes(w http.ResponseWriter, r *http.Request) {
nodes := make([]slurm.Node, len(m.nodes))
for i, n := range m.nodes {
nodes[i] = slurm.Node{Name: slurm.Ptr(n.Name)}
}
resp := NewNodeResponse(nodes)
writeJSON(w, http.StatusOK, resp)
}
// GET /slurm/v0.0.40/node/{name}
func (m *MockSlurm) handleGetNode(w http.ResponseWriter, r *http.Request) {
segments := strings.Split(strings.TrimRight(r.URL.Path, "/"), "/")
if len(segments) < 5 {
m.writeError(w, http.StatusBadRequest, "missing node name")
return
}
nodeName := segments[4]
var found *slurm.Node
for _, n := range m.nodes {
if n.Name == nodeName {
found = &slurm.Node{Name: slurm.Ptr(n.Name)}
break
}
}
if found == nil {
m.writeError(w, http.StatusNotFound, "node not found")
return
}
resp := NewNodeResponse([]slurm.Node{*found})
writeJSON(w, http.StatusOK, resp)
}
// GET /slurm/v0.0.40/partitions
func (m *MockSlurm) handleGetPartitions(w http.ResponseWriter, r *http.Request) {
parts := make([]slurm.PartitionInfo, len(m.partitions))
for i, p := range m.partitions {
parts[i] = slurm.PartitionInfo{Name: slurm.Ptr(p.Name)}
}
resp := NewPartitionResponse(parts)
writeJSON(w, http.StatusOK, resp)
}
// GET /slurm/v0.0.40/partition/{name}
func (m *MockSlurm) handleGetPartition(w http.ResponseWriter, r *http.Request) {
segments := strings.Split(strings.TrimRight(r.URL.Path, "/"), "/")
if len(segments) < 5 {
m.writeError(w, http.StatusBadRequest, "missing partition name")
return
}
partName := segments[4]
var found *slurm.PartitionInfo
for _, p := range m.partitions {
if p.Name == partName {
found = &slurm.PartitionInfo{Name: slurm.Ptr(p.Name)}
break
}
}
if found == nil {
m.writeError(w, http.StatusNotFound, "partition not found")
return
}
resp := NewPartitionResponse([]slurm.PartitionInfo{*found})
writeJSON(w, http.StatusOK, resp)
}
// GET /slurm/v0.0.40/diag
func (m *MockSlurm) handleDiag(w http.ResponseWriter, r *http.Request) {
resp := NewDiagResponse()
writeJSON(w, http.StatusOK, resp)
}
// GET /slurmdb/v0.0.40/jobs — supports filter params: job_name, start_time, end_time
func (m *MockSlurm) handleSlurmdbJobs(w http.ResponseWriter, r *http.Request) {
m.mu.RLock()
defer m.mu.RUnlock()
jobs := make([]slurm.Job, 0)
for _, mj := range m.historyJobs {
// Filter by job_name
if name := r.URL.Query().Get("job_name"); name != "" && mj.Name != name {
continue
}
// Filter by start_time (job start must be >= filter start)
if startStr := r.URL.Query().Get("start_time"); startStr != "" {
if st, err := strconv.ParseInt(startStr, 10, 64); err == nil {
if mj.StartTime == nil || mj.StartTime.Unix() < st {
continue
}
}
}
// Filter by end_time (job end must be <= filter end)
if endStr := r.URL.Query().Get("end_time"); endStr != "" {
if et, err := strconv.ParseInt(endStr, 10, 64); err == nil {
if mj.EndTime == nil || mj.EndTime.Unix() > et {
continue
}
}
}
jobs = append(jobs, m.mockJobToSlurmDBJob(mj))
}
resp := NewJobHistoryResponse(jobs)
writeJSON(w, http.StatusOK, resp)
}
// GET /slurmdb/v0.0.40/job/{id} — returns OpenapiSlurmdbdJobsResp (with jobs array wrapper)
func (m *MockSlurm) handleSlurmdbJob(w http.ResponseWriter, r *http.Request) {
segments := strings.Split(strings.TrimRight(r.URL.Path, "/"), "/")
if len(segments) < 5 {
m.writeError(w, http.StatusNotFound, "job not found")
return
}
id, err := strconv.ParseInt(segments[4], 10, 32)
if err != nil {
m.writeError(w, http.StatusNotFound, "job not found")
return
}
m.mu.RLock()
mj, ok := m.historyJobs[int32(id)]
m.mu.RUnlock()
if !ok {
m.writeError(w, http.StatusNotFound, "job not found")
return
}
dbJob := m.mockJobToSlurmDBJob(mj)
resp := NewJobHistoryResponse([]slurm.Job{dbJob})
writeJSON(w, http.StatusOK, resp)
}
// ---------------------------------------------------------------------------
// Conversion helpers
// ---------------------------------------------------------------------------
// mockJobToJobInfo converts a MockJob to an active-endpoint JobInfo.
// Uses buildActiveJobState for flat []string state format: ["RUNNING"].
func (m *MockSlurm) mockJobToJobInfo(mj *MockJob) slurm.JobInfo {
ji := slurm.JobInfo{
JobID: slurm.Ptr(mj.JobID),
JobState: buildActiveJobState(mj.State), // MUST be non-empty []string
Name: slurm.Ptr(mj.Name),
Partition: slurm.Ptr(mj.Partition),
CurrentWorkingDirectory: slurm.Ptr(mj.WorkDir),
SubmitTime: &slurm.Uint64NoVal{Number: slurm.Ptr(mj.SubmitTime.Unix())},
}
if mj.StartTime != nil {
ji.StartTime = &slurm.Uint64NoVal{Number: slurm.Ptr(mj.StartTime.Unix())}
}
if mj.EndTime != nil {
ji.EndTime = &slurm.Uint64NoVal{Number: slurm.Ptr(mj.EndTime.Unix())}
}
if mj.ExitCode != nil {
ji.ExitCode = &slurm.ProcessExitCodeVerbose{
ReturnCode: &slurm.Uint32NoVal{Number: slurm.Ptr(int64(*mj.ExitCode))},
}
}
return ji
}
// mockJobToSlurmDBJob converts a MockJob to a SlurmDB history Job.
// Uses buildHistoryJobState for nested state format: {current: ["COMPLETED"], reason: ""}.
func (m *MockSlurm) mockJobToSlurmDBJob(mj *MockJob) slurm.Job {
dbJob := slurm.Job{
JobID: slurm.Ptr(mj.JobID),
Name: slurm.Ptr(mj.Name),
Partition: slurm.Ptr(mj.Partition),
WorkingDirectory: slurm.Ptr(mj.WorkDir),
Script: slurm.Ptr(mj.Script),
State: buildHistoryJobState(mj.State),
Time: &slurm.JobTime{
Submission: slurm.Ptr(mj.SubmitTime.Unix()),
},
}
if mj.StartTime != nil {
dbJob.Time.Start = slurm.Ptr(mj.StartTime.Unix())
}
if mj.EndTime != nil {
dbJob.Time.End = slurm.Ptr(mj.EndTime.Unix())
}
if mj.ExitCode != nil {
dbJob.ExitCode = &slurm.ProcessExitCodeVerbose{
ReturnCode: &slurm.Uint32NoVal{Number: slurm.Ptr(int64(*mj.ExitCode))},
}
}
return dbJob
}
// ---------------------------------------------------------------------------
// Error helpers
// ---------------------------------------------------------------------------
// writeJSON writes a JSON response with the given status code.
func writeJSON(w http.ResponseWriter, code int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(v)
}
// writeError writes an HTTP error with an OpenapiResp body containing
// meta and errors fields. This is critical for CheckResponse/IsNotFound
// to work correctly — the response body must be parseable as OpenapiResp.
func (m *MockSlurm) writeError(w http.ResponseWriter, statusCode int, message string) {
meta := slurm.OpenapiMeta{
Plugin: &slurm.MetaPlugin{
Type: slurm.Ptr("openapi/v0.0.40"),
Name: slurm.Ptr(""),
},
Slurm: &slurm.MetaSlurm{
Version: &slurm.MetaSlurmVersion{
Major: slurm.Ptr("24"),
Micro: slurm.Ptr("0"),
Minor: slurm.Ptr("5"),
},
Release: slurm.Ptr("24.05.0"),
},
}
resp := slurm.OpenapiResp{
Meta: &meta,
Errors: slurm.OpenapiErrors{
{
ErrorNumber: slurm.Ptr(int32(0)),
Error: slurm.Ptr(message),
},
},
Warnings: slurm.OpenapiWarnings{},
}
writeJSON(w, statusCode, resp)
}

View File

@@ -0,0 +1,679 @@
package mockslurm
import (
"context"
"encoding/json"
"strconv"
"strings"
"testing"
"gcy_hpc_server/internal/slurm"
)
func setupTestClient(t *testing.T) (*slurm.Client, *MockSlurm) {
t.Helper()
srv, mock := NewMockSlurmServer()
t.Cleanup(srv.Close)
client, err := slurm.NewClientWithOpts(srv.URL, slurm.WithHTTPClient(srv.Client()))
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
return client, mock
}
func submitTestJob(t *testing.T, client *slurm.Client, name, partition, workDir, script string) int32 {
t.Helper()
ctx := context.Background()
resp, _, err := client.Jobs.SubmitJob(ctx, &slurm.JobSubmitReq{
Script: slurm.Ptr(script),
Job: &slurm.JobDescMsg{
Name: slurm.Ptr(name),
Partition: slurm.Ptr(partition),
CurrentWorkingDirectory: slurm.Ptr(workDir),
},
})
if err != nil {
t.Fatalf("SubmitJob failed: %v", err)
}
if resp.JobID == nil {
t.Fatal("SubmitJob returned nil JobID")
}
return *resp.JobID
}
// ---------------------------------------------------------------------------
// P0: Submit Job
// ---------------------------------------------------------------------------
func TestSubmitJob(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.Jobs.SubmitJob(ctx, &slurm.JobSubmitReq{
Script: slurm.Ptr("#!/bin/bash\necho hello"),
Job: &slurm.JobDescMsg{
Name: slurm.Ptr("test-job"),
Partition: slurm.Ptr("normal"),
CurrentWorkingDirectory: slurm.Ptr("/tmp/work"),
},
})
if err != nil {
t.Fatalf("SubmitJob failed: %v", err)
}
if resp.JobID == nil || *resp.JobID != 1 {
t.Errorf("JobID = %v, want 1", resp.JobID)
}
if resp.StepID == nil || *resp.StepID != "Scalar" {
t.Errorf("StepID = %v, want Scalar", resp.StepID)
}
if resp.Result == nil || resp.Result.JobID == nil || *resp.Result.JobID != 1 {
t.Errorf("Result.JobID = %v, want 1", resp.Result)
}
}
func TestSubmitJobAutoIncrement(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
for i := 1; i <= 3; i++ {
resp, _, err := client.Jobs.SubmitJob(ctx, &slurm.JobSubmitReq{
Script: slurm.Ptr("#!/bin/bash\necho " + strconv.Itoa(i)),
})
if err != nil {
t.Fatalf("SubmitJob %d failed: %v", i, err)
}
if resp.JobID == nil || *resp.JobID != int32(i) {
t.Errorf("job %d: JobID = %v, want %d", i, resp.JobID, i)
}
}
}
// ---------------------------------------------------------------------------
// P0: Get All Jobs
// ---------------------------------------------------------------------------
func TestGetJobsEmpty(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.Jobs.GetJobs(ctx, nil)
if err != nil {
t.Fatalf("GetJobs failed: %v", err)
}
if len(resp.Jobs) != 0 {
t.Errorf("len(Jobs) = %d, want 0", len(resp.Jobs))
}
}
func TestGetJobsWithSubmitted(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
submitTestJob(t, client, "job-a", "normal", "/tmp/a", "#!/bin/bash\ntrue")
submitTestJob(t, client, "job-b", "gpu", "/tmp/b", "#!/bin/bash\nfalse")
resp, _, err := client.Jobs.GetJobs(ctx, nil)
if err != nil {
t.Fatalf("GetJobs failed: %v", err)
}
if len(resp.Jobs) != 2 {
t.Fatalf("len(Jobs) = %d, want 2", len(resp.Jobs))
}
names := map[string]bool{}
for _, j := range resp.Jobs {
if j.Name != nil {
names[*j.Name] = true
}
if len(j.JobState) == 0 || j.JobState[0] != "PENDING" {
t.Errorf("JobState = %v, want [PENDING]", j.JobState)
}
}
if !names["job-a"] || !names["job-b"] {
t.Errorf("expected job-a and job-b, got %v", names)
}
}
// ---------------------------------------------------------------------------
// P0: Get Job By ID
// ---------------------------------------------------------------------------
func TestGetJobByID(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "single-job", "normal", "/tmp/work", "#!/bin/bash\necho hi")
resp, _, err := client.Jobs.GetJob(ctx, strconv.Itoa(int(jobID)), nil)
if err != nil {
t.Fatalf("GetJob failed: %v", err)
}
if len(resp.Jobs) != 1 {
t.Fatalf("len(Jobs) = %d, want 1", len(resp.Jobs))
}
job := resp.Jobs[0]
if job.JobID == nil || *job.JobID != jobID {
t.Errorf("JobID = %v, want %d", job.JobID, jobID)
}
if job.Name == nil || *job.Name != "single-job" {
t.Errorf("Name = %v, want single-job", job.Name)
}
if job.Partition == nil || *job.Partition != "normal" {
t.Errorf("Partition = %v, want normal", job.Partition)
}
if job.CurrentWorkingDirectory == nil || *job.CurrentWorkingDirectory != "/tmp/work" {
t.Errorf("CurrentWorkingDirectory = %v, want /tmp/work", job.CurrentWorkingDirectory)
}
if job.SubmitTime == nil || job.SubmitTime.Number == nil {
t.Error("SubmitTime should be non-nil")
}
}
func TestGetJobByIDNotFound(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
_, _, err := client.Jobs.GetJob(ctx, "999", nil)
if err == nil {
t.Fatal("expected error for unknown job ID, got nil")
}
if !slurm.IsNotFound(err) {
t.Errorf("error type = %T, want SlurmAPIError with 404", err)
}
}
// ---------------------------------------------------------------------------
// P0: Delete Job (triggers eviction)
// ---------------------------------------------------------------------------
func TestDeleteJob(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "cancel-me", "normal", "/tmp", "#!/bin/bash\nsleep 99")
resp, _, err := client.Jobs.DeleteJob(ctx, strconv.Itoa(int(jobID)), nil)
if err != nil {
t.Fatalf("DeleteJob failed: %v", err)
}
if resp == nil {
t.Fatal("DeleteJob returned nil response")
}
if len(mock.GetAllActiveJobs()) != 0 {
t.Error("active jobs should be empty after delete")
}
if len(mock.GetAllHistoryJobs()) != 1 {
t.Error("history should contain 1 job after delete")
}
if mock.GetJobState(jobID) != "CANCELLED" {
t.Errorf("job state = %q, want CANCELLED", mock.GetJobState(jobID))
}
}
func TestDeleteJobEvictsFromActive(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "to-delete", "normal", "/tmp", "#!/bin/bash\ntrue")
_, _, err := client.Jobs.DeleteJob(ctx, strconv.Itoa(int(jobID)), nil)
if err != nil {
t.Fatalf("DeleteJob failed: %v", err)
}
_, _, err = client.Jobs.GetJob(ctx, strconv.Itoa(int(jobID)), nil)
if err == nil {
t.Fatal("expected 404 after delete, got nil error")
}
if !slurm.IsNotFound(err) {
t.Errorf("error = %v, want not-found", err)
}
}
// ---------------------------------------------------------------------------
// P0: Job State format ([]string, not bare string)
// ---------------------------------------------------------------------------
func TestJobStateIsStringArray(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
submitTestJob(t, client, "state-test", "normal", "/tmp", "#!/bin/bash\necho")
resp, _, err := client.Jobs.GetJobs(ctx, nil)
if err != nil {
t.Fatalf("GetJobs failed: %v", err)
}
if len(resp.Jobs) == 0 {
t.Fatal("expected at least one job")
}
job := resp.Jobs[0]
if len(job.JobState) == 0 {
t.Fatal("JobState is empty — must be non-empty []string to avoid mapSlurmStateToTaskStatus silent failure")
}
if job.JobState[0] != "PENDING" {
t.Errorf("JobState[0] = %q, want %q", job.JobState[0], "PENDING")
}
raw, err := json.Marshal(job)
if err != nil {
t.Fatalf("Marshal job: %v", err)
}
if !strings.Contains(string(raw), `"job_state":["PENDING"]`) {
t.Errorf("JobState JSON = %s, want array format [\"PENDING\"]", string(raw))
}
}
// ---------------------------------------------------------------------------
// P0: Full Job Lifecycle
// ---------------------------------------------------------------------------
func TestJobLifecycle(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "lifecycle", "normal", "/tmp/lc", "#!/bin/bash\necho lifecycle")
// Verify PENDING in active
resp, _, err := client.Jobs.GetJob(ctx, strconv.Itoa(int(jobID)), nil)
if err != nil {
t.Fatalf("GetJob PENDING: %v", err)
}
if resp.Jobs[0].JobState[0] != "PENDING" {
t.Errorf("initial state = %v, want PENDING", resp.Jobs[0].JobState)
}
if len(mock.GetAllActiveJobs()) != 1 {
t.Error("should have 1 active job")
}
// Transition to RUNNING
mock.SetJobState(jobID, "RUNNING")
resp, _, err = client.Jobs.GetJob(ctx, strconv.Itoa(int(jobID)), nil)
if err != nil {
t.Fatalf("GetJob RUNNING: %v", err)
}
if resp.Jobs[0].JobState[0] != "RUNNING" {
t.Errorf("running state = %v, want RUNNING", resp.Jobs[0].JobState)
}
if resp.Jobs[0].StartTime == nil || resp.Jobs[0].StartTime.Number == nil {
t.Error("StartTime should be set for RUNNING job")
}
if len(mock.GetAllActiveJobs()) != 1 {
t.Error("should still have 1 active job after RUNNING")
}
// Transition to COMPLETED — triggers eviction
mock.SetJobState(jobID, "COMPLETED")
_, _, err = client.Jobs.GetJob(ctx, strconv.Itoa(int(jobID)), nil)
if err == nil {
t.Fatal("expected 404 after COMPLETED (evicted from active)")
}
if !slurm.IsNotFound(err) {
t.Errorf("error = %v, want not-found", err)
}
if len(mock.GetAllActiveJobs()) != 0 {
t.Error("active jobs should be empty after COMPLETED")
}
if len(mock.GetAllHistoryJobs()) != 1 {
t.Error("history should contain 1 job after COMPLETED")
}
if mock.GetJobState(jobID) != "COMPLETED" {
t.Errorf("state = %q, want COMPLETED", mock.GetJobState(jobID))
}
// Verify history endpoint returns the job
histResp, _, err := client.SlurmdbJobs.GetJob(ctx, strconv.Itoa(int(jobID)))
if err != nil {
t.Fatalf("SlurmdbJobs.GetJob: %v", err)
}
if len(histResp.Jobs) != 1 {
t.Fatalf("history jobs = %d, want 1", len(histResp.Jobs))
}
histJob := histResp.Jobs[0]
if histJob.State == nil || len(histJob.State.Current) == 0 || histJob.State.Current[0] != "COMPLETED" {
t.Errorf("history state = %v, want current=[COMPLETED]", histJob.State)
}
if histJob.ExitCode == nil || histJob.ExitCode.ReturnCode == nil || histJob.ExitCode.ReturnCode.Number == nil {
t.Error("history ExitCode should be set")
} else if *histJob.ExitCode.ReturnCode.Number != 0 {
t.Errorf("exit code = %d, want 0 for COMPLETED", *histJob.ExitCode.ReturnCode.Number)
}
}
// ---------------------------------------------------------------------------
// P1: Nodes
// ---------------------------------------------------------------------------
func TestGetNodes(t *testing.T) {
client, mock := setupTestClient(t)
_ = mock
ctx := context.Background()
resp, _, err := client.Nodes.GetNodes(ctx, nil)
if err != nil {
t.Fatalf("GetNodes failed: %v", err)
}
if resp.Nodes == nil {
t.Fatal("Nodes is nil")
}
if len(*resp.Nodes) != 3 {
t.Errorf("len(Nodes) = %d, want 3", len(*resp.Nodes))
}
names := make([]string, len(*resp.Nodes))
for i, n := range *resp.Nodes {
if n.Name == nil {
t.Errorf("node %d: Name is nil", i)
} else {
names[i] = *n.Name
}
}
for _, expected := range []string{"node01", "node02", "node03"} {
found := false
for _, n := range names {
if n == expected {
found = true
break
}
}
if !found {
t.Errorf("missing node %q in %v", expected, names)
}
}
}
func TestGetNode(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.Nodes.GetNode(ctx, "node02", nil)
if err != nil {
t.Fatalf("GetNode failed: %v", err)
}
if resp.Nodes == nil || len(*resp.Nodes) != 1 {
t.Fatalf("expected 1 node, got %v", resp.Nodes)
}
if (*resp.Nodes)[0].Name == nil || *(*resp.Nodes)[0].Name != "node02" {
t.Errorf("Name = %v, want node02", (*resp.Nodes)[0].Name)
}
}
func TestGetNodeNotFound(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
_, _, err := client.Nodes.GetNode(ctx, "nonexistent", nil)
if err == nil {
t.Fatal("expected error for unknown node, got nil")
}
if !slurm.IsNotFound(err) {
t.Errorf("error = %v, want not-found", err)
}
}
// ---------------------------------------------------------------------------
// P1: Partitions
// ---------------------------------------------------------------------------
func TestGetPartitions(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.Partitions.GetPartitions(ctx, nil)
if err != nil {
t.Fatalf("GetPartitions failed: %v", err)
}
if resp.Partitions == nil {
t.Fatal("Partitions is nil")
}
if len(*resp.Partitions) != 2 {
t.Errorf("len(Partitions) = %d, want 2", len(*resp.Partitions))
}
names := map[string]bool{}
for _, p := range *resp.Partitions {
if p.Name != nil {
names[*p.Name] = true
}
}
if !names["normal"] || !names["gpu"] {
t.Errorf("expected normal and gpu partitions, got %v", names)
}
}
func TestGetPartition(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.Partitions.GetPartition(ctx, "gpu", nil)
if err != nil {
t.Fatalf("GetPartition failed: %v", err)
}
if resp.Partitions == nil || len(*resp.Partitions) != 1 {
t.Fatalf("expected 1 partition, got %v", resp.Partitions)
}
if (*resp.Partitions)[0].Name == nil || *(*resp.Partitions)[0].Name != "gpu" {
t.Errorf("Name = %v, want gpu", (*resp.Partitions)[0].Name)
}
}
func TestGetPartitionNotFound(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
_, _, err := client.Partitions.GetPartition(ctx, "nonexistent", nil)
if err == nil {
t.Fatal("expected error for unknown partition, got nil")
}
if !slurm.IsNotFound(err) {
t.Errorf("error = %v, want not-found", err)
}
}
// ---------------------------------------------------------------------------
// P1: Diag
// ---------------------------------------------------------------------------
func TestGetDiag(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.Diag.GetDiag(ctx)
if err != nil {
t.Fatalf("GetDiag failed: %v", err)
}
if resp.Statistics == nil {
t.Fatal("Statistics is nil")
}
if resp.Statistics.ServerThreadCount == nil || *resp.Statistics.ServerThreadCount != 3 {
t.Errorf("ServerThreadCount = %v, want 3", resp.Statistics.ServerThreadCount)
}
if resp.Statistics.AgentQueueSize == nil || *resp.Statistics.AgentQueueSize != 0 {
t.Errorf("AgentQueueSize = %v, want 0", resp.Statistics.AgentQueueSize)
}
}
// ---------------------------------------------------------------------------
// P1: SlurmDB Jobs
// ---------------------------------------------------------------------------
func TestSlurmdbGetJobsEmpty(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
resp, _, err := client.SlurmdbJobs.GetJobs(ctx, nil)
if err != nil {
t.Fatalf("GetJobs failed: %v", err)
}
if len(resp.Jobs) != 0 {
t.Errorf("len(Jobs) = %d, want 0 (no history)", len(resp.Jobs))
}
}
func TestSlurmdbGetJobsAfterEviction(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "hist-job", "normal", "/tmp/h", "#!/bin/bash\necho hist")
mock.SetJobState(jobID, "RUNNING")
mock.SetJobState(jobID, "COMPLETED")
resp, _, err := client.SlurmdbJobs.GetJobs(ctx, nil)
if err != nil {
t.Fatalf("GetJobs failed: %v", err)
}
if len(resp.Jobs) != 1 {
t.Fatalf("len(Jobs) = %d, want 1", len(resp.Jobs))
}
job := resp.Jobs[0]
if job.Name == nil || *job.Name != "hist-job" {
t.Errorf("Name = %v, want hist-job", job.Name)
}
if job.State == nil || len(job.State.Current) == 0 || job.State.Current[0] != "COMPLETED" {
t.Errorf("State = %v, want current=[COMPLETED]", job.State)
}
if job.Time == nil || job.Time.Submission == nil {
t.Error("Time.Submission should be set")
}
}
func TestSlurmdbGetJobByID(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "single-hist", "normal", "/tmp/sh", "#!/bin/bash\nexit 1")
mock.SetJobState(jobID, "FAILED")
resp, _, err := client.SlurmdbJobs.GetJob(ctx, strconv.Itoa(int(jobID)))
if err != nil {
t.Fatalf("GetJob failed: %v", err)
}
if len(resp.Jobs) != 1 {
t.Fatalf("len(Jobs) = %d, want 1", len(resp.Jobs))
}
job := resp.Jobs[0]
if job.JobID == nil || *job.JobID != jobID {
t.Errorf("JobID = %v, want %d", job.JobID, jobID)
}
if job.State == nil || len(job.State.Current) == 0 || job.State.Current[0] != "FAILED" {
t.Errorf("State = %v, want current=[FAILED]", job.State)
}
if job.ExitCode == nil || job.ExitCode.ReturnCode == nil || job.ExitCode.ReturnCode.Number == nil {
t.Error("ExitCode should be set")
} else if *job.ExitCode.ReturnCode.Number != 1 {
t.Errorf("exit code = %d, want 1 for FAILED", *job.ExitCode.ReturnCode.Number)
}
}
func TestSlurmdbGetJobNotFound(t *testing.T) {
client, _ := setupTestClient(t)
ctx := context.Background()
_, _, err := client.SlurmdbJobs.GetJob(ctx, "999")
if err == nil {
t.Fatal("expected error for unknown history job, got nil")
}
if !slurm.IsNotFound(err) {
t.Errorf("error = %v, want not-found", err)
}
}
func TestSlurmdbJobStateIsNested(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
jobID := submitTestJob(t, client, "nested-state", "gpu", "/tmp/ns", "#!/bin/bash\ntrue")
mock.SetJobState(jobID, "COMPLETED")
resp, _, err := client.SlurmdbJobs.GetJob(ctx, strconv.Itoa(int(jobID)))
if err != nil {
t.Fatalf("GetJob failed: %v", err)
}
job := resp.Jobs[0]
if job.State == nil {
t.Fatal("State is nil — must be nested {current: [...], reason: \"\"}")
}
if len(job.State.Current) == 0 {
t.Fatal("State.Current is empty")
}
if job.State.Current[0] != "COMPLETED" {
t.Errorf("State.Current[0] = %q, want COMPLETED", job.State.Current[0])
}
if job.State.Reason == nil || *job.State.Reason != "" {
t.Errorf("State.Reason = %v, want empty string", job.State.Reason)
}
raw, err := json.Marshal(job)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
rawStr := string(raw)
if !strings.Contains(rawStr, `"state":{"current":["COMPLETED"]`) {
t.Errorf("state JSON should use nested format, got: %s", rawStr)
}
}
func TestSlurmdbJobsFilterByName(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
id1 := submitTestJob(t, client, "match-me", "normal", "/tmp", "#!/bin/bash\ntrue")
id2 := submitTestJob(t, client, "other-job", "normal", "/tmp", "#!/bin/bash\ntrue")
mock.SetJobState(id1, "COMPLETED")
mock.SetJobState(id2, "COMPLETED")
resp, _, err := client.SlurmdbJobs.GetJobs(ctx, &slurm.GetSlurmdbJobsOptions{
JobName: slurm.Ptr("match-me"),
})
if err != nil {
t.Fatalf("GetJobs with filter: %v", err)
}
if len(resp.Jobs) != 1 {
t.Fatalf("len(Jobs) = %d, want 1 (filtered by name)", len(resp.Jobs))
}
if resp.Jobs[0].Name == nil || *resp.Jobs[0].Name != "match-me" {
t.Errorf("Name = %v, want match-me", resp.Jobs[0].Name)
}
}
// ---------------------------------------------------------------------------
// SetJobState terminal state exit codes
// ---------------------------------------------------------------------------
func TestSetJobStateExitCodes(t *testing.T) {
client, mock := setupTestClient(t)
ctx := context.Background()
cases := []struct {
state string
wantExit int64
}{
{"COMPLETED", 0},
{"FAILED", 1},
{"CANCELLED", 1},
{"TIMEOUT", 1},
}
for i, tc := range cases {
jobID := submitTestJob(t, client, "exit-"+strconv.Itoa(i), "normal", "/tmp", "#!/bin/bash\ntrue")
mock.SetJobState(jobID, tc.state)
resp, _, err := client.SlurmdbJobs.GetJob(ctx, strconv.Itoa(int(jobID)))
if err != nil {
t.Fatalf("GetJob(%d) %s: %v", jobID, tc.state, err)
}
job := resp.Jobs[0]
if job.ExitCode == nil || job.ExitCode.ReturnCode == nil || job.ExitCode.ReturnCode.Number == nil {
t.Errorf("%s: ExitCode not set", tc.state)
continue
}
if *job.ExitCode.ReturnCode.Number != tc.wantExit {
t.Errorf("%s: exit code = %d, want %d", tc.state, *job.ExitCode.ReturnCode.Number, tc.wantExit)
}
}
}

View File

@@ -0,0 +1,134 @@
// Package mockslurm provides response builder helpers that generate JSON
// matching Openapi* types from the internal/slurm package.
package mockslurm
import (
"gcy_hpc_server/internal/slurm"
)
// newMeta returns standard OpenapiMeta with plugin type "openapi/v0.0.40"
// and Slurm version 24.05.0.
func newMeta() slurm.OpenapiMeta {
return slurm.OpenapiMeta{
Plugin: &slurm.MetaPlugin{
Type: slurm.Ptr("openapi/v0.0.40"),
Name: slurm.Ptr("slurmrestd"),
DataParser: slurm.Ptr("json/v0.0.40"),
},
Slurm: &slurm.MetaSlurm{
Version: &slurm.MetaSlurmVersion{
Major: slurm.Ptr("24"),
Micro: slurm.Ptr("0"),
Minor: slurm.Ptr("5"),
},
Release: slurm.Ptr("24.05.0"),
},
}
}
// NewSubmitResponse builds an OpenapiJobSubmitResponse with the given jobID.
func NewSubmitResponse(jobID int32) slurm.OpenapiJobSubmitResponse {
return slurm.OpenapiJobSubmitResponse{
Result: &slurm.JobSubmitResponseMsg{
JobID: slurm.Ptr(jobID),
},
JobID: slurm.Ptr(jobID),
StepID: slurm.Ptr("Scalar"),
Meta: &slurm.OpenapiMeta{},
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// NewJobInfoResponse builds an OpenapiJobInfoResp wrapping the given jobs.
func NewJobInfoResponse(jobs []slurm.JobInfo) slurm.OpenapiJobInfoResp {
meta := newMeta()
return slurm.OpenapiJobInfoResp{
Jobs: jobs,
Meta: &meta,
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// NewJobListResponse is an alias for NewJobInfoResponse.
func NewJobListResponse(jobs []slurm.JobInfo) slurm.OpenapiJobInfoResp {
return NewJobInfoResponse(jobs)
}
// NewDeleteResponse builds an OpenapiResp with meta and empty errors/warnings.
func NewDeleteResponse() slurm.OpenapiResp {
meta := newMeta()
return slurm.OpenapiResp{
Meta: &meta,
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// NewNodeResponse builds an OpenapiNodesResp wrapping the given nodes.
func NewNodeResponse(nodes []slurm.Node) slurm.OpenapiNodesResp {
meta := newMeta()
n := slurm.Nodes(nodes)
return slurm.OpenapiNodesResp{
Nodes: &n,
Meta: &meta,
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// NewPartitionResponse builds an OpenapiPartitionResp wrapping the given partitions.
func NewPartitionResponse(partitions []slurm.PartitionInfo) slurm.OpenapiPartitionResp {
meta := newMeta()
p := slurm.PartitionInfoMsg(partitions)
return slurm.OpenapiPartitionResp{
Partitions: &p,
Meta: &meta,
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// NewDiagResponse builds an OpenapiDiagResp with stats and meta.
func NewDiagResponse() slurm.OpenapiDiagResp {
meta := newMeta()
return slurm.OpenapiDiagResp{
Statistics: &slurm.StatsMsg{
ServerThreadCount: slurm.Ptr(int32(3)),
AgentQueueSize: slurm.Ptr(int32(0)),
JobsRunning: slurm.Ptr(int32(0)),
JobsPending: slurm.Ptr(int32(0)),
ScheduleQueueLength: slurm.Ptr(int32(0)),
},
Meta: &meta,
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// NewJobHistoryResponse builds an OpenapiSlurmdbdJobsResp wrapping the given SlurmDBD jobs.
func NewJobHistoryResponse(jobs []slurm.Job) slurm.OpenapiSlurmdbdJobsResp {
meta := newMeta()
return slurm.OpenapiSlurmdbdJobsResp{
Jobs: jobs,
Meta: &meta,
Errors: slurm.OpenapiErrors{},
Warnings: slurm.OpenapiWarnings{},
}
}
// buildActiveJobState returns a flat string array for the active endpoint
// job_state field (e.g. ["RUNNING"]).
func buildActiveJobState(states ...string) []string {
return states
}
// buildHistoryJobState returns a nested JobState object for the SlurmDB
// history endpoint (e.g. {current: ["COMPLETED"], reason: ""}).
func buildHistoryJobState(states ...string) *slurm.JobState {
return &slurm.JobState{
Current: states,
Reason: slurm.Ptr(""),
}
}

Some files were not shown because too many files have changed in this diff Show More