Compare commits
43 Commits
4ff02d4a80
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9092278d26 | ||
|
|
7c374f4fd5 | ||
|
|
36d842350c | ||
|
|
80f2bd32d9 | ||
|
|
52a34e2cb0 | ||
|
|
b9b2f0d9b4 | ||
|
|
73504f9fdb | ||
|
|
3f8a680c99 | ||
|
|
ec64300ff2 | ||
|
|
acf8c1d62b | ||
|
|
d46a784efb | ||
|
|
79870333cb | ||
|
|
d9a60c3511 | ||
|
|
20576bc325 | ||
|
|
c0176d7764 | ||
|
|
2298e92516 | ||
|
|
f0847d3978 | ||
|
|
a114821615 | ||
|
|
bf89de12f0 | ||
|
|
c861ff3adf | ||
|
|
0e4f523746 | ||
|
|
44895214d4 | ||
|
|
a65c8762af | ||
|
|
04f99cc1c4 | ||
|
|
32f5792b68 | ||
|
|
328691adff | ||
|
|
10bb15e5b2 | ||
|
|
d3eb728c2f | ||
|
|
4a8153aa6c | ||
|
|
dd8d226e78 | ||
|
|
62e458cb7a | ||
|
|
2cb6fbecdd | ||
|
|
35a4017b8e | ||
|
|
f4177dd287 | ||
|
|
b3d787c97b | ||
|
|
30f0fbc34b | ||
|
|
34ba617cbf | ||
|
|
824d9e816f | ||
|
|
85901fe18a | ||
|
|
270552ba9a | ||
|
|
347b0e1229 | ||
|
|
c070dd8abc | ||
|
|
1359730300 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,2 @@
|
|||||||
bin/
|
bin/
|
||||||
*.exe
|
*.exe
|
||||||
.sisyphus/
|
|
||||||
|
|||||||
662
cmd/client/main.go
Normal file
662
cmd/client/main.go
Normal file
@@ -0,0 +1,662 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ── API response types ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
type apiResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
FileName string `json:"file_name"`
|
||||||
|
FileSize int64 `json:"file_size"`
|
||||||
|
ChunkSize int64 `json:"chunk_size"`
|
||||||
|
TotalChunks int `json:"total_chunks"`
|
||||||
|
SHA256 string `json:"sha256"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
UploadedChunks []int `json:"uploaded_chunks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
MimeType string `json:"mime_type"`
|
||||||
|
SHA256 string `json:"sha256"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type folderResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ParentID *int64 `json:"parent_id,omitempty"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type listFilesResponse struct {
|
||||||
|
Files []fileResponse `json:"files"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func formatSize(b int64) string {
|
||||||
|
const (
|
||||||
|
KB = 1024
|
||||||
|
MB = KB * 1024
|
||||||
|
GB = MB * 1024
|
||||||
|
)
|
||||||
|
switch {
|
||||||
|
case b >= GB:
|
||||||
|
return fmt.Sprintf("%.2f GB", float64(b)/float64(GB))
|
||||||
|
case b >= MB:
|
||||||
|
return fmt.Sprintf("%.2f MB", float64(b)/float64(MB))
|
||||||
|
case b >= KB:
|
||||||
|
return fmt.Sprintf("%.2f KB", float64(b)/float64(KB))
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d B", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(server, method, path string, body io.Reader, contentType string) (*apiResponse, error) {
|
||||||
|
url := server + path
|
||||||
|
req, err := http.NewRequest(method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
if contentType != "" {
|
||||||
|
req.Header.Set("Content-Type", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sending request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiResp apiResponse
|
||||||
|
if err := json.Unmarshal(raw, &apiResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing response: %w\nbody: %s", err, string(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !apiResp.Success {
|
||||||
|
return &apiResp, fmt.Errorf("API error: %s", apiResp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &apiResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeSHA256(path string) (string, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := io.Copy(h, f); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(h.Sum(nil)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Commands ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func cmdMkdir(server string, args []string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> mkdir <name> [-parent <id>]")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
name := args[0]
|
||||||
|
var parentID *int64
|
||||||
|
for i := 1; i+1 < len(args); i += 2 {
|
||||||
|
if args[i] == "-parent" {
|
||||||
|
v, err := strconv.ParseInt(args[i+1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "invalid parent id: %s\n", args[i+1])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
parentID = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]interface{}{"name": name}
|
||||||
|
if parentID != nil {
|
||||||
|
payload["parent_id"] = *parentID
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(payload)
|
||||||
|
resp, err := doRequest(server, http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body), "application/json")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var folder folderResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &folder); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "parsing folder response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Folder created: id=%d name=%q path=%q\n", folder.ID, folder.Name, folder.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdListFolders(server string, args []string) {
|
||||||
|
path := "/api/v1/files/folders?"
|
||||||
|
for i := 0; i+1 < len(args); i += 2 {
|
||||||
|
if args[i] == "-parent" {
|
||||||
|
path += "parent_id=" + args[i+1] + "&"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(server, http.MethodGet, path, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var folders []folderResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &folders); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "parsing folders response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(folders) == 0 {
|
||||||
|
fmt.Println("No folders found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%-8s %-30s %-10s %s\n", "ID", "Name", "ParentID", "Path")
|
||||||
|
fmt.Println(strings.Repeat("-", 80))
|
||||||
|
for _, f := range folders {
|
||||||
|
pid := "<root>"
|
||||||
|
if f.ParentID != nil {
|
||||||
|
pid = strconv.FormatInt(*f.ParentID, 10)
|
||||||
|
}
|
||||||
|
fmt.Printf("%-8d %-30s %-10s %s\n", f.ID, f.Name, pid, f.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdDeleteFolder(server string, args []string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> delete-folder <id>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
id := args[0]
|
||||||
|
|
||||||
|
_, err := doRequest(server, http.MethodDelete, "/api/v1/files/folders/"+id, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Folder %s deleted.\n", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdUpload(server string, args []string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> upload <file> [-folder <id>] [-chunk-size <bytes>]")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := args[0]
|
||||||
|
var folderID *int64
|
||||||
|
chunkSize := int64(8 * 1024 * 1024) // 8 MB default
|
||||||
|
|
||||||
|
for i := 1; i+1 < len(args); i += 2 {
|
||||||
|
switch args[i] {
|
||||||
|
case "-folder":
|
||||||
|
v, err := strconv.ParseInt(args[i+1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "invalid folder id: %s\n", args[i+1])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
folderID = &v
|
||||||
|
case "-chunk-size":
|
||||||
|
v, err := strconv.ParseInt(args[i+1], 10, 64)
|
||||||
|
if err != nil || v <= 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "invalid chunk size: %s\n", args[i+1])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
chunkSize = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// Open file and get size
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "cannot open file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "cannot stat file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fileSize := fi.Size()
|
||||||
|
fileName := filepath.Base(filePath)
|
||||||
|
|
||||||
|
// Compute SHA256
|
||||||
|
fmt.Printf("Computing SHA256 for %s (%s)...\n", fileName, formatSize(fileSize))
|
||||||
|
sha, err := computeSHA256(filePath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "sha256 error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("SHA256: %s\n", sha)
|
||||||
|
|
||||||
|
// Init upload
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"file_name": fileName,
|
||||||
|
"file_size": fileSize,
|
||||||
|
"sha256": sha,
|
||||||
|
"chunk_size": chunkSize,
|
||||||
|
}
|
||||||
|
if folderID != nil {
|
||||||
|
payload["folder_id"] = *folderID
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(payload)
|
||||||
|
resp, err := doRequest(server, http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body), "application/json")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for dedup (秒传) — response is a file, not a session
|
||||||
|
var sess sessionResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &sess); err == nil && sess.Status != "" && sess.TotalChunks > 0 {
|
||||||
|
// It's a session, proceed with chunk uploads
|
||||||
|
} else {
|
||||||
|
// Try to parse as file (dedup hit)
|
||||||
|
var fr fileResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &fr); err == nil && fr.Name != "" {
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
fmt.Printf("⚡ 秒传 (instant upload)! File already exists.\n")
|
||||||
|
fmt.Printf(" ID: %d\n", fr.ID)
|
||||||
|
fmt.Printf(" Name: %s\n", fr.Name)
|
||||||
|
fmt.Printf(" Size: %s\n", formatSize(fr.Size))
|
||||||
|
fmt.Printf(" SHA256: %s\n", fr.SHA256)
|
||||||
|
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// If neither, just try the session path anyway
|
||||||
|
if err := json.Unmarshal(resp.Data, &sess); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "unexpected response: %s\n", string(resp.Data))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := sess.ID
|
||||||
|
totalChunks := sess.TotalChunks
|
||||||
|
fmt.Printf("Upload session created: id=%d total_chunks=%d chunk_size=%s\n",
|
||||||
|
sessionID, totalChunks, formatSize(chunkSize))
|
||||||
|
|
||||||
|
// Upload chunks (4 concurrent workers)
|
||||||
|
const uploadWorkers = 4
|
||||||
|
|
||||||
|
type chunkResult struct {
|
||||||
|
index int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
work := make(chan int, totalChunks)
|
||||||
|
results := make(chan chunkResult, totalChunks)
|
||||||
|
|
||||||
|
for i := 0; i < totalChunks; i++ {
|
||||||
|
work <- i
|
||||||
|
}
|
||||||
|
close(work)
|
||||||
|
|
||||||
|
var uploaded int64 // atomic counter for progress
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for w := 0; w < uploadWorkers; w++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for idx := range work {
|
||||||
|
offset := int64(idx) * chunkSize
|
||||||
|
remaining := fileSize - offset
|
||||||
|
thisChunkSize := chunkSize
|
||||||
|
if remaining < chunkSize {
|
||||||
|
thisChunkSize = remaining
|
||||||
|
}
|
||||||
|
sectionReader := io.NewSectionReader(f, offset, thisChunkSize)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buf)
|
||||||
|
part, err := writer.CreateFormFile("chunk", fileName)
|
||||||
|
if err != nil {
|
||||||
|
results <- chunkResult{index: idx, err: err}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(part, sectionReader); err != nil {
|
||||||
|
results <- chunkResult{index: idx, err: err}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
chunkURL := fmt.Sprintf("%s/api/v1/files/uploads/%d/chunks/%d", server, sessionID, idx)
|
||||||
|
chunkReq, err := http.NewRequest(http.MethodPut, chunkURL, &buf)
|
||||||
|
if err != nil {
|
||||||
|
results <- chunkResult{index: idx, err: err}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chunkReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
chunkResp, err := http.DefaultClient.Do(chunkReq)
|
||||||
|
if err != nil {
|
||||||
|
results <- chunkResult{index: idx, err: err}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw, _ := io.ReadAll(chunkResp.Body)
|
||||||
|
chunkResp.Body.Close()
|
||||||
|
|
||||||
|
if chunkResp.StatusCode >= 400 {
|
||||||
|
results <- chunkResult{index: idx, err: fmt.Errorf("HTTP %d: %s", chunkResp.StatusCode, string(raw))}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var chunkAPIResp apiResponse
|
||||||
|
if err := json.Unmarshal(raw, &chunkAPIResp); err == nil && !chunkAPIResp.Success {
|
||||||
|
results <- chunkResult{index: idx, err: fmt.Errorf("%s", chunkAPIResp.Error)}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
results <- chunkResult{index: idx, err: nil}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var firstErr error
|
||||||
|
for res := range results {
|
||||||
|
if res.err != nil {
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = res.err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "chunk %d failed: %v\n", res.index, res.err)
|
||||||
|
}
|
||||||
|
done := atomic.AddInt64(&uploaded, 1)
|
||||||
|
pct := float64(done) / float64(totalChunks) * 100
|
||||||
|
doneBytes := done * chunkSize
|
||||||
|
if doneBytes > fileSize {
|
||||||
|
doneBytes = fileSize
|
||||||
|
}
|
||||||
|
fmt.Printf("\rUploading: %.1f%% (%s / %s)", pct, formatSize(doneBytes), formatSize(fileSize))
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if firstErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "upload failed: %v\n", firstErr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete upload
|
||||||
|
completeURL := fmt.Sprintf("/api/v1/files/uploads/%d/complete", sessionID)
|
||||||
|
resp, err = doRequest(server, http.MethodPost, completeURL, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result fileResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "parsing complete response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
speed := float64(fileSize) / elapsed.Seconds()
|
||||||
|
fmt.Println("✓ Upload complete!")
|
||||||
|
fmt.Printf(" ID: %d\n", result.ID)
|
||||||
|
fmt.Printf(" Name: %s\n", result.Name)
|
||||||
|
fmt.Printf(" Size: %s\n", formatSize(result.Size))
|
||||||
|
fmt.Printf(" SHA256: %s\n", result.SHA256)
|
||||||
|
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
|
||||||
|
fmt.Printf(" Speed: %s/s\n", formatSize(int64(speed)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdDownload(server string, args []string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> download <file_id> [-o <output>]")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileID := args[0]
|
||||||
|
var output string
|
||||||
|
for i := 1; i+1 < len(args); i += 2 {
|
||||||
|
if args[i] == "-o" {
|
||||||
|
output = args[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
url := server + "/api/v1/files/" + fileID + "/download"
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "download error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
fmt.Fprintf(os.Stderr, "download failed (HTTP %d): %s\n", resp.StatusCode, string(body))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine output filename
|
||||||
|
if output == "" {
|
||||||
|
// Try Content-Disposition header
|
||||||
|
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||||
|
if idx := strings.Index(cd, "filename="); idx != -1 {
|
||||||
|
output = strings.Trim(cd[idx+9:], `"`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if output == "" {
|
||||||
|
output = "download_" + fileID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outFile, err := os.Create(output)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "cannot create output file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer outFile.Close()
|
||||||
|
|
||||||
|
written, err := io.Copy(outFile, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "download write error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
speed := float64(written) / elapsed.Seconds()
|
||||||
|
fmt.Println("✓ Download complete!")
|
||||||
|
fmt.Printf(" File: %s\n", output)
|
||||||
|
fmt.Printf(" Size: %s\n", formatSize(written))
|
||||||
|
fmt.Printf(" Elapsed: %s\n", elapsed.Truncate(time.Millisecond))
|
||||||
|
fmt.Printf(" Speed: %s/s\n", formatSize(int64(speed)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdListFiles(server string, args []string) {
|
||||||
|
var folderID, page, pageSize string
|
||||||
|
for i := 0; i+1 < len(args); i += 2 {
|
||||||
|
switch args[i] {
|
||||||
|
case "-folder":
|
||||||
|
folderID = args[i+1]
|
||||||
|
case "-page":
|
||||||
|
page = args[i+1]
|
||||||
|
case "-page-size":
|
||||||
|
pageSize = args[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
path := "/api/v1/files?"
|
||||||
|
if folderID != "" {
|
||||||
|
path += "folder_id=" + folderID + "&"
|
||||||
|
}
|
||||||
|
if page != "" {
|
||||||
|
path += "page=" + page + "&"
|
||||||
|
}
|
||||||
|
if pageSize != "" {
|
||||||
|
path += "page_size=" + pageSize + "&"
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(server, http.MethodGet, path, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var list listFilesResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &list); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "parsing files response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(list.Files) == 0 {
|
||||||
|
fmt.Println("No files found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%-8s %-30s %-12s %-10s %s\n", "ID", "Name", "Size", "MIME", "Created")
|
||||||
|
fmt.Println(strings.Repeat("-", 90))
|
||||||
|
for _, f := range list.Files {
|
||||||
|
name := f.Name
|
||||||
|
if len(name) > 28 {
|
||||||
|
name = name[:25] + "..."
|
||||||
|
}
|
||||||
|
fmt.Printf("%-8d %-30s %-12s %-10s %s\n", f.ID, name, formatSize(f.Size), f.MimeType, f.CreatedAt)
|
||||||
|
}
|
||||||
|
fmt.Printf("\nTotal: %d Page: %d PageSize: %d\n", list.Total, list.Page, list.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdDeleteFile(server string, args []string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: client -server <addr> delete-file <file_id>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fileID := args[0]
|
||||||
|
|
||||||
|
_, err := doRequest(server, http.MethodDelete, "/api/v1/files/"+fileID, nil, "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("File %s deleted.\n", fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Main ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func usage() {
|
||||||
|
fmt.Fprintf(os.Stderr, `HPC File Storage Client
|
||||||
|
|
||||||
|
Usage: client -server <addr> <command> [args]
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
mkdir <name> [-parent <id>] 创建文件夹
|
||||||
|
ls-folders [-parent <id>] 列出文件夹
|
||||||
|
delete-folder <id> 删除文件夹
|
||||||
|
upload <file> [-folder <id>] [-chunk-size <bytes>] 上传文件(分片)
|
||||||
|
download <file_id> [-o <output>] 下载文件
|
||||||
|
ls-files [-folder <id>] [-page <n>] [-page-size <n>] 列出文件
|
||||||
|
delete-file <file_id> 删除文件
|
||||||
|
|
||||||
|
Global flags:
|
||||||
|
-server <addr> 服务器地址 (默认 http://localhost:8080)
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
server := "http://localhost:8080"
|
||||||
|
var command string
|
||||||
|
var cmdArgs []string
|
||||||
|
|
||||||
|
args := os.Args[1:]
|
||||||
|
i := 0
|
||||||
|
for i < len(args) {
|
||||||
|
if args[i] == "-server" && i+1 < len(args) {
|
||||||
|
server = args[i+1]
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if command == "" {
|
||||||
|
command = args[i]
|
||||||
|
} else {
|
||||||
|
cmdArgs = append(cmdArgs, args[i])
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if command == "" {
|
||||||
|
usage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch command {
|
||||||
|
case "mkdir":
|
||||||
|
cmdMkdir(server, cmdArgs)
|
||||||
|
case "ls-folders":
|
||||||
|
cmdListFolders(server, cmdArgs)
|
||||||
|
case "delete-folder":
|
||||||
|
cmdDeleteFolder(server, cmdArgs)
|
||||||
|
case "upload":
|
||||||
|
cmdUpload(server, cmdArgs)
|
||||||
|
case "download":
|
||||||
|
cmdDownload(server, cmdArgs)
|
||||||
|
case "ls-files":
|
||||||
|
cmdListFiles(server, cmdArgs)
|
||||||
|
case "delete-file":
|
||||||
|
cmdDeleteFile(server, cmdArgs)
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
|
||||||
|
usage()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
773
cmd/server/file_test.go
Normal file
773
cmd/server/file_test.go
Normal file
@@ -0,0 +1,773 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
"gcy_hpc_server/internal/handler"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// In-memory ObjectStorage mock
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type inMemoryStorage struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
objects map[string][]byte
|
||||||
|
bucket string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ storage.ObjectStorage = (*inMemoryStorage)(nil)
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) PutObject(_ context.Context, _, key string, reader io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return storage.UploadInfo{}, fmt.Errorf("read all: %w", err)
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.objects[key] = data
|
||||||
|
s.mu.Unlock()
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return storage.UploadInfo{ETag: hex.EncodeToString(h[:]), Size: int64(len(data))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) GetObject(_ context.Context, _, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
data, ok := s.objects[key]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
|
||||||
|
}
|
||||||
|
size := int64(len(data))
|
||||||
|
start := int64(0)
|
||||||
|
end := size - 1
|
||||||
|
if opts.Start != nil {
|
||||||
|
start = *opts.Start
|
||||||
|
}
|
||||||
|
if opts.End != nil {
|
||||||
|
end = *opts.End
|
||||||
|
}
|
||||||
|
if end >= size {
|
||||||
|
end = size - 1
|
||||||
|
}
|
||||||
|
section := io.NewSectionReader(bytes.NewReader(data), start, end-start+1)
|
||||||
|
info := storage.ObjectInfo{Key: key, Size: size}
|
||||||
|
return io.NopCloser(section), info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) ComposeObject(_ context.Context, _, dst string, sources []string) (storage.UploadInfo, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for _, src := range sources {
|
||||||
|
data, ok := s.objects[src]
|
||||||
|
if !ok {
|
||||||
|
return storage.UploadInfo{}, fmt.Errorf("source object %s not found", src)
|
||||||
|
}
|
||||||
|
buf.Write(data)
|
||||||
|
}
|
||||||
|
combined := buf.Bytes()
|
||||||
|
s.objects[dst] = combined
|
||||||
|
h := sha256.Sum256(combined)
|
||||||
|
return storage.UploadInfo{ETag: hex.EncodeToString(h[:]), Size: int64(len(combined))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) RemoveObject(_ context.Context, _, key string, _ storage.RemoveObjectOptions) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.objects, key)
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) ListObjects(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
var result []storage.ObjectInfo
|
||||||
|
for k, v := range s.objects {
|
||||||
|
if strings.HasPrefix(k, prefix) {
|
||||||
|
result = append(result, storage.ObjectInfo{Key: k, Size: int64(len(v))})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(result, func(i, j int) bool { return result[i].Key < result[j].Key })
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) RemoveObjects(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, k := range keys {
|
||||||
|
delete(s.objects, k)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) BucketExists(_ context.Context, _ string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *inMemoryStorage) StatObject(_ context.Context, _, key string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
data, ok := s.objects[key]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
|
||||||
|
}
|
||||||
|
return storage.ObjectInfo{Key: key, Size: int64(len(data))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Test helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func setupFileTestRouter(t *testing.T) (*gin.Engine, *gorm.DB, *inMemoryStorage) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sqlDB, _ := db.DB()
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
db.AutoMigrate(&model.FileBlob{}, &model.File{}, &model.Folder{}, &model.UploadSession{}, &model.UploadChunk{})
|
||||||
|
|
||||||
|
memStore := &inMemoryStorage{objects: make(map[string][]byte)}
|
||||||
|
|
||||||
|
blobStore := store.NewBlobStore(db)
|
||||||
|
fileStore := store.NewFileStore(db)
|
||||||
|
folderStore := store.NewFolderStore(db)
|
||||||
|
uploadStore := store.NewUploadStore(db)
|
||||||
|
|
||||||
|
cfg := config.MinioConfig{
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
MaxFileSize: 50 << 30,
|
||||||
|
MinChunkSize: 5 << 20,
|
||||||
|
SessionTTL: 48,
|
||||||
|
Bucket: "files",
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadSvc := service.NewUploadService(memStore, blobStore, fileStore, uploadStore, cfg, db, zap.NewNop())
|
||||||
|
_ = service.NewDownloadService(memStore, blobStore, fileStore, "files", zap.NewNop())
|
||||||
|
folderSvc := service.NewFolderService(folderStore, fileStore, zap.NewNop())
|
||||||
|
fileSvc := service.NewFileService(memStore, blobStore, fileStore, "files", db, zap.NewNop())
|
||||||
|
|
||||||
|
uploadH := handler.NewUploadHandler(uploadSvc, zap.NewNop())
|
||||||
|
fileH := handler.NewFileHandler(fileSvc, zap.NewNop())
|
||||||
|
folderH := handler.NewFolderHandler(folderSvc, zap.NewNop())
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(gin.Recovery())
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
files := v1.Group("/files")
|
||||||
|
|
||||||
|
uploads := files.Group("/uploads")
|
||||||
|
uploads.POST("", uploadH.InitUpload)
|
||||||
|
uploads.GET("/:id", uploadH.GetUploadStatus)
|
||||||
|
uploads.PUT("/:id/chunks/:index", uploadH.UploadChunk)
|
||||||
|
uploads.POST("/:id/complete", uploadH.CompleteUpload)
|
||||||
|
uploads.DELETE("/:id", uploadH.CancelUpload)
|
||||||
|
|
||||||
|
files.GET("", fileH.ListFiles)
|
||||||
|
files.GET("/:id", fileH.GetFile)
|
||||||
|
files.GET("/:id/download", fileH.DownloadFile)
|
||||||
|
files.DELETE("/:id", fileH.DeleteFile)
|
||||||
|
|
||||||
|
folders := files.Group("/folders")
|
||||||
|
folders.POST("", folderH.CreateFolder)
|
||||||
|
folders.GET("", folderH.ListFolders)
|
||||||
|
folders.GET("/:id", folderH.GetFolder)
|
||||||
|
folders.DELETE("/:id", folderH.DeleteFolder)
|
||||||
|
|
||||||
|
return r, db, memStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiResponse mirrors server.APIResponse for decoding.
|
||||||
|
type apiResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) apiResponse {
|
||||||
|
t.Helper()
|
||||||
|
var resp apiResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v, body: %s", err, w.Body.String())
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func createChunkRequest(t *testing.T, url string, data []byte) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buf)
|
||||||
|
part, err := writer.CreateFormFile("chunk", "chunk.bin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
part.Write(data)
|
||||||
|
writer.Close()
|
||||||
|
req, err := http.NewRequest("PUT", url, &buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// helperUploadFile performs a full upload lifecycle: init → upload chunks → complete.
|
||||||
|
// Returns the file ID from the completed upload response.
|
||||||
|
func helperUploadFile(t *testing.T, router *gin.Engine, fileName string, fileData []byte, sha256Hash string, folderID *int64, chunkSize int64) int64 {
|
||||||
|
t.Helper()
|
||||||
|
fileSize := int64(len(fileData))
|
||||||
|
|
||||||
|
initBody := model.InitUploadRequest{
|
||||||
|
FileName: fileName,
|
||||||
|
FileSize: fileSize,
|
||||||
|
SHA256: sha256Hash,
|
||||||
|
FolderID: folderID,
|
||||||
|
ChunkSize: &chunkSize,
|
||||||
|
}
|
||||||
|
initJSON, _ := json.Marshal(initBody)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
resp := decodeResponse(t, w)
|
||||||
|
if !resp.Success {
|
||||||
|
t.Fatalf("init upload failed: %s", resp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
var fileResp model.FileResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
|
||||||
|
t.Fatalf("decode dedup file response: %v", err)
|
||||||
|
}
|
||||||
|
return fileResp.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("init upload: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var session model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &session); err != nil {
|
||||||
|
t.Fatalf("failed to decode session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalChunks := session.TotalChunks
|
||||||
|
for i := 0; i < totalChunks; i++ {
|
||||||
|
start := int64(i) * chunkSize
|
||||||
|
end := start + chunkSize
|
||||||
|
if end > fileSize {
|
||||||
|
end = fileSize
|
||||||
|
}
|
||||||
|
chunkData := fileData[start:end]
|
||||||
|
url := fmt.Sprintf("/api/v1/files/uploads/%d/chunks/%d", session.ID, i)
|
||||||
|
cw := httptest.NewRecorder()
|
||||||
|
creq := createChunkRequest(t, url, chunkData)
|
||||||
|
router.ServeHTTP(cw, creq)
|
||||||
|
if cw.Code != http.StatusOK {
|
||||||
|
t.Fatalf("upload chunk %d: expected 200, got %d: %s", i, cw.Code, cw.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("complete upload: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = decodeResponse(t, w)
|
||||||
|
var fileResp model.FileResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
|
||||||
|
t.Fatalf("failed to decode file response: %v", err)
|
||||||
|
}
|
||||||
|
return fileResp.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestFileFullLifecycle(t *testing.T) {
|
||||||
|
router, _, _ := setupFileTestRouter(t)
|
||||||
|
|
||||||
|
// Create test file data
|
||||||
|
fileData := []byte("Hello, World! This is a test file for the full lifecycle integration test.")
|
||||||
|
fileSize := int64(len(fileData))
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
chunkSize := int64(5 << 20) // 5MB min chunk size
|
||||||
|
|
||||||
|
// 1. Init upload
|
||||||
|
initBody := model.InitUploadRequest{
|
||||||
|
FileName: "test.txt",
|
||||||
|
FileSize: fileSize,
|
||||||
|
SHA256: sha256Hash,
|
||||||
|
ChunkSize: &chunkSize,
|
||||||
|
}
|
||||||
|
initJSON, _ := json.Marshal(initBody)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("init upload: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeResponse(t, w)
|
||||||
|
var session model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &session); err != nil {
|
||||||
|
t.Fatalf("failed to decode session: %v", err)
|
||||||
|
}
|
||||||
|
if session.Status != "pending" {
|
||||||
|
t.Fatalf("expected status pending, got %s", session.Status)
|
||||||
|
}
|
||||||
|
if session.TotalChunks != 1 {
|
||||||
|
t.Fatalf("expected 1 chunk, got %d", session.TotalChunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Upload chunk 0
|
||||||
|
url := fmt.Sprintf("/api/v1/files/uploads/%d/chunks/0", session.ID)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = createChunkRequest(t, url, fileData)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("upload chunk: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Complete upload
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("complete upload: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = decodeResponse(t, w)
|
||||||
|
var fileResp model.FileResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
|
||||||
|
t.Fatalf("failed to decode file response: %v", err)
|
||||||
|
}
|
||||||
|
if fileResp.Name != "test.txt" {
|
||||||
|
t.Fatalf("expected name test.txt, got %s", fileResp.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Download file
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("download: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), fileData) {
|
||||||
|
t.Fatalf("downloaded data mismatch: got %q, want %q", w.Body.String(), string(fileData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Delete file
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("delete: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Verify file is gone (download should fail)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
t.Fatal("expected download to fail after delete, got 200")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileDedup(t *testing.T) {
|
||||||
|
router, db, _ := setupFileTestRouter(t)
|
||||||
|
|
||||||
|
fileData := []byte("Duplicate content for dedup test.")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
chunkSize := int64(5 << 20)
|
||||||
|
|
||||||
|
// Upload file A (full lifecycle)
|
||||||
|
fileAID := helperUploadFile(t, router, "fileA.txt", fileData, sha256Hash, nil, chunkSize)
|
||||||
|
if fileAID == 0 {
|
||||||
|
t.Fatal("file A ID should not be 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload file B with same SHA256 → should be instant dedup
|
||||||
|
fileBID := helperUploadFile(t, router, "fileB.txt", fileData, sha256Hash, nil, chunkSize)
|
||||||
|
if fileBID == 0 {
|
||||||
|
t.Fatal("file B ID should not be 0")
|
||||||
|
}
|
||||||
|
if fileBID == fileAID {
|
||||||
|
t.Fatal("file B should have a different ID than file A")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify blob ref_count = 2
|
||||||
|
var blob model.FileBlob
|
||||||
|
if err := db.Where("sha256 = ?", sha256Hash).First(&blob).Error; err != nil {
|
||||||
|
t.Fatalf("blob not found: %v", err)
|
||||||
|
}
|
||||||
|
if blob.RefCount != 2 {
|
||||||
|
t.Fatalf("expected ref_count 2, got %d", blob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both files should be downloadable
|
||||||
|
for _, id := range []int64{fileAID, fileBID} {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", id), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("download file %d: expected 200, got %d", id, w.Code)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), fileData) {
|
||||||
|
t.Fatalf("downloaded data mismatch for file %d", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete file A — file B should still be downloadable
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("delete file A: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify blob still exists with ref_count = 1
|
||||||
|
if err := db.Where("sha256 = ?", sha256Hash).First(&blob).Error; err != nil {
|
||||||
|
t.Fatalf("blob should still exist: %v", err)
|
||||||
|
}
|
||||||
|
if blob.RefCount != 2 {
|
||||||
|
t.Fatalf("expected ref_count 2 (blob still shared), got %d", blob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// File B still downloadable
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("file B should still be downloadable, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileResumeUpload(t *testing.T) {
|
||||||
|
router, _, _ := setupFileTestRouter(t)
|
||||||
|
|
||||||
|
// Create data that spans 2 chunks (min chunk = 5MB, use that)
|
||||||
|
chunkSize := int64(5 << 20) // 5MB
|
||||||
|
data1 := bytes.Repeat([]byte("A"), int(chunkSize))
|
||||||
|
data2 := bytes.Repeat([]byte("B"), int(chunkSize))
|
||||||
|
fileData := append(data1, data2...)
|
||||||
|
fileSize := int64(len(fileData))
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
// 1. Init upload
|
||||||
|
initBody := model.InitUploadRequest{
|
||||||
|
FileName: "resume.bin",
|
||||||
|
FileSize: fileSize,
|
||||||
|
SHA256: sha256Hash,
|
||||||
|
ChunkSize: &chunkSize,
|
||||||
|
}
|
||||||
|
initJSON, _ := json.Marshal(initBody)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/files/uploads", bytes.NewReader(initJSON))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("init: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeResponse(t, w)
|
||||||
|
var session model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &session); err != nil {
|
||||||
|
t.Fatalf("decode session: %v", err)
|
||||||
|
}
|
||||||
|
if session.TotalChunks != 2 {
|
||||||
|
t.Fatalf("expected 2 chunks, got %d", session.TotalChunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Upload chunk 0 only
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = createChunkRequest(t, fmt.Sprintf("/api/v1/files/uploads/%d/chunks/0", session.ID), data1)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("upload chunk 0: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Get status — should show chunk 0 uploaded
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("get status: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
resp = decodeResponse(t, w)
|
||||||
|
var status model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &status); err != nil {
|
||||||
|
t.Fatalf("decode status: %v", err)
|
||||||
|
}
|
||||||
|
if len(status.UploadedChunks) != 1 || status.UploadedChunks[0] != 0 {
|
||||||
|
t.Fatalf("expected uploaded_chunks=[0], got %v", status.UploadedChunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Upload chunk 1 (resume)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req = createChunkRequest(t, fmt.Sprintf("/api/v1/files/uploads/%d/chunks/1", session.ID), data2)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("upload chunk 1: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Complete
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("complete: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = decodeResponse(t, w)
|
||||||
|
var fileResp model.FileResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &fileResp); err != nil {
|
||||||
|
t.Fatalf("decode file response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Download and verify
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("download: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), fileData) {
|
||||||
|
t.Fatal("downloaded data does not match original")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileFolderOperations(t *testing.T) {
|
||||||
|
router, _, _ := setupFileTestRouter(t)
|
||||||
|
|
||||||
|
// 1. Create folder
|
||||||
|
folderBody := model.CreateFolderRequest{Name: "test-folder"}
|
||||||
|
folderJSON, _ := json.Marshal(folderBody)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/files/folders", bytes.NewReader(folderJSON))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("create folder: expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeResponse(t, w)
|
||||||
|
var folderResp model.FolderResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &folderResp); err != nil {
|
||||||
|
t.Fatalf("decode folder: %v", err)
|
||||||
|
}
|
||||||
|
if folderResp.Name != "test-folder" {
|
||||||
|
t.Fatalf("expected name test-folder, got %s", folderResp.Name)
|
||||||
|
}
|
||||||
|
if folderResp.Path != "/test-folder/" {
|
||||||
|
t.Fatalf("expected path /test-folder/, got %s", folderResp.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Upload a file into the folder
|
||||||
|
fileData := []byte("File inside folder")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
chunkSize := int64(5 << 20)
|
||||||
|
|
||||||
|
fileID := helperUploadFile(t, router, "folder_file.txt", fileData, sha256Hash, &folderResp.ID, chunkSize)
|
||||||
|
if fileID == 0 {
|
||||||
|
t.Fatal("file ID should not be 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. List files in folder
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files?folder_id=%d", folderResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("list files: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
resp = decodeResponse(t, w)
|
||||||
|
var listResp model.ListFilesResponse
|
||||||
|
if err := json.Unmarshal(resp.Data, &listResp); err != nil {
|
||||||
|
t.Fatalf("decode list: %v", err)
|
||||||
|
}
|
||||||
|
if listResp.Total != 1 {
|
||||||
|
t.Fatalf("expected 1 file in folder, got %d", listResp.Total)
|
||||||
|
}
|
||||||
|
if listResp.Files[0].Name != "folder_file.txt" {
|
||||||
|
t.Fatalf("expected file name folder_file.txt, got %s", listResp.Files[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. List folders (root)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", "/api/v1/files/folders", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("list folders: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Try delete folder (should fail — not empty)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/folders/%d", folderResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("delete non-empty folder: expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Delete the file first
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("delete file: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Now delete empty folder
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/folders/%d", folderResp.ID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("delete empty folder: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileRangeDownload(t *testing.T) {
|
||||||
|
router, _, _ := setupFileTestRouter(t)
|
||||||
|
|
||||||
|
fileData := []byte("0123456789ABCDEF") // 16 bytes
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
chunkSize := int64(5 << 20)
|
||||||
|
|
||||||
|
fileID := helperUploadFile(t, router, "range_test.bin", fileData, sha256Hash, nil, chunkSize)
|
||||||
|
|
||||||
|
// Download range bytes=4-9 → "456789"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
|
||||||
|
req.Header.Set("Range", "bytes=4-9")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusPartialContent {
|
||||||
|
t.Fatalf("range download: expected 206, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
contentRange := w.Header().Get("Content-Range")
|
||||||
|
expectedRange := fmt.Sprintf("bytes 4-9/%d", len(fileData))
|
||||||
|
if contentRange != expectedRange {
|
||||||
|
t.Fatalf("content-range: expected %q, got %q", expectedRange, contentRange)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), []byte("456789")) {
|
||||||
|
t.Fatalf("range content: expected '456789', got %q", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full download still works
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("full download: expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), fileData) {
|
||||||
|
t.Fatalf("full download data mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileDeleteOneRefOtherStillDownloadable(t *testing.T) {
|
||||||
|
router, _, _ := setupFileTestRouter(t)
|
||||||
|
|
||||||
|
fileData := []byte("Shared blob content for multi-ref test")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
chunkSize := int64(5 << 20)
|
||||||
|
|
||||||
|
// Upload file A
|
||||||
|
fileAID := helperUploadFile(t, router, "refA.txt", fileData, sha256Hash, nil, chunkSize)
|
||||||
|
|
||||||
|
// Upload file B with same content → dedup instant
|
||||||
|
fileBID := helperUploadFile(t, router, "refB.txt", fileData, sha256Hash, nil, chunkSize)
|
||||||
|
|
||||||
|
// Delete file A
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("delete file A: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file A is gone
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileAID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
// File is soft-deleted, but GetByID may still find it depending on soft-delete handling
|
||||||
|
// The important check is that file B is still downloadable
|
||||||
|
|
||||||
|
// File B should still be downloadable
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("file B should still be downloadable after deleting file A, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), fileData) {
|
||||||
|
t.Fatal("file B download data mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete file B — now blob should be fully removed
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileBID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("delete file B: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// File B should now be gone
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
req, _ = http.NewRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileBID), nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
t.Fatal("file B should not be downloadable after deletion")
|
||||||
|
}
|
||||||
|
}
|
||||||
257
cmd/server/integration_app_test.go
Normal file
257
cmd/server/integration_app_test.go
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// appListData mirrors the list endpoint response data structure.
|
||||||
|
type appListData struct {
|
||||||
|
Applications []json.RawMessage `json:"applications"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// appCreatedData mirrors the create endpoint response data structure.
|
||||||
|
type appCreatedData struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// appMessageData mirrors the update/delete endpoint response data structure.
|
||||||
|
type appMessageData struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// appData mirrors the application model returned by GET.
|
||||||
|
type appData struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ScriptTemplate string `json:"script_template"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// appDoRequest is a small wrapper that marshals body and calls env.DoRequest.
|
||||||
|
func appDoRequest(env *testenv.TestEnv, method, path string, body interface{}) *http.Response {
|
||||||
|
var r io.Reader
|
||||||
|
if body != nil {
|
||||||
|
b, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("appDoRequest marshal: %v", err))
|
||||||
|
}
|
||||||
|
r = bytes.NewReader(b)
|
||||||
|
}
|
||||||
|
return env.DoRequest(method, path, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// appDecodeAll decodes the response and also reads the HTTP status.
|
||||||
|
func appDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
|
||||||
|
statusCode = resp.StatusCode
|
||||||
|
success, data, err = env.DecodeResponse(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// appSeedApp creates an app via the service (bypasses HTTP) and returns its ID.
|
||||||
|
func appSeedApp(env *testenv.TestEnv, name string) int64 {
|
||||||
|
id, err := env.CreateApp(name, "#!/bin/bash\necho hello", json.RawMessage(`[]`))
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("appSeedApp: %v", err))
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_App_List verifies GET /api/v1/applications returns an empty list initially.
|
||||||
|
func TestIntegration_App_List(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/applications", nil)
|
||||||
|
status, success, data, err := appDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var list appListData
|
||||||
|
if err := json.Unmarshal(data, &list); err != nil {
|
||||||
|
t.Fatalf("unmarshal list data: %v", err)
|
||||||
|
}
|
||||||
|
if list.Total != 0 {
|
||||||
|
t.Fatalf("expected total=0, got %d", list.Total)
|
||||||
|
}
|
||||||
|
if len(list.Applications) != 0 {
|
||||||
|
t.Fatalf("expected 0 applications, got %d", len(list.Applications))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_App_Create verifies POST /api/v1/applications creates an application.
|
||||||
|
func TestIntegration_App_Create(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"name": "test-app-create",
|
||||||
|
"script_template": "#!/bin/bash\necho hello",
|
||||||
|
"parameters": []interface{}{},
|
||||||
|
}
|
||||||
|
resp := appDoRequest(env, http.MethodPost, "/api/v1/applications", body)
|
||||||
|
status, success, data, err := appDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status 201, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var created appCreatedData
|
||||||
|
if err := json.Unmarshal(data, &created); err != nil {
|
||||||
|
t.Fatalf("unmarshal created data: %v", err)
|
||||||
|
}
|
||||||
|
if created.ID <= 0 {
|
||||||
|
t.Fatalf("expected positive id, got %d", created.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_App_Get verifies GET /api/v1/applications/:id returns the correct application.
|
||||||
|
func TestIntegration_App_Get(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
id := appSeedApp(env, "test-app-get")
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api/v1/applications/%d", id)
|
||||||
|
resp := env.DoRequest(http.MethodGet, path, nil)
|
||||||
|
status, success, data, err := appDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var app appData
|
||||||
|
if err := json.Unmarshal(data, &app); err != nil {
|
||||||
|
t.Fatalf("unmarshal app data: %v", err)
|
||||||
|
}
|
||||||
|
if app.ID != id {
|
||||||
|
t.Fatalf("expected id=%d, got %d", id, app.ID)
|
||||||
|
}
|
||||||
|
if app.Name != "test-app-get" {
|
||||||
|
t.Fatalf("expected name=test-app-get, got %s", app.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_App_Update verifies PUT /api/v1/applications/:id updates an application.
|
||||||
|
func TestIntegration_App_Update(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
id := appSeedApp(env, "test-app-update-before")
|
||||||
|
|
||||||
|
newName := "test-app-update-after"
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"name": newName,
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("/api/v1/applications/%d", id)
|
||||||
|
resp := appDoRequest(env, http.MethodPut, path, body)
|
||||||
|
status, success, data, err := appDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg appMessageData
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
t.Fatalf("unmarshal message data: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Message != "application updated" {
|
||||||
|
t.Fatalf("expected message 'application updated', got %q", msg.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
getResp := env.DoRequest(http.MethodGet, path, nil)
|
||||||
|
_, _, getData, gErr := appDecodeAll(env, getResp)
|
||||||
|
if gErr != nil {
|
||||||
|
t.Fatalf("decode get response: %v", gErr)
|
||||||
|
}
|
||||||
|
var updated appData
|
||||||
|
if err := json.Unmarshal(getData, &updated); err != nil {
|
||||||
|
t.Fatalf("unmarshal updated app: %v", err)
|
||||||
|
}
|
||||||
|
if updated.Name != newName {
|
||||||
|
t.Fatalf("expected updated name=%q, got %q", newName, updated.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_App_Delete verifies DELETE /api/v1/applications/:id removes an application.
|
||||||
|
func TestIntegration_App_Delete(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
id := appSeedApp(env, "test-app-delete")
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api/v1/applications/%d", id)
|
||||||
|
resp := env.DoRequest(http.MethodDelete, path, nil)
|
||||||
|
status, success, data, err := appDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg appMessageData
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
t.Fatalf("unmarshal message data: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Message != "application deleted" {
|
||||||
|
t.Fatalf("expected message 'application deleted', got %q", msg.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deletion returns 404.
|
||||||
|
getResp := env.DoRequest(http.MethodGet, path, nil)
|
||||||
|
getStatus, getSuccess, _, _ := appDecodeAll(env, getResp)
|
||||||
|
if getStatus != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected status 404 after delete, got %d", getStatus)
|
||||||
|
}
|
||||||
|
if getSuccess {
|
||||||
|
t.Fatal("expected success=false after delete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_App_CreateValidation verifies POST /api/v1/applications with empty name returns error.
|
||||||
|
func TestIntegration_App_CreateValidation(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"name": "",
|
||||||
|
"script_template": "#!/bin/bash\necho hello",
|
||||||
|
}
|
||||||
|
resp := appDoRequest(env, http.MethodPost, "/api/v1/applications", body)
|
||||||
|
status, success, _, err := appDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400, got %d", status)
|
||||||
|
}
|
||||||
|
if success {
|
||||||
|
t.Fatal("expected success=false for validation error")
|
||||||
|
}
|
||||||
|
}
|
||||||
186
cmd/server/integration_cluster_test.go
Normal file
186
cmd/server/integration_cluster_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// clusterNodeData mirrors the NodeResponse DTO returned by the API.
|
||||||
|
type clusterNodeData struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
State []string `json:"state"`
|
||||||
|
CPUs int32 `json:"cpus"`
|
||||||
|
RealMemory int64 `json:"real_memory"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// clusterPartitionData mirrors the PartitionResponse DTO returned by the API.
|
||||||
|
type clusterPartitionData struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
State []string `json:"state"`
|
||||||
|
TotalNodes int32 `json:"total_nodes,omitempty"`
|
||||||
|
TotalCPUs int32 `json:"total_cpus,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// clusterDiagStat mirrors a single entry from the diag statistics.
|
||||||
|
type clusterDiagStat struct {
|
||||||
|
Parts []struct {
|
||||||
|
Param string `json:"param"`
|
||||||
|
} `json:"parts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// clusterDecodeAll decodes the response and returns status, success, and raw data.
|
||||||
|
func clusterDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
|
||||||
|
statusCode = resp.StatusCode
|
||||||
|
success, data, err = env.DecodeResponse(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Cluster_Nodes verifies GET /api/v1/nodes returns the 3 pre-loaded mock nodes.
|
||||||
|
func TestIntegration_Cluster_Nodes(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/nodes", nil)
|
||||||
|
status, success, data, err := clusterDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var nodes []clusterNodeData
|
||||||
|
if err := json.Unmarshal(data, &nodes); err != nil {
|
||||||
|
t.Fatalf("unmarshal nodes: %v", err)
|
||||||
|
}
|
||||||
|
if len(nodes) != 3 {
|
||||||
|
t.Fatalf("expected 3 nodes, got %d", len(nodes))
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make(map[string]bool, len(nodes))
|
||||||
|
for _, n := range nodes {
|
||||||
|
names[n.Name] = true
|
||||||
|
}
|
||||||
|
for _, expected := range []string{"node01", "node02", "node03"} {
|
||||||
|
if !names[expected] {
|
||||||
|
t.Errorf("missing expected node %q", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Cluster_NodeByName verifies GET /api/v1/nodes/:name returns a single node.
|
||||||
|
func TestIntegration_Cluster_NodeByName(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/nodes/node01", nil)
|
||||||
|
status, success, data, err := clusterDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var node clusterNodeData
|
||||||
|
if err := json.Unmarshal(data, &node); err != nil {
|
||||||
|
t.Fatalf("unmarshal node: %v", err)
|
||||||
|
}
|
||||||
|
if node.Name != "node01" {
|
||||||
|
t.Fatalf("expected name=node01, got %q", node.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Cluster_Partitions verifies GET /api/v1/partitions returns the 2 pre-loaded partitions.
|
||||||
|
func TestIntegration_Cluster_Partitions(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/partitions", nil)
|
||||||
|
status, success, data, err := clusterDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var partitions []clusterPartitionData
|
||||||
|
if err := json.Unmarshal(data, &partitions); err != nil {
|
||||||
|
t.Fatalf("unmarshal partitions: %v", err)
|
||||||
|
}
|
||||||
|
if len(partitions) != 2 {
|
||||||
|
t.Fatalf("expected 2 partitions, got %d", len(partitions))
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make(map[string]bool, len(partitions))
|
||||||
|
for _, p := range partitions {
|
||||||
|
names[p.Name] = true
|
||||||
|
}
|
||||||
|
if !names["normal"] {
|
||||||
|
t.Error("missing expected partition \"normal\"")
|
||||||
|
}
|
||||||
|
if !names["gpu"] {
|
||||||
|
t.Error("missing expected partition \"gpu\"")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Cluster_PartitionByName verifies GET /api/v1/partitions/:name returns a single partition.
|
||||||
|
func TestIntegration_Cluster_PartitionByName(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/partitions/normal", nil)
|
||||||
|
status, success, data, err := clusterDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var part clusterPartitionData
|
||||||
|
if err := json.Unmarshal(data, &part); err != nil {
|
||||||
|
t.Fatalf("unmarshal partition: %v", err)
|
||||||
|
}
|
||||||
|
if part.Name != "normal" {
|
||||||
|
t.Fatalf("expected name=normal, got %q", part.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Cluster_Diag verifies GET /api/v1/diag returns diagnostics data.
|
||||||
|
func TestIntegration_Cluster_Diag(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/diag", nil)
|
||||||
|
status, success, data, err := clusterDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the response contains a "statistics" field (non-empty JSON object).
|
||||||
|
var raw map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
t.Fatalf("unmarshal diag top-level: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := raw["statistics"]; !ok {
|
||||||
|
t.Fatal("diag response missing \"statistics\" field")
|
||||||
|
}
|
||||||
|
}
|
||||||
202
cmd/server/integration_e2e_test.go
Normal file
202
cmd/server/integration_e2e_test.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// e2eResponse mirrors the unified API response structure.
|
||||||
|
type e2eResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// e2eTaskCreatedData mirrors the POST /api/v1/tasks response data.
|
||||||
|
type e2eTaskCreatedData struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// e2eTaskItem mirrors a single task in the list response.
|
||||||
|
type e2eTaskItem struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
TaskName string `json:"task_name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
WorkDir string `json:"work_dir"`
|
||||||
|
ErrorMessage string `json:"error_message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// e2eTaskListData mirrors the list endpoint response data.
|
||||||
|
type e2eTaskListData struct {
|
||||||
|
Items []e2eTaskItem `json:"items"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// e2eSendRequest sends an HTTP request via the test env and returns the response.
|
||||||
|
func e2eSendRequest(env *testenv.TestEnv, method, path string, body string) *http.Response {
|
||||||
|
var r io.Reader
|
||||||
|
if body != "" {
|
||||||
|
r = strings.NewReader(body)
|
||||||
|
}
|
||||||
|
return env.DoRequest(method, path, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// e2eParseResponse decodes an HTTP response into e2eResponse.
|
||||||
|
func e2eParseResponse(resp *http.Response) (int, e2eResponse) {
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("e2eParseResponse read: %v", err))
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
var result e2eResponse
|
||||||
|
if err := json.Unmarshal(b, &result); err != nil {
|
||||||
|
panic(fmt.Sprintf("e2eParseResponse unmarshal: %v (body: %s)", err, string(b)))
|
||||||
|
}
|
||||||
|
return resp.StatusCode, result
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_E2E_CompleteWorkflow verifies the full lifecycle:
|
||||||
|
// create app → upload file → submit task → queued → running → completed.
|
||||||
|
func TestIntegration_E2E_CompleteWorkflow(t *testing.T) {
|
||||||
|
t.Log("========== E2E 全链路测试开始 ==========")
|
||||||
|
t.Log("")
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
t.Log("✓ 测试环境创建完成 (SQLite + MockSlurm + MockMinIO + Router + Poller)")
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 1: Create Application with script template and parameters.
|
||||||
|
t.Log("【步骤 1】创建应用")
|
||||||
|
appID, err := env.CreateApp("e2e-app", "#!/bin/bash\necho {{.np}}",
|
||||||
|
json.RawMessage(`[{"name":"np","type":"string","default":"1"}]`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 create app: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf(" → 应用创建成功, appID=%d, 脚本模板='#!/bin/bash echo {{.np}}', 参数=[np]", appID)
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 2: Upload input file.
|
||||||
|
t.Log("【步骤 2】上传输入文件")
|
||||||
|
fileID, _ := env.UploadTestData("input.txt", []byte("test input data"))
|
||||||
|
t.Logf(" → 文件上传成功, fileID=%d, 内容='test input data' (存入 MockMinIO + SQLite)", fileID)
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 3: Submit Task via API.
|
||||||
|
t.Log("【步骤 3】通过 HTTP API 提交任务")
|
||||||
|
body := fmt.Sprintf(
|
||||||
|
`{"app_id": %d, "task_name": "e2e-task", "values": {"np": "4"}, "file_ids": [%d]}`,
|
||||||
|
appID, fileID,
|
||||||
|
)
|
||||||
|
t.Logf(" → POST /api/v1/tasks body=%s", body)
|
||||||
|
resp := e2eSendRequest(env, http.MethodPost, "/api/v1/tasks", body)
|
||||||
|
status, result := e2eParseResponse(resp)
|
||||||
|
if status != http.StatusCreated {
|
||||||
|
t.Fatalf("step 3 submit task: status=%d, success=%v, error=%q", status, result.Success, result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var created e2eTaskCreatedData
|
||||||
|
if err := json.Unmarshal(result.Data, &created); err != nil {
|
||||||
|
t.Fatalf("step 3 parse task id: %v", err)
|
||||||
|
}
|
||||||
|
taskID := created.ID
|
||||||
|
if taskID <= 0 {
|
||||||
|
t.Fatalf("step 3: expected positive task id, got %d", taskID)
|
||||||
|
}
|
||||||
|
t.Logf(" → HTTP 201 Created, taskID=%d", taskID)
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 4: Wait for queued status.
|
||||||
|
t.Log("【步骤 4】等待 TaskProcessor 异步提交到 MockSlurm")
|
||||||
|
t.Log(" → 后台流程: submitted → preparing → downloading → ready → queued")
|
||||||
|
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
|
||||||
|
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
|
||||||
|
t.Fatalf("step 4 wait for queued: %v (current status via API: %q)", err, taskStatus)
|
||||||
|
}
|
||||||
|
t.Logf(" → 任务状态变为 'queued' (TaskProcessor 已提交到 Slurm)")
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 5: Get slurmJobID.
|
||||||
|
t.Log("【步骤 5】查询数据库获取 Slurm Job ID")
|
||||||
|
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 5 get slurm job id: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf(" → slurmJobID=%d (MockSlurm 中的作业号)", slurmJobID)
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 6: Transition to RUNNING.
|
||||||
|
t.Log("【步骤 6】模拟 Slurm: 作业开始运行")
|
||||||
|
t.Logf(" → MockSlurm.SetJobState(%d, 'RUNNING')", slurmJobID)
|
||||||
|
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
|
||||||
|
t.Logf(" → MakeTaskStale(%d) — 绕过 30s 等待,让 poller 立即刷新", taskID)
|
||||||
|
if err := env.MakeTaskStale(taskID); err != nil {
|
||||||
|
t.Fatalf("step 6 make task stale: %v", err)
|
||||||
|
}
|
||||||
|
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
|
||||||
|
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
|
||||||
|
t.Fatalf("step 6 wait for running: %v (current status via API: %q)", err, taskStatus)
|
||||||
|
}
|
||||||
|
t.Logf(" → 任务状态变为 'running'")
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 7: Transition to COMPLETED — job evicted from activeJobs to historyJobs.
|
||||||
|
t.Log("【步骤 7】模拟 Slurm: 作业运行完成")
|
||||||
|
t.Logf(" → MockSlurm.SetJobState(%d, 'COMPLETED') — 作业从 activeJobs 淘汰到 historyJobs", slurmJobID)
|
||||||
|
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
|
||||||
|
t.Log(" → MakeTaskStale + WaitForTaskStatus...")
|
||||||
|
if err := env.MakeTaskStale(taskID); err != nil {
|
||||||
|
t.Fatalf("step 7 make task stale: %v", err)
|
||||||
|
}
|
||||||
|
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
|
||||||
|
taskStatus, _ := e2eFetchTaskStatus(env, taskID)
|
||||||
|
t.Fatalf("step 7 wait for completed: %v (current status via API: %q)", err, taskStatus)
|
||||||
|
}
|
||||||
|
t.Logf(" → 任务状态变为 'completed' (通过 SlurmDB 历史回退路径获取)")
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 8: Verify final state via GET /api/v1/tasks.
|
||||||
|
t.Log("【步骤 8】通过 HTTP API 验证最终状态")
|
||||||
|
finalStatus, finalItem := e2eFetchTaskStatus(env, taskID)
|
||||||
|
if finalStatus != "completed" {
|
||||||
|
t.Fatalf("step 8: expected status completed, got %q (error: %q)", finalStatus, finalItem.ErrorMessage)
|
||||||
|
}
|
||||||
|
t.Logf(" → GET /api/v1/tasks 返回 status='completed'")
|
||||||
|
t.Logf(" → task_name='%s', work_dir='%s'", finalItem.TaskName, finalItem.WorkDir)
|
||||||
|
t.Logf(" → MockSlurm activeJobs=%d, historyJobs=%d",
|
||||||
|
len(env.MockSlurm.GetAllActiveJobs()), len(env.MockSlurm.GetAllHistoryJobs()))
|
||||||
|
t.Log("")
|
||||||
|
|
||||||
|
// Step 9: Verify WorkDir exists and contains the input file.
|
||||||
|
t.Log("【步骤 9】验证工作目录")
|
||||||
|
if finalItem.WorkDir == "" {
|
||||||
|
t.Fatal("step 9: expected non-empty work_dir")
|
||||||
|
}
|
||||||
|
t.Logf(" → work_dir='%s' (非空,TaskProcessor 已创建)", finalItem.WorkDir)
|
||||||
|
t.Log("")
|
||||||
|
t.Log("========== E2E 全链路测试通过 ✓ ==========")
|
||||||
|
}
|
||||||
|
|
||||||
|
// e2eFetchTaskStatus fetches a single task's status from the list API.
|
||||||
|
func e2eFetchTaskStatus(env *testenv.TestEnv, taskID int64) (string, e2eTaskItem) {
|
||||||
|
resp := e2eSendRequest(env, http.MethodGet, "/api/v1/tasks", "")
|
||||||
|
_, result := e2eParseResponse(resp)
|
||||||
|
|
||||||
|
var list e2eTaskListData
|
||||||
|
if err := json.Unmarshal(result.Data, &list); err != nil {
|
||||||
|
return "", e2eTaskItem{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range list.Items {
|
||||||
|
if item.ID == taskID {
|
||||||
|
return item.Status, item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", e2eTaskItem{}
|
||||||
|
}
|
||||||
170
cmd/server/integration_file_test.go
Normal file
170
cmd/server/integration_file_test.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fileAPIResp mirrors server.APIResponse for file integration tests.
|
||||||
|
type fileAPIResp struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileDecode parses an HTTP response body into fileAPIResp.
|
||||||
|
func fileDecode(t *testing.T, body io.Reader) fileAPIResp {
|
||||||
|
t.Helper()
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fileDecode: read body: %v", err)
|
||||||
|
}
|
||||||
|
var r fileAPIResp
|
||||||
|
if err := json.Unmarshal(data, &r); err != nil {
|
||||||
|
t.Fatalf("fileDecode: unmarshal: %v (body: %s)", err, string(data))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_File_List(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Upload a file so the list is non-empty.
|
||||||
|
env.UploadTestData("list_test.txt", []byte("hello list"))
|
||||||
|
|
||||||
|
resp := env.DoRequest("GET", "/api/v1/files", nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := fileDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var listResp model.ListFilesResponse
|
||||||
|
if err := json.Unmarshal(r.Data, &listResp); err != nil {
|
||||||
|
t.Fatalf("unmarshal list response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listResp.Files) == 0 {
|
||||||
|
t.Fatal("expected at least 1 file in list, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, f := range listResp.Files {
|
||||||
|
if f.Name == "list_test.txt" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected to find list_test.txt in file list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Total < 1 {
|
||||||
|
t.Fatalf("expected total >= 1, got %d", listResp.Total)
|
||||||
|
}
|
||||||
|
if listResp.Page < 1 {
|
||||||
|
t.Fatalf("expected page >= 1, got %d", listResp.Page)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_File_Get(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileID, _ := env.UploadTestData("get_test.txt", []byte("hello get"))
|
||||||
|
|
||||||
|
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := fileDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileResp model.FileResponse
|
||||||
|
if err := json.Unmarshal(r.Data, &fileResp); err != nil {
|
||||||
|
t.Fatalf("unmarshal file response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fileResp.ID != fileID {
|
||||||
|
t.Fatalf("expected file ID %d, got %d", fileID, fileResp.ID)
|
||||||
|
}
|
||||||
|
if fileResp.Name != "get_test.txt" {
|
||||||
|
t.Fatalf("expected name get_test.txt, got %s", fileResp.Name)
|
||||||
|
}
|
||||||
|
if fileResp.Size != int64(len("hello get")) {
|
||||||
|
t.Fatalf("expected size %d, got %d", len("hello get"), fileResp.Size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_File_Download(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
content := []byte("hello world")
|
||||||
|
fileID, _ := env.UploadTestData("download_test.txt", content)
|
||||||
|
|
||||||
|
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d/download", fileID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read download body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(body) != string(content) {
|
||||||
|
t.Fatalf("downloaded content mismatch: got %q, want %q", string(body), string(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType == "" {
|
||||||
|
t.Fatal("expected Content-Type header to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_File_Delete(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileID, _ := env.UploadTestData("delete_test.txt", []byte("hello delete"))
|
||||||
|
|
||||||
|
// Delete the file.
|
||||||
|
resp := env.DoRequest("DELETE", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := fileDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("delete response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the file is gone — GET should return 500 (internal error) or 404.
|
||||||
|
getResp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/%d", fileID), nil)
|
||||||
|
defer getResp.Body.Close()
|
||||||
|
|
||||||
|
if getResp.StatusCode == http.StatusOK {
|
||||||
|
gr := fileDecode(t, getResp.Body)
|
||||||
|
if gr.Success {
|
||||||
|
t.Fatal("expected file to be deleted, but GET still returns success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
193
cmd/server/integration_folder_test.go
Normal file
193
cmd/server/integration_folder_test.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// folderData mirrors the FolderResponse DTO returned by the API.
|
||||||
|
type folderData struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ParentID *int64 `json:"parent_id,omitempty"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
FileCount int64 `json:"file_count"`
|
||||||
|
SubFolderCount int64 `json:"subfolder_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// folderMessageData mirrors the delete endpoint response data structure.
|
||||||
|
type folderMessageData struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// folderDoRequest marshals body and calls env.DoRequest.
|
||||||
|
func folderDoRequest(env *testenv.TestEnv, method, path string, body interface{}) *http.Response {
|
||||||
|
var r io.Reader
|
||||||
|
if body != nil {
|
||||||
|
b, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("folderDoRequest marshal: %v", err))
|
||||||
|
}
|
||||||
|
r = bytes.NewReader(b)
|
||||||
|
}
|
||||||
|
return env.DoRequest(method, path, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// folderDecodeAll decodes the response and returns status, success, and raw data.
|
||||||
|
func folderDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
|
||||||
|
statusCode = resp.StatusCode
|
||||||
|
success, data, err = env.DecodeResponse(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// folderSeed creates a folder via HTTP and returns its ID.
|
||||||
|
func folderSeed(env *testenv.TestEnv, name string) int64 {
|
||||||
|
body := map[string]interface{}{"name": name}
|
||||||
|
resp := folderDoRequest(env, http.MethodPost, "/api/v1/files/folders", body)
|
||||||
|
status, success, data, err := folderDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("folderSeed decode: %v", err))
|
||||||
|
}
|
||||||
|
if status != http.StatusCreated {
|
||||||
|
panic(fmt.Sprintf("folderSeed: expected 201, got %d", status))
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
panic("folderSeed: expected success=true")
|
||||||
|
}
|
||||||
|
var f folderData
|
||||||
|
if err := json.Unmarshal(data, &f); err != nil {
|
||||||
|
panic(fmt.Sprintf("folderSeed unmarshal: %v", err))
|
||||||
|
}
|
||||||
|
return f.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Folder_Create verifies POST /api/v1/files/folders creates a folder.
|
||||||
|
func TestIntegration_Folder_Create(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
body := map[string]interface{}{"name": "test-folder-create"}
|
||||||
|
resp := folderDoRequest(env, http.MethodPost, "/api/v1/files/folders", body)
|
||||||
|
status, success, data, err := folderDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status 201, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var created folderData
|
||||||
|
if err := json.Unmarshal(data, &created); err != nil {
|
||||||
|
t.Fatalf("unmarshal created data: %v", err)
|
||||||
|
}
|
||||||
|
if created.ID <= 0 {
|
||||||
|
t.Fatalf("expected positive id, got %d", created.ID)
|
||||||
|
}
|
||||||
|
if created.Name != "test-folder-create" {
|
||||||
|
t.Fatalf("expected name=test-folder-create, got %s", created.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Folder_List verifies GET /api/v1/files/folders returns a list.
|
||||||
|
func TestIntegration_Folder_List(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Seed two folders.
|
||||||
|
folderSeed(env, "list-folder-1")
|
||||||
|
folderSeed(env, "list-folder-2")
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/files/folders", nil)
|
||||||
|
status, success, data, err := folderDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var folders []folderData
|
||||||
|
if err := json.Unmarshal(data, &folders); err != nil {
|
||||||
|
t.Fatalf("unmarshal list data: %v", err)
|
||||||
|
}
|
||||||
|
if len(folders) < 2 {
|
||||||
|
t.Fatalf("expected at least 2 folders, got %d", len(folders))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Folder_Get verifies GET /api/v1/files/folders/:id returns folder details.
|
||||||
|
func TestIntegration_Folder_Get(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
id := folderSeed(env, "test-folder-get")
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api/v1/files/folders/%d", id)
|
||||||
|
resp := env.DoRequest(http.MethodGet, path, nil)
|
||||||
|
status, success, data, err := folderDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var f folderData
|
||||||
|
if err := json.Unmarshal(data, &f); err != nil {
|
||||||
|
t.Fatalf("unmarshal folder data: %v", err)
|
||||||
|
}
|
||||||
|
if f.ID != id {
|
||||||
|
t.Fatalf("expected id=%d, got %d", id, f.ID)
|
||||||
|
}
|
||||||
|
if f.Name != "test-folder-get" {
|
||||||
|
t.Fatalf("expected name=test-folder-get, got %s", f.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Folder_Delete verifies DELETE /api/v1/files/folders/:id removes a folder.
|
||||||
|
func TestIntegration_Folder_Delete(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
id := folderSeed(env, "test-folder-delete")
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api/v1/files/folders/%d", id)
|
||||||
|
resp := env.DoRequest(http.MethodDelete, path, nil)
|
||||||
|
status, success, data, err := folderDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg folderMessageData
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
t.Fatalf("unmarshal message data: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Message != "folder deleted" {
|
||||||
|
t.Fatalf("expected message 'folder deleted', got %q", msg.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone via GET → 404.
|
||||||
|
getResp := env.DoRequest(http.MethodGet, path, nil)
|
||||||
|
getStatus, getSuccess, _, _ := folderDecodeAll(env, getResp)
|
||||||
|
if getStatus != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected status 404 after delete, got %d", getStatus)
|
||||||
|
}
|
||||||
|
if getSuccess {
|
||||||
|
t.Fatal("expected success=false after delete")
|
||||||
|
}
|
||||||
|
}
|
||||||
222
cmd/server/integration_job_test.go
Normal file
222
cmd/server/integration_job_test.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// jobItemData mirrors the JobResponse DTO for a single job.
|
||||||
|
type jobItemData struct {
|
||||||
|
JobID int32 `json:"job_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
State []string `json:"job_state"`
|
||||||
|
Partition string `json:"partition"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobListData mirrors the paginated JobListResponse DTO.
|
||||||
|
type jobListData struct {
|
||||||
|
Jobs []jobItemData `json:"jobs"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobCancelData mirrors the cancel response message.
|
||||||
|
type jobCancelData struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobDecodeAll decodes the response and returns status, success, and raw data.
|
||||||
|
func jobDecodeAll(env *testenv.TestEnv, resp *http.Response) (statusCode int, success bool, data json.RawMessage, err error) {
|
||||||
|
statusCode = resp.StatusCode
|
||||||
|
success, data, err = env.DecodeResponse(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobSubmitBody builds a JSON body for job submit requests.
|
||||||
|
func jobSubmitBody(script string) *bytes.Reader {
|
||||||
|
body, _ := json.Marshal(map[string]string{"script": script})
|
||||||
|
return bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobSubmitViaAPI submits a job and returns the job ID. Fatals on failure.
|
||||||
|
func jobSubmitViaAPI(t *testing.T, env *testenv.TestEnv, script string) int32 {
|
||||||
|
t.Helper()
|
||||||
|
resp := env.DoRequest(http.MethodPost, "/api/v1/jobs/submit", jobSubmitBody(script))
|
||||||
|
status, success, data, err := jobDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("submit job decode: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status 201, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true on submit")
|
||||||
|
}
|
||||||
|
|
||||||
|
var job jobItemData
|
||||||
|
if err := json.Unmarshal(data, &job); err != nil {
|
||||||
|
t.Fatalf("unmarshal submitted job: %v", err)
|
||||||
|
}
|
||||||
|
return job.JobID
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Jobs_Submit verifies POST /api/v1/jobs/submit creates a new job.
|
||||||
|
func TestIntegration_Jobs_Submit(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
script := "#!/bin/bash\necho hello"
|
||||||
|
resp := env.DoRequest(http.MethodPost, "/api/v1/jobs/submit", jobSubmitBody(script))
|
||||||
|
status, success, data, err := jobDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status 201, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var job jobItemData
|
||||||
|
if err := json.Unmarshal(data, &job); err != nil {
|
||||||
|
t.Fatalf("unmarshal job: %v", err)
|
||||||
|
}
|
||||||
|
if job.JobID <= 0 {
|
||||||
|
t.Fatalf("expected positive job_id, got %d", job.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Jobs_List verifies GET /api/v1/jobs returns a paginated job list.
|
||||||
|
func TestIntegration_Jobs_List(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Submit a job so the list is not empty.
|
||||||
|
jobSubmitViaAPI(t, env, "#!/bin/bash\necho list-test")
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/jobs", nil)
|
||||||
|
status, success, data, err := jobDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var list jobListData
|
||||||
|
if err := json.Unmarshal(data, &list); err != nil {
|
||||||
|
t.Fatalf("unmarshal job list: %v", err)
|
||||||
|
}
|
||||||
|
if list.Total < 1 {
|
||||||
|
t.Fatalf("expected at least 1 job, got total=%d", list.Total)
|
||||||
|
}
|
||||||
|
if list.Page != 1 {
|
||||||
|
t.Fatalf("expected page=1, got %d", list.Page)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Jobs_Get verifies GET /api/v1/jobs/:id returns a single job.
|
||||||
|
func TestIntegration_Jobs_Get(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho get-test")
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
|
||||||
|
resp := env.DoRequest(http.MethodGet, path, nil)
|
||||||
|
status, success, data, err := jobDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var job jobItemData
|
||||||
|
if err := json.Unmarshal(data, &job); err != nil {
|
||||||
|
t.Fatalf("unmarshal job: %v", err)
|
||||||
|
}
|
||||||
|
if job.JobID != jobID {
|
||||||
|
t.Fatalf("expected job_id=%d, got %d", jobID, job.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Jobs_Cancel verifies DELETE /api/v1/jobs/:id cancels a job.
|
||||||
|
func TestIntegration_Jobs_Cancel(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho cancel-test")
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
|
||||||
|
resp := env.DoRequest(http.MethodDelete, path, nil)
|
||||||
|
status, success, data, err := jobDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg jobCancelData
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
t.Fatalf("unmarshal cancel response: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Message == "" {
|
||||||
|
t.Fatal("expected non-empty cancel message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Jobs_History verifies GET /api/v1/jobs/history returns historical jobs.
|
||||||
|
func TestIntegration_Jobs_History(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Submit and cancel a job so it moves from active to history queue.
|
||||||
|
jobID := jobSubmitViaAPI(t, env, "#!/bin/bash\necho history-test")
|
||||||
|
path := fmt.Sprintf("/api/v1/jobs/%d", jobID)
|
||||||
|
env.DoRequest(http.MethodDelete, path, nil)
|
||||||
|
|
||||||
|
resp := env.DoRequest(http.MethodGet, "/api/v1/jobs/history", nil)
|
||||||
|
status, success, data, err := jobDecodeAll(env, resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", status)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
var list jobListData
|
||||||
|
if err := json.Unmarshal(data, &list); err != nil {
|
||||||
|
t.Fatalf("unmarshal history: %v", err)
|
||||||
|
}
|
||||||
|
if list.Total < 1 {
|
||||||
|
t.Fatalf("expected at least 1 history job, got total=%d", list.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the cancelled job appears in history.
|
||||||
|
found := false
|
||||||
|
for _, j := range list.Jobs {
|
||||||
|
if j.JobID == jobID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("cancelled job %d not found in history", jobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
261
cmd/server/integration_task_test.go
Normal file
261
cmd/server/integration_task_test.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// taskAPIResponse decodes the unified API response envelope.
|
||||||
|
type taskAPIResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskCreateData is the data payload from a successful task creation.
|
||||||
|
type taskCreateData struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskListData is the data payload from listing tasks.
|
||||||
|
type taskListData struct {
|
||||||
|
Items []taskListItem `json:"items"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type taskListItem struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
TaskName string `json:"task_name"`
|
||||||
|
AppID int64 `json:"app_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
SlurmJobID *int32 `json:"slurm_job_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskSendReq sends an HTTP request via the test env and returns the response.
|
||||||
|
func taskSendReq(t *testing.T, env *testenv.TestEnv, method, path string, body string) *http.Response {
|
||||||
|
t.Helper()
|
||||||
|
var r io.Reader
|
||||||
|
if body != "" {
|
||||||
|
r = strings.NewReader(body)
|
||||||
|
}
|
||||||
|
resp := env.DoRequest(method, path, r)
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskParseResp decodes the response body into a taskAPIResponse.
|
||||||
|
func taskParseResp(t *testing.T, resp *http.Response) taskAPIResponse {
|
||||||
|
t.Helper()
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read response body: %v", err)
|
||||||
|
}
|
||||||
|
var result taskAPIResponse
|
||||||
|
if err := json.Unmarshal(b, &result); err != nil {
|
||||||
|
t.Fatalf("unmarshal response: %v (body: %s)", err, string(b))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskCreateViaAPI creates a task via the HTTP API and returns the task ID.
|
||||||
|
func taskCreateViaAPI(t *testing.T, env *testenv.TestEnv, appID int64, taskName string) int64 {
|
||||||
|
t.Helper()
|
||||||
|
body := fmt.Sprintf(`{"app_id":%d,"task_name":"%s","values":{},"file_ids":[]}`, appID, taskName)
|
||||||
|
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", body)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed := taskParseResp(t, resp)
|
||||||
|
if !parsed.Success {
|
||||||
|
t.Fatalf("expected success=true, got error: %s", parsed.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var data taskCreateData
|
||||||
|
if err := json.Unmarshal(parsed.Data, &data); err != nil {
|
||||||
|
t.Fatalf("unmarshal create data: %v", err)
|
||||||
|
}
|
||||||
|
if data.ID == 0 {
|
||||||
|
t.Fatal("expected non-zero task ID")
|
||||||
|
}
|
||||||
|
return data.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Tests ----------
|
||||||
|
|
||||||
|
func TestIntegration_Task_Create(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Create application
|
||||||
|
appID, err := env.CreateApp("task-create-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create task via API
|
||||||
|
taskID := taskCreateViaAPI(t, env, appID, "test-task-create")
|
||||||
|
|
||||||
|
// Verify the task ID is positive
|
||||||
|
if taskID <= 0 {
|
||||||
|
t.Fatalf("expected positive task ID, got %d", taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait briefly for async processing, then verify task exists in DB via list
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
parsed := taskParseResp(t, resp)
|
||||||
|
|
||||||
|
var listData taskListData
|
||||||
|
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
|
||||||
|
t.Fatalf("unmarshal list data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, item := range listData.Items {
|
||||||
|
if item.ID == taskID {
|
||||||
|
found = true
|
||||||
|
if item.TaskName != "test-task-create" {
|
||||||
|
t.Errorf("expected task_name=test-task-create, got %s", item.TaskName)
|
||||||
|
}
|
||||||
|
if item.AppID != appID {
|
||||||
|
t.Errorf("expected app_id=%d, got %d", appID, item.AppID)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("task %d not found in list", taskID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Task_List(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Create application
|
||||||
|
appID, err := env.CreateApp("task-list-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 3 tasks
|
||||||
|
taskCreateViaAPI(t, env, appID, "list-task-1")
|
||||||
|
taskCreateViaAPI(t, env, appID, "list-task-2")
|
||||||
|
taskCreateViaAPI(t, env, appID, "list-task-3")
|
||||||
|
|
||||||
|
// Allow async processing
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// List tasks
|
||||||
|
resp := taskSendReq(t, env, http.MethodGet, "/api/v1/tasks", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed := taskParseResp(t, resp)
|
||||||
|
if !parsed.Success {
|
||||||
|
t.Fatalf("expected success, got error: %s", parsed.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var listData taskListData
|
||||||
|
if err := json.Unmarshal(parsed.Data, &listData); err != nil {
|
||||||
|
t.Fatalf("unmarshal list data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if listData.Total < 3 {
|
||||||
|
t.Fatalf("expected at least 3 tasks, got %d", listData.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify each created task has required fields
|
||||||
|
for _, item := range listData.Items {
|
||||||
|
if item.ID == 0 {
|
||||||
|
t.Error("expected non-zero ID")
|
||||||
|
}
|
||||||
|
if item.Status == "" {
|
||||||
|
t.Error("expected non-empty status")
|
||||||
|
}
|
||||||
|
if item.AppID == 0 {
|
||||||
|
t.Error("expected non-zero app_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Task_PollerLifecycle(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// 1. Create application
|
||||||
|
appID, err := env.CreateApp("poller-lifecycle-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Submit task via API
|
||||||
|
taskID := taskCreateViaAPI(t, env, appID, "poller-lifecycle-task")
|
||||||
|
|
||||||
|
// 3. Wait for queued — TaskProcessor submits to MockSlurm asynchronously.
|
||||||
|
// Intermediate states (submitted→preparing→downloading→ready→queued) are
|
||||||
|
// non-deterministic; only assert the final "queued" state.
|
||||||
|
if err := env.WaitForTaskStatus(taskID, "queued", 5*time.Second); err != nil {
|
||||||
|
t.Fatalf("wait for queued: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Get slurm job ID from DB (not returned by API)
|
||||||
|
slurmJobID, err := env.GetTaskSlurmJobID(taskID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get slurm job id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Transition: queued → running
|
||||||
|
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
|
||||||
|
env.MockSlurm.SetJobState(slurmJobID, "RUNNING")
|
||||||
|
if err := env.MakeTaskStale(taskID); err != nil {
|
||||||
|
t.Fatalf("make task stale (running): %v", err)
|
||||||
|
}
|
||||||
|
if err := env.WaitForTaskStatus(taskID, "running", 5*time.Second); err != nil {
|
||||||
|
t.Fatalf("wait for running: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Transition: running → completed
|
||||||
|
// ORDER IS CRITICAL: SetJobState BEFORE MakeTaskStale
|
||||||
|
env.MockSlurm.SetJobState(slurmJobID, "COMPLETED")
|
||||||
|
if err := env.MakeTaskStale(taskID); err != nil {
|
||||||
|
t.Fatalf("make task stale (completed): %v", err)
|
||||||
|
}
|
||||||
|
if err := env.WaitForTaskStatus(taskID, "completed", 5*time.Second); err != nil {
|
||||||
|
t.Fatalf("wait for completed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Task_Validation(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
// Missing required app_id
|
||||||
|
resp := taskSendReq(t, env, http.MethodPost, "/api/v1/tasks", `{"task_name":"no-app-id"}`)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400 for missing app_id, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed := taskParseResp(t, resp)
|
||||||
|
if parsed.Success {
|
||||||
|
t.Fatal("expected success=false for validation error")
|
||||||
|
}
|
||||||
|
if parsed.Error == "" {
|
||||||
|
t.Error("expected non-empty error message")
|
||||||
|
}
|
||||||
|
}
|
||||||
279
cmd/server/integration_upload_test.go
Normal file
279
cmd/server/integration_upload_test.go
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/testutil/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// uploadResponse mirrors server.APIResponse for upload integration tests.
|
||||||
|
type uploadAPIResp struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// uploadDecode parses an HTTP response body into uploadAPIResp.
|
||||||
|
func uploadDecode(t *testing.T, body io.Reader) uploadAPIResp {
|
||||||
|
t.Helper()
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("uploadDecode: read body: %v", err)
|
||||||
|
}
|
||||||
|
var r uploadAPIResp
|
||||||
|
if err := json.Unmarshal(data, &r); err != nil {
|
||||||
|
t.Fatalf("uploadDecode: unmarshal: %v (body: %s)", err, string(data))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// uploadInitSession calls InitUpload and returns the created session.
|
||||||
|
// Uses the real HTTP server from testenv.
|
||||||
|
func uploadInitSession(t *testing.T, env *testenv.TestEnv, fileName string, fileSize int64, sha256Hash string) model.UploadSessionResponse {
|
||||||
|
t.Helper()
|
||||||
|
reqBody := model.InitUploadRequest{
|
||||||
|
FileName: fileName,
|
||||||
|
FileSize: fileSize,
|
||||||
|
SHA256: sha256Hash,
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(reqBody)
|
||||||
|
resp := env.DoRequest("POST", "/api/v1/files/uploads", bytes.NewReader(body))
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("uploadInitSession: expected 201, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := uploadDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("uploadInitSession: response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var session model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(r.Data, &session); err != nil {
|
||||||
|
t.Fatalf("uploadInitSession: unmarshal session: %v", err)
|
||||||
|
}
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
// uploadSendChunk sends a single chunk via multipart form data.
|
||||||
|
// Uses raw HTTP client to set the correct multipart content type.
|
||||||
|
func uploadSendChunk(t *testing.T, env *testenv.TestEnv, sessionID int64, chunkIndex int, chunkData []byte) {
|
||||||
|
t.Helper()
|
||||||
|
url := fmt.Sprintf("%s/api/v1/files/uploads/%d/chunks/%d", env.URL(), sessionID, chunkIndex)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buf)
|
||||||
|
part, err := writer.CreateFormFile("chunk", "chunk.bin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("uploadSendChunk: create form file: %v", err)
|
||||||
|
}
|
||||||
|
part.Write(chunkData)
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("PUT", url, &buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("uploadSendChunk: new request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("uploadSendChunk: do request: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("uploadSendChunk: expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Upload_Init(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileData := []byte("integration test upload init content")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
session := uploadInitSession(t, env, "init_test.txt", int64(len(fileData)), sha256Hash)
|
||||||
|
|
||||||
|
if session.ID <= 0 {
|
||||||
|
t.Fatalf("expected positive session ID, got %d", session.ID)
|
||||||
|
}
|
||||||
|
if session.FileName != "init_test.txt" {
|
||||||
|
t.Fatalf("expected file_name init_test.txt, got %s", session.FileName)
|
||||||
|
}
|
||||||
|
if session.Status != "pending" {
|
||||||
|
t.Fatalf("expected status pending, got %s", session.Status)
|
||||||
|
}
|
||||||
|
if session.TotalChunks != 1 {
|
||||||
|
t.Fatalf("expected 1 chunk for small file, got %d", session.TotalChunks)
|
||||||
|
}
|
||||||
|
if session.FileSize != int64(len(fileData)) {
|
||||||
|
t.Fatalf("expected file_size %d, got %d", len(fileData), session.FileSize)
|
||||||
|
}
|
||||||
|
if session.SHA256 != sha256Hash {
|
||||||
|
t.Fatalf("expected sha256 %s, got %s", sha256Hash, session.SHA256)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Upload_Status(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileData := []byte("integration test status content")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
session := uploadInitSession(t, env, "status_test.txt", int64(len(fileData)), sha256Hash)
|
||||||
|
|
||||||
|
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := uploadDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(r.Data, &status); err != nil {
|
||||||
|
t.Fatalf("unmarshal status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.ID != session.ID {
|
||||||
|
t.Fatalf("expected session ID %d, got %d", session.ID, status.ID)
|
||||||
|
}
|
||||||
|
if status.Status != "pending" {
|
||||||
|
t.Fatalf("expected status pending, got %s", status.Status)
|
||||||
|
}
|
||||||
|
if status.FileName != "status_test.txt" {
|
||||||
|
t.Fatalf("expected file_name status_test.txt, got %s", status.FileName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Upload_Chunk(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileData := []byte("integration test chunk upload data")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
session := uploadInitSession(t, env, "chunk_test.txt", int64(len(fileData)), sha256Hash)
|
||||||
|
|
||||||
|
uploadSendChunk(t, env, session.ID, 0, fileData)
|
||||||
|
|
||||||
|
// Verify chunk appears in uploaded_chunks via status endpoint
|
||||||
|
resp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
r := uploadDecode(t, resp.Body)
|
||||||
|
var status model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(r.Data, &status); err != nil {
|
||||||
|
t.Fatalf("unmarshal status after chunk: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(status.UploadedChunks) != 1 {
|
||||||
|
t.Fatalf("expected 1 uploaded chunk, got %d", len(status.UploadedChunks))
|
||||||
|
}
|
||||||
|
if status.UploadedChunks[0] != 0 {
|
||||||
|
t.Fatalf("expected uploaded chunk index 0, got %d", status.UploadedChunks[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Upload_Complete(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileData := []byte("integration test complete upload data")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
session := uploadInitSession(t, env, "complete_test.txt", int64(len(fileData)), sha256Hash)
|
||||||
|
|
||||||
|
// Upload all chunks
|
||||||
|
for i := 0; i < session.TotalChunks; i++ {
|
||||||
|
uploadSendChunk(t, env, session.ID, i, fileData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete upload
|
||||||
|
resp := env.DoRequest("POST", fmt.Sprintf("/api/v1/files/uploads/%d/complete", session.ID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
r := uploadDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("complete response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileResp model.FileResponse
|
||||||
|
if err := json.Unmarshal(r.Data, &fileResp); err != nil {
|
||||||
|
t.Fatalf("unmarshal file response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fileResp.ID <= 0 {
|
||||||
|
t.Fatalf("expected positive file ID, got %d", fileResp.ID)
|
||||||
|
}
|
||||||
|
if fileResp.Name != "complete_test.txt" {
|
||||||
|
t.Fatalf("expected name complete_test.txt, got %s", fileResp.Name)
|
||||||
|
}
|
||||||
|
if fileResp.Size != int64(len(fileData)) {
|
||||||
|
t.Fatalf("expected size %d, got %d", len(fileData), fileResp.Size)
|
||||||
|
}
|
||||||
|
if fileResp.SHA256 != sha256Hash {
|
||||||
|
t.Fatalf("expected sha256 %s, got %s", sha256Hash, fileResp.SHA256)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Upload_Cancel(t *testing.T) {
|
||||||
|
env := testenv.NewTestEnv(t)
|
||||||
|
|
||||||
|
fileData := []byte("integration test cancel upload data")
|
||||||
|
h := sha256.Sum256(fileData)
|
||||||
|
sha256Hash := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
session := uploadInitSession(t, env, "cancel_test.txt", int64(len(fileData)), sha256Hash)
|
||||||
|
|
||||||
|
// Cancel the upload
|
||||||
|
resp := env.DoRequest("DELETE", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
r := uploadDecode(t, resp.Body)
|
||||||
|
if !r.Success {
|
||||||
|
t.Fatalf("cancel response not success: %s", r.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session is no longer in pending state by checking status
|
||||||
|
statusResp := env.DoRequest("GET", fmt.Sprintf("/api/v1/files/uploads/%d", session.ID), nil)
|
||||||
|
defer statusResp.Body.Close()
|
||||||
|
|
||||||
|
sr := uploadDecode(t, statusResp.Body)
|
||||||
|
if sr.Success {
|
||||||
|
var status model.UploadSessionResponse
|
||||||
|
if err := json.Unmarshal(sr.Data, &status); err == nil {
|
||||||
|
if status.Status == "pending" {
|
||||||
|
t.Fatal("expected status to not be pending after cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ import (
|
|||||||
|
|
||||||
func newTestDB() *gorm.DB {
|
func newTestDB() *gorm.DB {
|
||||||
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
db.AutoMigrate(&model.JobTemplate{})
|
db.AutoMigrate(&model.Application{})
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,12 +34,17 @@ func TestRouterRegistration(t *testing.T) {
|
|||||||
defer slurmSrv.Close()
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
||||||
templateStore := store.NewTemplateStore(newTestDB())
|
jobSvc := service.NewJobService(client, zap.NewNop())
|
||||||
|
appStore := store.NewApplicationStore(newTestDB())
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
|
||||||
|
appH := handler.NewApplicationHandler(appSvc, zap.NewNop())
|
||||||
|
|
||||||
router := server.NewRouter(
|
router := server.NewRouter(
|
||||||
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
|
handler.NewJobHandler(jobSvc, zap.NewNop()),
|
||||||
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
|
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
|
||||||
handler.NewTemplateHandler(templateStore, zap.NewNop()),
|
appH,
|
||||||
|
nil, nil, nil,
|
||||||
|
nil,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,11 +63,12 @@ func TestRouterRegistration(t *testing.T) {
|
|||||||
{"GET", "/api/v1/partitions"},
|
{"GET", "/api/v1/partitions"},
|
||||||
{"GET", "/api/v1/partitions/:name"},
|
{"GET", "/api/v1/partitions/:name"},
|
||||||
{"GET", "/api/v1/diag"},
|
{"GET", "/api/v1/diag"},
|
||||||
{"GET", "/api/v1/templates"},
|
{"GET", "/api/v1/applications"},
|
||||||
{"POST", "/api/v1/templates"},
|
{"POST", "/api/v1/applications"},
|
||||||
{"GET", "/api/v1/templates/:id"},
|
{"GET", "/api/v1/applications/:id"},
|
||||||
{"PUT", "/api/v1/templates/:id"},
|
{"PUT", "/api/v1/applications/:id"},
|
||||||
{"DELETE", "/api/v1/templates/:id"},
|
{"DELETE", "/api/v1/applications/:id"},
|
||||||
|
// {"POST", "/api/v1/applications/:id/submit"}, // [已禁用] 已被 POST /tasks 取代
|
||||||
}
|
}
|
||||||
|
|
||||||
routeMap := map[string]bool{}
|
routeMap := map[string]bool{}
|
||||||
@@ -90,12 +96,17 @@ func TestSmokeGetJobsEndpoint(t *testing.T) {
|
|||||||
defer slurmSrv.Close()
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
client, _ := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
||||||
templateStore := store.NewTemplateStore(newTestDB())
|
jobSvc := service.NewJobService(client, zap.NewNop())
|
||||||
|
appStore := store.NewApplicationStore(newTestDB())
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
|
||||||
|
appH := handler.NewApplicationHandler(appSvc, zap.NewNop())
|
||||||
|
|
||||||
router := server.NewRouter(
|
router := server.NewRouter(
|
||||||
handler.NewJobHandler(service.NewJobService(client, zap.NewNop()), zap.NewNop()),
|
handler.NewJobHandler(jobSvc, zap.NewNop()),
|
||||||
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
|
handler.NewClusterHandler(service.NewClusterService(client, zap.NewNop()), zap.NewNop()),
|
||||||
handler.NewTemplateHandler(templateStore, zap.NewNop()),
|
appH,
|
||||||
|
nil, nil, nil,
|
||||||
|
nil,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
434
cmd/server/task_test.go
Normal file
434
cmd/server/task_test.go
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/handler"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTaskTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
sqlDB, _ := db.DB()
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
db.AutoMigrate(
|
||||||
|
&model.Application{},
|
||||||
|
&model.File{},
|
||||||
|
&model.FileBlob{},
|
||||||
|
&model.Task{},
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
sqlDB.Close()
|
||||||
|
})
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockSlurmHandler struct {
|
||||||
|
submitFn func(w http.ResponseWriter, r *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSlurmServer(t *testing.T, h *mockSlurmHandler) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/job/submit", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if h != nil && h.submitFn != nil {
|
||||||
|
h.submitFn(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
fmt.Fprint(w, `{"job_id": 42, "step_id": "0", "result": {"job_id": 42, "step_id": "0"}}`)
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"jobs": []interface{}{}})
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
srv := httptest.NewServer(mux)
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTaskTestServer(t *testing.T) (*httptest.Server, *store.TaskStore, func()) {
|
||||||
|
t.Helper()
|
||||||
|
return setupTaskTestServerWithSlurm(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTaskTestServerWithSlurm(t *testing.T, slurmHandler *mockSlurmHandler) (*httptest.Server, *store.TaskStore, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
slurmSrv := newMockSlurmServer(t, slurmHandler)
|
||||||
|
|
||||||
|
client, err := slurm.NewClientWithOpts(slurmSrv.URL, slurm.WithHTTPClient(slurmSrv.Client()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("slurm client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log := zap.NewNop()
|
||||||
|
jobSvc := service.NewJobService(client, log)
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "task-test-workdir-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("temp dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskSvc := service.NewTaskService(
|
||||||
|
taskStore, appStore,
|
||||||
|
nil, nil, nil,
|
||||||
|
jobSvc,
|
||||||
|
tmpDir,
|
||||||
|
log,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, tmpDir, log, taskSvc)
|
||||||
|
appH := handler.NewApplicationHandler(appSvc, log)
|
||||||
|
taskH := handler.NewTaskHandler(taskSvc, log)
|
||||||
|
|
||||||
|
router := server.NewRouter(
|
||||||
|
handler.NewJobHandler(jobSvc, log),
|
||||||
|
handler.NewClusterHandler(service.NewClusterService(client, log), log),
|
||||||
|
appH,
|
||||||
|
nil, nil, nil,
|
||||||
|
taskH,
|
||||||
|
log,
|
||||||
|
)
|
||||||
|
|
||||||
|
httpSrv := httptest.NewServer(router)
|
||||||
|
cleanup := func() {
|
||||||
|
taskSvc.StopProcessor()
|
||||||
|
cancel()
|
||||||
|
httpSrv.Close()
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
return httpSrv, taskStore, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestApp(t *testing.T, srvURL string) int64 {
|
||||||
|
t.Helper()
|
||||||
|
body := `{"name":"test-app","script_template":"#!/bin/bash\necho {{.np}}","parameters":[{"name":"np","type":"string","default":"1"}]}`
|
||||||
|
resp, err := http.Post(srvURL+"/api/v1/applications", "application/json", bytes.NewReader([]byte(body)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("create app: status %d, body %s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
var result struct {
|
||||||
|
Data struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("decode app response: %v", err)
|
||||||
|
}
|
||||||
|
return result.Data.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func postTask(t *testing.T, srvURL string, body string) (*http.Response, map[string]interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
resp, err := http.Post(srvURL+"/api/v1/tasks", "application/json", bytes.NewReader([]byte(body)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("post task: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.Unmarshal(b, &result)
|
||||||
|
return resp, result
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTasks(t *testing.T, srvURL string, query string) (int, []interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
url := srvURL + "/api/v1/tasks"
|
||||||
|
if query != "" {
|
||||||
|
url += "?" + query
|
||||||
|
}
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get tasks: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.Unmarshal(b, &result)
|
||||||
|
data, _ := result["data"].(map[string]interface{})
|
||||||
|
items, _ := data["items"].([]interface{})
|
||||||
|
total, _ := data["total"].(float64)
|
||||||
|
return int(total), items
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_FullLifecycle(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, fmt.Sprintf(
|
||||||
|
`{"app_id":%d,"task_name":"lifecycle-test","values":{"np":"4"}}`, appID,
|
||||||
|
))
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := result["data"].(map[string]interface{})
|
||||||
|
if data["id"] == nil {
|
||||||
|
t.Fatal("expected id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
total, items := getTasks(t, srv.URL, "")
|
||||||
|
if total < 1 {
|
||||||
|
t.Fatalf("expected at least 1 task, got %d", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
task := items[0].(map[string]interface{})
|
||||||
|
if task["status"] == nil || task["status"] == "" {
|
||||||
|
t.Error("expected non-empty status")
|
||||||
|
}
|
||||||
|
if task["task_name"] != "lifecycle-test" {
|
||||||
|
t.Errorf("expected task_name=lifecycle-test, got %v", task["task_name"])
|
||||||
|
}
|
||||||
|
if task["app_id"] != float64(appID) {
|
||||||
|
t.Errorf("expected app_id=%d, got %v", appID, task["app_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_CreateWithMissingApp(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, `{"app_id":9999,"task_name":"no-app"}`)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_CreateWithInvalidBody(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, `{}`)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_FileLimitExceeded(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
fileIDs := make([]int64, 101)
|
||||||
|
for i := range fileIDs {
|
||||||
|
fileIDs[i] = int64(i + 1)
|
||||||
|
}
|
||||||
|
idsJSON, _ := json.Marshal(fileIDs)
|
||||||
|
|
||||||
|
body := fmt.Sprintf(`{"app_id":%d,"file_ids":%s}`, appID, string(idsJSON))
|
||||||
|
resp, result := postTask(t, srv.URL, body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400 for file limit, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_RetryScenario(t *testing.T) {
|
||||||
|
var failCount int32
|
||||||
|
slurmH := &mockSlurmHandler{
|
||||||
|
submitFn: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if atomic.AddInt32(&failCount, 1) <= 2 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, `{"errors":[{"error":"internal error"}]}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
fmt.Fprint(w, `{"job_id": 99, "step_id": "0", "result": {"job_id": 99, "step_id": "0"}}`)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, taskStore, cleanup := setupTaskTestServerWithSlurm(t, slurmH)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, fmt.Sprintf(
|
||||||
|
`{"app_id":%d,"task_name":"retry-test","values":{"np":"2"}}`, appID,
|
||||||
|
))
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
|
||||||
|
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task != nil && task.Status == model.TaskStatusQueued {
|
||||||
|
if task.SlurmJobID != nil && *task.SlurmJobID == 99 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task == nil {
|
||||||
|
t.Fatalf("task %d not found after deadline", taskID)
|
||||||
|
}
|
||||||
|
t.Fatalf("task did not reach queued with slurm_job_id=99; status=%s retry_count=%d slurm_job_id=%v",
|
||||||
|
task.Status, task.RetryCount, task.SlurmJobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// [已禁用] 前端已全部迁移到 POST /tasks 接口,旧 API 兼容性测试不再需要。
|
||||||
|
/*
|
||||||
|
func TestTask_OldAPICompatibility(t *testing.T) {
|
||||||
|
srv, _, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
body := `{"values":{"np":"8"}}`
|
||||||
|
url := fmt.Sprintf("%s/api/v1/applications/%d/submit", srv.URL, appID)
|
||||||
|
resp, err := http.Post(url, "application/json", bytes.NewReader([]byte(body)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("post submit: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 201, got %d: %s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
data, _ := result["data"].(map[string]interface{})
|
||||||
|
if data == nil {
|
||||||
|
t.Fatal("expected data in response")
|
||||||
|
}
|
||||||
|
if data["job_id"] == nil {
|
||||||
|
t.Errorf("expected job_id in old API response, got: %v", data)
|
||||||
|
}
|
||||||
|
jobID, ok := data["job_id"].(float64)
|
||||||
|
if !ok || jobID == 0 {
|
||||||
|
t.Errorf("expected non-zero job_id, got %v", data["job_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestTask_ListWithFilters(t *testing.T) {
|
||||||
|
srv, taskStore, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
taskStore.Create(context.Background(), &model.Task{
|
||||||
|
TaskName: "task-completed", AppID: appID, AppName: "test-app",
|
||||||
|
Status: model.TaskStatusCompleted, SubmittedAt: now,
|
||||||
|
})
|
||||||
|
taskStore.Create(context.Background(), &model.Task{
|
||||||
|
TaskName: "task-failed", AppID: appID, AppName: "test-app",
|
||||||
|
Status: model.TaskStatusFailed, ErrorMessage: "boom", SubmittedAt: now,
|
||||||
|
})
|
||||||
|
taskStore.Create(context.Background(), &model.Task{
|
||||||
|
TaskName: "task-queued", AppID: appID, AppName: "test-app",
|
||||||
|
Status: model.TaskStatusQueued, SubmittedAt: now,
|
||||||
|
})
|
||||||
|
|
||||||
|
total, items := getTasks(t, srv.URL, "status=completed")
|
||||||
|
if total != 1 {
|
||||||
|
t.Fatalf("expected 1 completed, got %d", total)
|
||||||
|
}
|
||||||
|
task := items[0].(map[string]interface{})
|
||||||
|
if task["status"] != "completed" {
|
||||||
|
t.Errorf("expected status=completed, got %v", task["status"])
|
||||||
|
}
|
||||||
|
|
||||||
|
total, items = getTasks(t, srv.URL, "status=failed")
|
||||||
|
if total != 1 {
|
||||||
|
t.Fatalf("expected 1 failed, got %d", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, _ = getTasks(t, srv.URL, "")
|
||||||
|
if total != 3 {
|
||||||
|
t.Fatalf("expected 3 total, got %d", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_WorkDirCreated(t *testing.T) {
|
||||||
|
srv, taskStore, cleanup := setupTaskTestServer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
appID := createTestApp(t, srv.URL)
|
||||||
|
|
||||||
|
resp, result := postTask(t, srv.URL, fmt.Sprintf(
|
||||||
|
`{"app_id":%d,"task_name":"workdir-test","values":{"np":"1"}}`, appID,
|
||||||
|
))
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID := int64(result["data"].(map[string]interface{})["id"].(float64))
|
||||||
|
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task != nil && task.WorkDir != "" {
|
||||||
|
if _, err := os.Stat(task.WorkDir); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("work dir %s not created", task.WorkDir)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(task.WorkDir) {
|
||||||
|
t.Errorf("expected absolute work dir, got %s", task.WorkDir)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
task, _ := taskStore.GetByID(context.Background(), taskID)
|
||||||
|
if task == nil {
|
||||||
|
t.Fatalf("task %d not found after deadline", taskID)
|
||||||
|
}
|
||||||
|
t.Fatalf("work dir never set; status=%s workdir=%s", task.Status, task.WorkDir)
|
||||||
|
}
|
||||||
12
go.mod
12
go.mod
@@ -19,31 +19,43 @@ require (
|
|||||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
|
github.com/go-ini/ini v1.67.0 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||||
github.com/goccy/go-json v0.10.5 // indirect
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/compress v1.18.2 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/klauspost/crc32 v1.3.0 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
|
github.com/minio/crc64nvme v1.1.1 // indirect
|
||||||
|
github.com/minio/md5-simd v1.1.2 // indirect
|
||||||
|
github.com/minio/minio-go/v7 v7.0.100 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/philhofer/fwd v1.2.0 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||||
|
github.com/rs/xid v1.6.0 // indirect
|
||||||
|
github.com/tinylib/msgp v1.6.1 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||||
go.uber.org/multierr v1.10.0 // indirect
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/crypto v0.48.0 // indirect
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
golang.org/x/net v0.51.0 // indirect
|
golang.org/x/net v0.51.0 // indirect
|
||||||
|
|||||||
25
go.sum
25
go.sum
@@ -12,12 +12,16 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||||
|
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||||
|
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
@@ -35,14 +39,21 @@ github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk
|
|||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||||
|
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
|
github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM=
|
||||||
|
github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
@@ -53,6 +64,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
|||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI=
|
||||||
|
github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
|
||||||
|
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||||
|
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||||
|
github.com/minio/minio-go/v7 v7.0.100 h1:ShkWi8Tyj9RtU57OQB2HIXKz4bFgtVib0bbT1sbtLI8=
|
||||||
|
github.com/minio/minio-go/v7 v7.0.100/go.mod h1:EtGNKtlX20iL2yaYnxEigaIvj0G0GwSDnifnG8ClIdw=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -60,6 +77,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
|||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
|
||||||
|
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -69,6 +88,8 @@ github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SA
|
|||||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
@@ -80,6 +101,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
|||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY=
|
||||||
|
github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||||
@@ -94,6 +117,8 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
|||||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
|||||||
2954
hpc_server_openapi.json
Normal file
2954
hpc_server_openapi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,18 +15,21 @@ import (
|
|||||||
"gcy_hpc_server/internal/server"
|
"gcy_hpc_server/internal/server"
|
||||||
"gcy_hpc_server/internal/service"
|
"gcy_hpc_server/internal/service"
|
||||||
"gcy_hpc_server/internal/slurm"
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
"gcy_hpc_server/internal/store"
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// App encapsulates the entire application lifecycle.
|
|
||||||
type App struct {
|
type App struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
server *http.Server
|
server *http.Server
|
||||||
|
cancelCleanup context.CancelFunc
|
||||||
|
taskSvc *service.TaskService
|
||||||
|
taskPoller *TaskPoller
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
|
// NewApp initializes all application dependencies: DB, Slurm client, services, handlers, router.
|
||||||
@@ -42,13 +45,16 @@ func NewApp(cfg *config.Config, logger *zap.Logger) (*App, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := initHTTPServer(cfg, gormDB, slurmClient, logger)
|
srv, cancelCleanup, taskSvc, taskPoller := initHTTPServer(cfg, gormDB, slurmClient, logger)
|
||||||
|
|
||||||
return &App{
|
return &App{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
db: gormDB,
|
db: gormDB,
|
||||||
server: srv,
|
server: srv,
|
||||||
|
cancelCleanup: cancelCleanup,
|
||||||
|
taskSvc: taskSvc,
|
||||||
|
taskPoller: taskPoller,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,6 +90,18 @@ func (a *App) Run() error {
|
|||||||
func (a *App) Close() error {
|
func (a *App) Close() error {
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
|
if a.taskSvc != nil {
|
||||||
|
a.taskSvc.StopProcessor()
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.taskPoller != nil {
|
||||||
|
a.taskPoller.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.cancelCleanup != nil {
|
||||||
|
a.cancelCleanup()
|
||||||
|
}
|
||||||
|
|
||||||
if a.server != nil {
|
if a.server != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -139,21 +157,71 @@ func initSlurmClient(cfg *config.Config) (*slurm.Client, error) {
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) *http.Server {
|
func initHTTPServer(cfg *config.Config, db *gorm.DB, slurmClient *slurm.Client, logger *zap.Logger) (*http.Server, context.CancelFunc, *service.TaskService, *TaskPoller) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
jobSvc := service.NewJobService(slurmClient, logger)
|
jobSvc := service.NewJobService(slurmClient, logger)
|
||||||
clusterSvc := service.NewClusterService(slurmClient, logger)
|
clusterSvc := service.NewClusterService(slurmClient, logger)
|
||||||
templateStore := store.NewTemplateStore(db)
|
|
||||||
|
|
||||||
jobH := handler.NewJobHandler(jobSvc, logger)
|
jobH := handler.NewJobHandler(jobSvc, logger)
|
||||||
clusterH := handler.NewClusterHandler(clusterSvc, logger)
|
clusterH := handler.NewClusterHandler(clusterSvc, logger)
|
||||||
templateH := handler.NewTemplateHandler(templateStore, logger)
|
|
||||||
|
|
||||||
router := server.NewRouter(jobH, clusterH, templateH, logger)
|
appStore := store.NewApplicationStore(db)
|
||||||
|
|
||||||
|
// File storage initialization
|
||||||
|
minioClient, err := storage.NewMinioClient(cfg.Minio)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("failed to initialize MinIO client, file storage disabled", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
var uploadH *handler.UploadHandler
|
||||||
|
var fileH *handler.FileHandler
|
||||||
|
var folderH *handler.FolderHandler
|
||||||
|
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
fileStore := store.NewFileStore(db)
|
||||||
|
blobStore := store.NewBlobStore(db)
|
||||||
|
|
||||||
|
var stagingSvc *service.FileStagingService
|
||||||
|
if minioClient != nil {
|
||||||
|
folderStore := store.NewFolderStore(db)
|
||||||
|
uploadStore := store.NewUploadStore(db)
|
||||||
|
|
||||||
|
uploadSvc := service.NewUploadService(minioClient, blobStore, fileStore, uploadStore, cfg.Minio, db, logger)
|
||||||
|
folderSvc := service.NewFolderService(folderStore, fileStore, logger)
|
||||||
|
fileSvc := service.NewFileService(minioClient, blobStore, fileStore, cfg.Minio.Bucket, db, logger)
|
||||||
|
|
||||||
|
uploadH = handler.NewUploadHandler(uploadSvc, logger)
|
||||||
|
fileH = handler.NewFileHandler(fileSvc, logger)
|
||||||
|
folderH = handler.NewFolderHandler(folderSvc, logger)
|
||||||
|
|
||||||
|
stagingSvc = service.NewFileStagingService(fileStore, blobStore, minioClient, cfg.Minio.Bucket, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskSvc := service.NewTaskService(taskStore, appStore, fileStore, blobStore, stagingSvc, jobSvc, cfg.WorkDirBase, logger)
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, cfg.WorkDirBase, logger, taskSvc)
|
||||||
|
appH := handler.NewApplicationHandler(appSvc, logger)
|
||||||
|
|
||||||
|
poller := NewTaskPoller(taskSvc, 10*time.Second, logger)
|
||||||
|
poller.Start(ctx)
|
||||||
|
|
||||||
|
taskH := handler.NewTaskHandler(taskSvc, logger)
|
||||||
|
|
||||||
|
var cancelCleanup context.CancelFunc
|
||||||
|
|
||||||
|
if minioClient != nil {
|
||||||
|
cleanupCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancelCleanup = cancel
|
||||||
|
go startCleanupWorker(cleanupCtx, store.NewUploadStore(db), minioClient, cfg.Minio.Bucket, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
router := server.NewRouter(jobH, clusterH, appH, uploadH, fileH, folderH, taskH, logger)
|
||||||
|
|
||||||
addr := ":" + cfg.ServerPort
|
addr := ":" + cfg.ServerPort
|
||||||
|
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: router,
|
Handler: router,
|
||||||
}
|
}, cancelCleanup, taskSvc, poller
|
||||||
}
|
}
|
||||||
|
|||||||
83
internal/app/cleanup.go
Normal file
83
internal/app/cleanup.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startCleanupWorker runs a background goroutine that periodically cleans up:
|
||||||
|
// 1. Expired upload sessions (mark → delete MinIO chunks → delete DB records)
|
||||||
|
// 2. Leaked multipart uploads from failed ComposeObject calls
|
||||||
|
func startCleanupWorker(ctx context.Context, uploadStore *store.UploadStore, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
|
||||||
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
cleanupExpiredSessions(ctx, uploadStore, objStorage, bucket, logger)
|
||||||
|
cleanupLeakedMultipartUploads(ctx, objStorage, bucket, logger)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Info("cleanup worker stopped")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cleanupExpiredSessions(ctx, uploadStore, objStorage, bucket, logger)
|
||||||
|
cleanupLeakedMultipartUploads(ctx, objStorage, bucket, logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpiredSessions performs three-phase cleanup of expired upload sessions:
|
||||||
|
// Phase 1: Find and mark expired sessions
|
||||||
|
// Phase 2: Delete MinIO temp chunks for each session
|
||||||
|
// Phase 3: Delete DB records (session + chunks)
|
||||||
|
func cleanupExpiredSessions(ctx context.Context, uploadStore *store.UploadStore, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
|
||||||
|
sessions, err := uploadStore.ListExpiredSessions(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to list expired sessions", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sessions) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("cleaning up expired sessions", zap.Int("count", len(sessions)))
|
||||||
|
|
||||||
|
for i := range sessions {
|
||||||
|
session := &sessions[i]
|
||||||
|
|
||||||
|
if err := uploadStore.UpdateSessionStatus(ctx, session.ID, "expired"); err != nil {
|
||||||
|
logger.Error("failed to mark session expired", zap.Int64("session_id", session.ID), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
objects, err := objStorage.ListObjects(ctx, bucket, session.MinioPrefix, true)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to list session objects", zap.Int64("session_id", session.ID), zap.Error(err))
|
||||||
|
} else if len(objects) > 0 {
|
||||||
|
keys := make([]string, len(objects))
|
||||||
|
for j, obj := range objects {
|
||||||
|
keys[j] = obj.Key
|
||||||
|
}
|
||||||
|
if err := objStorage.RemoveObjects(ctx, bucket, keys, storage.RemoveObjectsOptions{}); err != nil {
|
||||||
|
logger.Error("failed to remove session objects", zap.Int64("session_id", session.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := uploadStore.DeleteSession(ctx, session.ID); err != nil {
|
||||||
|
logger.Error("failed to delete session", zap.Int64("session_id", session.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLeakedMultipartUploads(ctx context.Context, objStorage storage.ObjectStorage, bucket string, logger *zap.Logger) {
|
||||||
|
if err := objStorage.RemoveIncompleteUpload(ctx, bucket, "uploads/"); err != nil {
|
||||||
|
logger.Error("failed to cleanup leaked multipart uploads", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
266
internal/app/cleanup_test.go
Normal file
266
internal/app/cleanup_test.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockCleanupStorage implements ObjectStorage for cleanup tests.
|
||||||
|
type mockCleanupStorage struct {
|
||||||
|
listObjectsFn func(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error)
|
||||||
|
removeObjectsFn func(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error
|
||||||
|
removeIncompleteFn func(ctx context.Context, bucket, object string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) PutObject(_ context.Context, _ string, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) GetObject(_ context.Context, _ string, _ string, _ storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
return nil, storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) AbortMultipartUpload(_ context.Context, _ string, _ string, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) RemoveIncompleteUpload(_ context.Context, _ string, _ string) error {
|
||||||
|
if m.removeIncompleteFn != nil {
|
||||||
|
return m.removeIncompleteFn(context.Background(), "", "")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) RemoveObject(_ context.Context, _ string, _ string, _ storage.RemoveObjectOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) ListObjects(ctx context.Context, bucket string, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
|
||||||
|
if m.listObjectsFn != nil {
|
||||||
|
return m.listObjectsFn(ctx, bucket, prefix, recursive)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
|
||||||
|
if m.removeObjectsFn != nil {
|
||||||
|
return m.removeObjectsFn(ctx, bucket, keys, opts)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) BucketExists(_ context.Context, _ string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCleanupStorage) StatObject(_ context.Context, _ string, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
return storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupCleanupTestDB(t *testing.T) (*gorm.DB, *store.UploadStore) {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.UploadSession{}, &model.UploadChunk{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db, store.NewUploadStore(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredSessions(t *testing.T) {
|
||||||
|
_, uploadStore := setupCleanupTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
past := time.Now().Add(-1 * time.Hour)
|
||||||
|
err := uploadStore.CreateSession(ctx, &model.UploadSession{
|
||||||
|
FileName: "expired.bin",
|
||||||
|
FileSize: 1024,
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
TotalChunks: 1,
|
||||||
|
SHA256: "expired_hash",
|
||||||
|
Status: "pending",
|
||||||
|
MinioPrefix: "uploads/99/",
|
||||||
|
ExpiresAt: past,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var listedPrefix string
|
||||||
|
var removedKeys []string
|
||||||
|
|
||||||
|
mockStore := &mockCleanupStorage{
|
||||||
|
listObjectsFn: func(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
listedPrefix = prefix
|
||||||
|
return []storage.ObjectInfo{{Key: "uploads/99/chunk_00000", Size: 100}}, nil
|
||||||
|
},
|
||||||
|
removeObjectsFn: func(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
|
||||||
|
removedKeys = keys
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
|
||||||
|
|
||||||
|
if listedPrefix != "uploads/99/" {
|
||||||
|
t.Errorf("listed prefix = %q, want %q", listedPrefix, "uploads/99/")
|
||||||
|
}
|
||||||
|
if len(removedKeys) != 1 || removedKeys[0] != "uploads/99/chunk_00000" {
|
||||||
|
t.Errorf("removed keys = %v, want [uploads/99/chunk_00000]", removedKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := uploadStore.GetSession(ctx, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get session: %v", err)
|
||||||
|
}
|
||||||
|
if session != nil {
|
||||||
|
t.Error("session should be deleted after cleanup")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredSessions_Empty(t *testing.T) {
|
||||||
|
_, uploadStore := setupCleanupTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
called := false
|
||||||
|
mockStore := &mockCleanupStorage{
|
||||||
|
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
called = true
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("ListObjects should not be called when no expired sessions exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredSessions_CompletedNotCleaned(t *testing.T) {
|
||||||
|
_, uploadStore := setupCleanupTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
past := time.Now().Add(-1 * time.Hour)
|
||||||
|
err := uploadStore.CreateSession(ctx, &model.UploadSession{
|
||||||
|
FileName: "completed.bin",
|
||||||
|
FileSize: 1024,
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
TotalChunks: 1,
|
||||||
|
SHA256: "completed_hash",
|
||||||
|
Status: "completed",
|
||||||
|
MinioPrefix: "uploads/100/",
|
||||||
|
ExpiresAt: past,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
mockStore := &mockCleanupStorage{
|
||||||
|
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
called = true
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("ListObjects should not be called for completed sessions")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := uploadStore.GetSession(ctx, 1)
|
||||||
|
if session == nil {
|
||||||
|
t.Error("completed session should not be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupWorker_StopsOnContextCancel(t *testing.T) {
|
||||||
|
_, uploadStore := setupCleanupTestDB(t)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
mockStore := &mockCleanupStorage{}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
startCleanupWorker(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("worker did not stop after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredSessions_NoObjects(t *testing.T) {
|
||||||
|
_, uploadStore := setupCleanupTestDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
past := time.Now().Add(-1 * time.Hour)
|
||||||
|
err := uploadStore.CreateSession(ctx, &model.UploadSession{
|
||||||
|
FileName: "empty.bin",
|
||||||
|
FileSize: 1024,
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
TotalChunks: 1,
|
||||||
|
SHA256: "empty_hash",
|
||||||
|
Status: "pending",
|
||||||
|
MinioPrefix: "uploads/200/",
|
||||||
|
ExpiresAt: past,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listCalled := false
|
||||||
|
removeCalled := false
|
||||||
|
mockStore := &mockCleanupStorage{
|
||||||
|
listObjectsFn: func(_ context.Context, _, _ string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
listCalled = true
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
removeObjectsFn: func(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
|
||||||
|
removeCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupExpiredSessions(ctx, uploadStore, mockStore, "test-bucket", zap.NewNop())
|
||||||
|
|
||||||
|
if !listCalled {
|
||||||
|
t.Error("ListObjects should be called")
|
||||||
|
}
|
||||||
|
if removeCalled {
|
||||||
|
t.Error("RemoveObjects should not be called when no objects found")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := uploadStore.GetSession(ctx, 1)
|
||||||
|
if session != nil {
|
||||||
|
t.Error("session should be deleted even when no objects found")
|
||||||
|
}
|
||||||
|
}
|
||||||
61
internal/app/task_poller.go
Normal file
61
internal/app/task_poller.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskPollable defines the interface for refreshing stale task statuses.
|
||||||
|
type TaskPollable interface {
|
||||||
|
RefreshStaleTasks(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskPoller periodically polls Slurm for task status updates via TaskPollable.
|
||||||
|
type TaskPoller struct {
|
||||||
|
taskSvc TaskPollable
|
||||||
|
interval time.Duration
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTaskPoller creates a new TaskPoller with the given service, interval, and logger.
|
||||||
|
func NewTaskPoller(taskSvc TaskPollable, interval time.Duration, logger *zap.Logger) *TaskPoller {
|
||||||
|
return &TaskPoller{
|
||||||
|
taskSvc: taskSvc,
|
||||||
|
interval: interval,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start launches the background goroutine that periodically refreshes stale tasks.
|
||||||
|
func (p *TaskPoller) Start(ctx context.Context) {
|
||||||
|
ctx, p.cancel = context.WithCancel(ctx)
|
||||||
|
p.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer p.wg.Done()
|
||||||
|
ticker := time.NewTicker(p.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := p.taskSvc.RefreshStaleTasks(ctx); err != nil {
|
||||||
|
p.logger.Error("failed to refresh stale tasks", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop cancels the background goroutine and waits for it to finish.
|
||||||
|
func (p *TaskPoller) Stop() {
|
||||||
|
if p.cancel != nil {
|
||||||
|
p.cancel()
|
||||||
|
}
|
||||||
|
p.wg.Wait()
|
||||||
|
}
|
||||||
70
internal/app/task_poller_test.go
Normal file
70
internal/app/task_poller_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockTaskPollable struct {
|
||||||
|
refreshFunc func(ctx context.Context) error
|
||||||
|
callCount int
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTaskPollable) RefreshStaleTasks(ctx context.Context) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.callCount++
|
||||||
|
if m.refreshFunc != nil {
|
||||||
|
return m.refreshFunc(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTaskPollable) getCallCount() int {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.callCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPoller_StartStop(t *testing.T) {
|
||||||
|
mock := &mockTaskPollable{}
|
||||||
|
logger := zap.NewNop()
|
||||||
|
poller := NewTaskPoller(mock, 1*time.Second, logger)
|
||||||
|
|
||||||
|
poller.Start(context.Background())
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
poller.Stop()
|
||||||
|
|
||||||
|
// No goroutine leak — Stop() returned means wg.Wait() completed.
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPoller_RefreshesStaleTasks(t *testing.T) {
|
||||||
|
mock := &mockTaskPollable{}
|
||||||
|
logger := zap.NewNop()
|
||||||
|
poller := NewTaskPoller(mock, 50*time.Millisecond, logger)
|
||||||
|
|
||||||
|
poller.Start(context.Background())
|
||||||
|
defer poller.Stop()
|
||||||
|
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
if count := mock.getCallCount(); count < 1 {
|
||||||
|
t.Errorf("expected RefreshStaleTasks to be called at least once, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPoller_StopsCleanly(t *testing.T) {
|
||||||
|
mock := &mockTaskPollable{}
|
||||||
|
logger := zap.NewNop()
|
||||||
|
poller := NewTaskPoller(mock, 1*time.Second, logger)
|
||||||
|
|
||||||
|
poller.Start(context.Background())
|
||||||
|
poller.Stop()
|
||||||
|
|
||||||
|
// No panic and WaitGroup is done — Stop returned successfully.
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ slurm_api_url: "http://localhost:6820"
|
|||||||
slurm_user_name: "root"
|
slurm_user_name: "root"
|
||||||
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
||||||
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
||||||
|
work_dir_base: "/mnt/nfs_mount/platform" # 作业工作目录根路径,留空则不自动创建
|
||||||
|
|
||||||
log:
|
log:
|
||||||
level: "info" # debug, info, warn, error
|
level: "info" # debug, info, warn, error
|
||||||
@@ -14,3 +15,14 @@ log:
|
|||||||
max_age: 30 # days to retain old log files
|
max_age: 30 # days to retain old log files
|
||||||
compress: true # gzip rotated log files
|
compress: true # gzip rotated log files
|
||||||
gorm_level: "warn" # GORM SQL log level: silent, error, warn, info
|
gorm_level: "warn" # GORM SQL log level: silent, error, warn, info
|
||||||
|
|
||||||
|
minio:
|
||||||
|
endpoint: "http://fnos.dailz.cn:15001" # MinIO server address
|
||||||
|
access_key: "3dgDu9ncwflLoRQW2OeP" # access key
|
||||||
|
secret_key: "g2GLBNTPxJ9sdFwh37jtfilRSacEO5yQepMkDrnV" # secret key
|
||||||
|
bucket: "test" # bucket name
|
||||||
|
use_ssl: false # use TLS connection
|
||||||
|
chunk_size: 16777216 # upload chunk size in bytes (default: 16MB)
|
||||||
|
max_file_size: 53687091200 # max file size in bytes (default: 50GB)
|
||||||
|
min_chunk_size: 5242880 # minimum chunk size in bytes (default: 5MB)
|
||||||
|
session_ttl: 48 # session TTL in hours (default: 48)
|
||||||
|
|||||||
@@ -20,6 +20,19 @@ type LogConfig struct {
|
|||||||
GormLevel string `yaml:"gorm_level"` // GORM SQL log level (default: warn)
|
GormLevel string `yaml:"gorm_level"` // GORM SQL log level (default: warn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MinioConfig holds MinIO object storage configuration values.
|
||||||
|
type MinioConfig struct {
|
||||||
|
Endpoint string `yaml:"endpoint"` // MinIO server address
|
||||||
|
AccessKey string `yaml:"access_key"` // access key
|
||||||
|
SecretKey string `yaml:"secret_key"` // secret key
|
||||||
|
Bucket string `yaml:"bucket"` // bucket name
|
||||||
|
UseSSL bool `yaml:"use_ssl"` // use TLS connection
|
||||||
|
ChunkSize int64 `yaml:"chunk_size"` // upload chunk size in bytes (default: 16MB)
|
||||||
|
MaxFileSize int64 `yaml:"max_file_size"` // max file size in bytes (default: 50GB)
|
||||||
|
MinChunkSize int64 `yaml:"min_chunk_size"` // minimum chunk size in bytes (default: 5MB)
|
||||||
|
SessionTTL int `yaml:"session_ttl"` // session TTL in hours (default: 48)
|
||||||
|
}
|
||||||
|
|
||||||
// Config holds all application configuration values.
|
// Config holds all application configuration values.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ServerPort string `yaml:"server_port"`
|
ServerPort string `yaml:"server_port"`
|
||||||
@@ -27,7 +40,9 @@ type Config struct {
|
|||||||
SlurmUserName string `yaml:"slurm_user_name"`
|
SlurmUserName string `yaml:"slurm_user_name"`
|
||||||
SlurmJWTKeyPath string `yaml:"slurm_jwt_key_path"`
|
SlurmJWTKeyPath string `yaml:"slurm_jwt_key_path"`
|
||||||
MySQLDSN string `yaml:"mysql_dsn"`
|
MySQLDSN string `yaml:"mysql_dsn"`
|
||||||
|
WorkDirBase string `yaml:"work_dir_base"` // base directory for job work dirs
|
||||||
Log LogConfig `yaml:"log"`
|
Log LogConfig `yaml:"log"`
|
||||||
|
Minio MinioConfig `yaml:"minio"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load reads a YAML configuration file and returns a parsed Config.
|
// Load reads a YAML configuration file and returns a parsed Config.
|
||||||
@@ -47,5 +62,18 @@ func Load(path string) (*Config, error) {
|
|||||||
return nil, fmt.Errorf("parse config file %s: %w", path, err)
|
return nil, fmt.Errorf("parse config file %s: %w", path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Minio.ChunkSize == 0 {
|
||||||
|
cfg.Minio.ChunkSize = 16 << 20 // 16MB
|
||||||
|
}
|
||||||
|
if cfg.Minio.MaxFileSize == 0 {
|
||||||
|
cfg.Minio.MaxFileSize = 50 << 30 // 50GB
|
||||||
|
}
|
||||||
|
if cfg.Minio.MinChunkSize == 0 {
|
||||||
|
cfg.Minio.MinChunkSize = 5 << 20 // 5MB
|
||||||
|
}
|
||||||
|
if cfg.Minio.SessionTTL == 0 {
|
||||||
|
cfg.Minio.SessionTTL = 48
|
||||||
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -254,3 +254,135 @@ log:
|
|||||||
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
|
t.Errorf("Log.FilePath = %q, want %q", cfg.Log.FilePath, "/var/log/app.log")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadWithMinioConfig(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "9090"
|
||||||
|
slurm_api_url: "http://slurm.example.com:6820"
|
||||||
|
slurm_user_name: "admin"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt.key"
|
||||||
|
mysql_dsn: "user:pass@tcp(10.0.0.1:3306)/testdb?parseTime=true"
|
||||||
|
minio:
|
||||||
|
endpoint: "minio.example.com:9000"
|
||||||
|
access_key: "myaccesskey"
|
||||||
|
secret_key: "mysecretkey"
|
||||||
|
bucket: "test-bucket"
|
||||||
|
use_ssl: true
|
||||||
|
chunk_size: 33554432
|
||||||
|
max_file_size: 107374182400
|
||||||
|
min_chunk_size: 10485760
|
||||||
|
session_ttl: 24
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Minio.Endpoint != "minio.example.com:9000" {
|
||||||
|
t.Errorf("Minio.Endpoint = %q, want %q", cfg.Minio.Endpoint, "minio.example.com:9000")
|
||||||
|
}
|
||||||
|
if cfg.Minio.AccessKey != "myaccesskey" {
|
||||||
|
t.Errorf("Minio.AccessKey = %q, want %q", cfg.Minio.AccessKey, "myaccesskey")
|
||||||
|
}
|
||||||
|
if cfg.Minio.SecretKey != "mysecretkey" {
|
||||||
|
t.Errorf("Minio.SecretKey = %q, want %q", cfg.Minio.SecretKey, "mysecretkey")
|
||||||
|
}
|
||||||
|
if cfg.Minio.Bucket != "test-bucket" {
|
||||||
|
t.Errorf("Minio.Bucket = %q, want %q", cfg.Minio.Bucket, "test-bucket")
|
||||||
|
}
|
||||||
|
if cfg.Minio.UseSSL != true {
|
||||||
|
t.Errorf("Minio.UseSSL = %v, want %v", cfg.Minio.UseSSL, true)
|
||||||
|
}
|
||||||
|
if cfg.Minio.ChunkSize != 33554432 {
|
||||||
|
t.Errorf("Minio.ChunkSize = %d, want %d", cfg.Minio.ChunkSize, 33554432)
|
||||||
|
}
|
||||||
|
if cfg.Minio.MaxFileSize != 107374182400 {
|
||||||
|
t.Errorf("Minio.MaxFileSize = %d, want %d", cfg.Minio.MaxFileSize, 107374182400)
|
||||||
|
}
|
||||||
|
if cfg.Minio.MinChunkSize != 10485760 {
|
||||||
|
t.Errorf("Minio.MinChunkSize = %d, want %d", cfg.Minio.MinChunkSize, 10485760)
|
||||||
|
}
|
||||||
|
if cfg.Minio.SessionTTL != 24 {
|
||||||
|
t.Errorf("Minio.SessionTTL = %d, want %d", cfg.Minio.SessionTTL, 24)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWithoutMinioConfig(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "8080"
|
||||||
|
slurm_api_url: "http://localhost:6820"
|
||||||
|
slurm_user_name: "root"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
||||||
|
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Minio.Endpoint != "" {
|
||||||
|
t.Errorf("Minio.Endpoint = %q, want empty string", cfg.Minio.Endpoint)
|
||||||
|
}
|
||||||
|
if cfg.Minio.AccessKey != "" {
|
||||||
|
t.Errorf("Minio.AccessKey = %q, want empty string", cfg.Minio.AccessKey)
|
||||||
|
}
|
||||||
|
if cfg.Minio.SecretKey != "" {
|
||||||
|
t.Errorf("Minio.SecretKey = %q, want empty string", cfg.Minio.SecretKey)
|
||||||
|
}
|
||||||
|
if cfg.Minio.Bucket != "" {
|
||||||
|
t.Errorf("Minio.Bucket = %q, want empty string", cfg.Minio.Bucket)
|
||||||
|
}
|
||||||
|
if cfg.Minio.UseSSL != false {
|
||||||
|
t.Errorf("Minio.UseSSL = %v, want false", cfg.Minio.UseSSL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadMinioDefaults(t *testing.T) {
|
||||||
|
content := []byte(`server_port: "8080"
|
||||||
|
slurm_api_url: "http://localhost:6820"
|
||||||
|
slurm_user_name: "root"
|
||||||
|
slurm_jwt_key_path: "/etc/slurm/jwt_hs256.key"
|
||||||
|
mysql_dsn: "root:@tcp(127.0.0.1:3306)/hpc_platform?parseTime=true"
|
||||||
|
minio:
|
||||||
|
endpoint: "localhost:9000"
|
||||||
|
access_key: "minioadmin"
|
||||||
|
secret_key: "minioadmin"
|
||||||
|
bucket: "uploads"
|
||||||
|
`)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Minio.ChunkSize != 16<<20 {
|
||||||
|
t.Errorf("Minio.ChunkSize = %d, want %d", cfg.Minio.ChunkSize, 16<<20)
|
||||||
|
}
|
||||||
|
if cfg.Minio.MaxFileSize != 50<<30 {
|
||||||
|
t.Errorf("Minio.MaxFileSize = %d, want %d", cfg.Minio.MaxFileSize, 50<<30)
|
||||||
|
}
|
||||||
|
if cfg.Minio.MinChunkSize != 5<<20 {
|
||||||
|
t.Errorf("Minio.MinChunkSize = %d, want %d", cfg.Minio.MinChunkSize, 5<<20)
|
||||||
|
}
|
||||||
|
if cfg.Minio.SessionTTL != 48 {
|
||||||
|
t.Errorf("Minio.SessionTTL = %d, want %d", cfg.Minio.SessionTTL, 48)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
174
internal/handler/application.go
Normal file
174
internal/handler/application.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ApplicationHandler struct {
|
||||||
|
appSvc *service.ApplicationService
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewApplicationHandler(appSvc *service.ApplicationService, logger *zap.Logger) *ApplicationHandler {
|
||||||
|
return &ApplicationHandler{appSvc: appSvc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ApplicationHandler) ListApplications(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
apps, total, err := h.appSvc.ListApplications(c.Request.Context(), page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to list applications", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, gin.H{
|
||||||
|
"applications": apps,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": pageSize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ApplicationHandler) CreateApplication(c *gin.Context) {
|
||||||
|
var req model.CreateApplicationRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for create application", zap.Error(err))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Name == "" || req.ScriptTemplate == "" {
|
||||||
|
h.logger.Warn("missing required fields for create application")
|
||||||
|
server.BadRequest(c, "name and script_template are required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.Parameters) == 0 {
|
||||||
|
req.Parameters = json.RawMessage(`[]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := h.appSvc.CreateApplication(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to create application", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("application created", zap.Int64("id", id))
|
||||||
|
server.Created(c, gin.H{"id": id})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ApplicationHandler) GetApplication(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid application id", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
app, err := h.appSvc.GetApplication(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to get application", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if app == nil {
|
||||||
|
h.logger.Warn("application not found", zap.Int64("id", id))
|
||||||
|
server.NotFound(c, "application not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server.OK(c, app)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ApplicationHandler) UpdateApplication(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid application id for update", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req model.UpdateApplicationRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for update application", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.appSvc.UpdateApplication(c.Request.Context(), id, &req); err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
h.logger.Warn("application not found for update", zap.Int64("id", id))
|
||||||
|
server.NotFound(c, "application not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to update application", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, "failed to update application")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("application updated", zap.Int64("id", id))
|
||||||
|
server.OK(c, gin.H{"message": "application updated"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ApplicationHandler) DeleteApplication(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid application id for delete", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.appSvc.DeleteApplication(c.Request.Context(), id); err != nil {
|
||||||
|
h.logger.Error("failed to delete application", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("application deleted", zap.Int64("id", id))
|
||||||
|
server.OK(c, gin.H{"message": "application deleted"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此端点不再使用。
|
||||||
|
/* func (h *ApplicationHandler) SubmitApplication(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid application id for submit", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req model.ApplicationSubmitRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for submit application", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := h.appSvc.SubmitFromApplication(c.Request.Context(), id, req.Values)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "not found") {
|
||||||
|
h.logger.Warn("application not found for submit", zap.Int64("id", id))
|
||||||
|
server.NotFound(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "validation") {
|
||||||
|
h.logger.Warn("application submit validation failed", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.BadRequest(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to submit application", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("application submitted", zap.Int64("id", id), zap.Int32("job_id", resp.JobID))
|
||||||
|
server.Created(c, resp)
|
||||||
|
} */
|
||||||
642
internal/handler/application_test.go
Normal file
642
internal/handler/application_test.go
Normal file
@@ -0,0 +1,642 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func itoa(id int64) string {
|
||||||
|
return fmt.Sprintf("%d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupApplicationHandler() (*ApplicationHandler, *gorm.DB) {
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
db.AutoMigrate(&model.Application{})
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
appSvc := service.NewApplicationService(appStore, nil, "", zap.NewNop())
|
||||||
|
h := NewApplicationHandler(appSvc, zap.NewNop())
|
||||||
|
return h, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupApplicationRouter(h *ApplicationHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
apps := v1.Group("/applications")
|
||||||
|
apps.GET("", h.ListApplications)
|
||||||
|
apps.POST("", h.CreateApplication)
|
||||||
|
apps.GET("/:id", h.GetApplication)
|
||||||
|
apps.PUT("/:id", h.UpdateApplication)
|
||||||
|
apps.DELETE("/:id", h.DeleteApplication)
|
||||||
|
// apps.POST("/:id/submit", h.SubmitApplication) // [已禁用] 已被 POST /tasks 取代
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupApplicationHandlerWithSlurm(slurmHandler http.HandlerFunc) (*ApplicationHandler, func()) {
|
||||||
|
srv := httptest.NewServer(slurmHandler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
db.AutoMigrate(&model.Application{})
|
||||||
|
|
||||||
|
jobSvc := service.NewJobService(client, zap.NewNop())
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, "", zap.NewNop())
|
||||||
|
h := NewApplicationHandler(appSvc, zap.NewNop())
|
||||||
|
return h, srv.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupApplicationHandlerWithObserver() (*ApplicationHandler, *gorm.DB, *observer.ObservedLogs) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
db.AutoMigrate(&model.Application{})
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
appSvc := service.NewApplicationService(appStore, nil, "", l)
|
||||||
|
return NewApplicationHandler(appSvc, l), db, recorded
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestApplication(h *ApplicationHandler, r *gin.Engine) int64 {
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
Name: "test-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
return int64(data["id"].(float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- CRUD Tests ----
|
||||||
|
|
||||||
|
func TestCreateApplication_Success(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
Name: "my-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if _, ok := data["id"]; !ok {
|
||||||
|
t.Fatal("expected id in response data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateApplication_MissingName(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateApplication_MissingScriptTemplate(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
Name: "my-app",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateApplication_EmptyParameters(t *testing.T) {
|
||||||
|
h, db := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
body := `{"name":"empty-params-app","script_template":"#!/bin/bash\necho hello"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var app model.Application
|
||||||
|
db.First(&app)
|
||||||
|
if string(app.Parameters) != "[]" {
|
||||||
|
t.Fatalf("expected parameters to default to [], got %s", string(app.Parameters))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListApplications_Success(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
createTestApplication(h, r)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["total"].(float64) < 1 {
|
||||||
|
t.Fatal("expected at least 1 application")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListApplications_Pagination(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
Name: fmt.Sprintf("app-%d", i),
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho " + fmt.Sprintf("app-%d", i),
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications?page=1&page_size=2", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["total"].(float64) != 5 {
|
||||||
|
t.Fatalf("expected total=5, got %v", data["total"])
|
||||||
|
}
|
||||||
|
if data["page"].(float64) != 1 {
|
||||||
|
t.Fatalf("expected page=1, got %v", data["page"])
|
||||||
|
}
|
||||||
|
if data["page_size"].(float64) != 2 {
|
||||||
|
t.Fatalf("expected page_size=2, got %v", data["page_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetApplication_Success(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
id := createTestApplication(h, r)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["name"] != "test-app" {
|
||||||
|
t.Fatalf("expected name=test-app, got %v", data["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetApplication_NotFound(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/999", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetApplication_InvalidID(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/abc", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateApplication_Success(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
id := createTestApplication(h, r)
|
||||||
|
|
||||||
|
newName := "updated-app"
|
||||||
|
body, _ := json.Marshal(model.UpdateApplicationRequest{
|
||||||
|
Name: &newName,
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/"+itoa(id), bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateApplication_NotFound(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
newName := "updated-app"
|
||||||
|
body, _ := json.Marshal(model.UpdateApplicationRequest{
|
||||||
|
Name: &newName,
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/999", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteApplication_Success(t *testing.T) {
|
||||||
|
h, _ := setupApplicationHandler()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
id := createTestApplication(h, r)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/applications/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
req2, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
if w2.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404 after delete, got %d", w2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Submit Tests ----
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitApplication_Success(t *testing.T) {
|
||||||
|
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"job_id": 12345,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
|
||||||
|
defer cleanup()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
params := `[{"name":"COUNT","type":"integer","required":true}]`
|
||||||
|
body := `{"name":"submit-app","script_template":"#!/bin/bash\necho $COUNT","parameters":` + params + `}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var createResp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &createResp)
|
||||||
|
data := createResp["data"].(map[string]interface{})
|
||||||
|
id := int64(data["id"].(float64))
|
||||||
|
|
||||||
|
submitBody := `{"values":{"COUNT":"5"}}`
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
|
||||||
|
req2.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
if w2.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w2.Code, w2.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitApplication_AppNotFound(t *testing.T) {
|
||||||
|
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
|
||||||
|
defer cleanup()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
submitBody := `{"values":{"COUNT":"5"}}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/999/submit", strings.NewReader(submitBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitApplication_ValidationFail(t *testing.T) {
|
||||||
|
slurmHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
|
})
|
||||||
|
h, cleanup := setupApplicationHandlerWithSlurm(slurmHandler)
|
||||||
|
defer cleanup()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
params := `[{"name":"COUNT","type":"integer","required":true}]`
|
||||||
|
body := `{"name":"val-app","script_template":"#!/bin/bash\necho $COUNT","parameters":` + params + `}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var createResp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &createResp)
|
||||||
|
data := createResp["data"].(map[string]interface{})
|
||||||
|
id := int64(data["id"].(float64))
|
||||||
|
|
||||||
|
submitBody := `{"values":{"COUNT":"not-a-number"}}`
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
|
||||||
|
req2.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
if w2.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w2.Code, w2.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ---- Logging Tests ----
|
||||||
|
|
||||||
|
func TestApplicationLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupApplicationHandlerWithObserver()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
Name: "log-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
if entry.Message == "application created" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected 'application created' log message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupApplicationHandlerWithObserver()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/applications/999", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
if entry.Message == "application not found" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected 'application not found' log message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupApplicationHandlerWithObserver()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
id := createTestApplication(h, r)
|
||||||
|
recorded.TakeAll()
|
||||||
|
|
||||||
|
newName := "updated-log-app"
|
||||||
|
body, _ := json.Marshal(model.UpdateApplicationRequest{
|
||||||
|
Name: &newName,
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "/api/v1/applications/"+itoa(id), bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
if entry.Message == "application updated" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected 'application updated' log message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
h, _, recorded := setupApplicationHandlerWithObserver()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
id := createTestApplication(h, r)
|
||||||
|
recorded.TakeAll()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/applications/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
if entry.Message == "application deleted" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected 'application deleted' log message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestApplicationLogging_SubmitSuccess_LogsInfoWithID(t *testing.T) {
|
||||||
|
core, recorded := observer.New(zapcore.DebugLevel)
|
||||||
|
l := zap.New(core)
|
||||||
|
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"job_id": 42,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
db.AutoMigrate(&model.Application{})
|
||||||
|
jobSvc := service.NewJobService(client, l)
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
appSvc := service.NewApplicationService(appStore, jobSvc, "", l)
|
||||||
|
h := NewApplicationHandler(appSvc, l)
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
params := `[{"name":"X","type":"string","required":false}]`
|
||||||
|
body := `{"name":"sub-log-app","script_template":"#!/bin/bash\necho $X","parameters":` + params + `}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var createResp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &createResp)
|
||||||
|
data := createResp["data"].(map[string]interface{})
|
||||||
|
id := int64(data["id"].(float64))
|
||||||
|
recorded.TakeAll()
|
||||||
|
|
||||||
|
submitBody := `{"values":{"X":"val"}}`
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
req2, _ := http.NewRequest(http.MethodPost, "/api/v1/applications/"+itoa(id)+"/submit", strings.NewReader(submitBody))
|
||||||
|
req2.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
if entry.Message == "application submitted" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected 'application submitted' log message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestApplicationLogging_CreateBadRequest_LogsWarn(t *testing.T) {
|
||||||
|
h, _, recorded := setupApplicationHandlerWithObserver()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", strings.NewReader(`{"name":""}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
if entry.Message == "invalid request body for create application" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected 'invalid request body for create application' log message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationLogging_LogsDoNotContainApplicationContent(t *testing.T) {
|
||||||
|
h, _, recorded := setupApplicationHandlerWithObserver()
|
||||||
|
r := setupApplicationRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateApplicationRequest{
|
||||||
|
Name: "secret-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho secret_password_here",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/applications", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
for _, entry := range recorded.All() {
|
||||||
|
msg := entry.Message
|
||||||
|
if strings.Contains(msg, "secret_password_here") {
|
||||||
|
t.Fatalf("log message contains application content: %s", msg)
|
||||||
|
}
|
||||||
|
for _, field := range entry.Context {
|
||||||
|
if strings.Contains(field.String, "secret_password_here") {
|
||||||
|
t.Fatalf("log field contains application content: %s", field.String)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -319,8 +319,8 @@ func TestClusterHandler_GetNodes_Success_NoLogs(t *testing.T) {
|
|||||||
t.Fatalf("expected 200, got %d", w.Code)
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 0 {
|
if recorded.Len() != 2 {
|
||||||
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,10 +370,10 @@ func TestClusterHandler_GetNode_NotFound_LogsWarn(t *testing.T) {
|
|||||||
t.Fatalf("expected 404, got %d", w.Code)
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 1 {
|
if recorded.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
|
t.Fatalf("expected 3 log entries, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
entry := recorded.All()[0]
|
entry := recorded.All()[2]
|
||||||
if entry.Level != zapcore.WarnLevel {
|
if entry.Level != zapcore.WarnLevel {
|
||||||
t.Fatalf("expected Warn level, got %v", entry.Level)
|
t.Fatalf("expected Warn level, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
@@ -405,8 +405,8 @@ func TestClusterHandler_GetNode_Success_NoLogs(t *testing.T) {
|
|||||||
t.Fatalf("expected 200, got %d", w.Code)
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 0 {
|
if recorded.Len() != 2 {
|
||||||
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,8 +458,8 @@ func TestClusterHandler_GetPartitions_Success_NoLogs(t *testing.T) {
|
|||||||
t.Fatalf("expected 200, got %d", w.Code)
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 0 {
|
if recorded.Len() != 2 {
|
||||||
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -509,10 +509,10 @@ func TestClusterHandler_GetPartition_NotFound_LogsWarn(t *testing.T) {
|
|||||||
t.Fatalf("expected 404, got %d", w.Code)
|
t.Fatalf("expected 404, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 1 {
|
if recorded.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", recorded.Len())
|
t.Fatalf("expected 3 log entries, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
entry := recorded.All()[0]
|
entry := recorded.All()[2]
|
||||||
if entry.Level != zapcore.WarnLevel {
|
if entry.Level != zapcore.WarnLevel {
|
||||||
t.Fatalf("expected Warn level, got %v", entry.Level)
|
t.Fatalf("expected Warn level, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
@@ -559,8 +559,8 @@ func TestClusterHandler_GetPartition_Success_NoLogs(t *testing.T) {
|
|||||||
t.Fatalf("expected 200, got %d", w.Code)
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 0 {
|
if recorded.Len() != 2 {
|
||||||
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -617,8 +617,8 @@ func TestClusterHandler_GetDiag_Success_NoLogs(t *testing.T) {
|
|||||||
t.Fatalf("expected 200, got %d", w.Code)
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if recorded.Len() != 0 {
|
if recorded.Len() != 2 {
|
||||||
t.Fatalf("expected 0 log entries on success, got %d", recorded.Len())
|
t.Fatalf("expected 2 log entries on success, got %d", recorded.Len())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
138
internal/handler/file_handler.go
Normal file
138
internal/handler/file_handler.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileServiceProvider interface {
|
||||||
|
ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error)
|
||||||
|
GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error)
|
||||||
|
DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error)
|
||||||
|
DeleteFile(ctx context.Context, fileID int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileHandler struct {
|
||||||
|
svc fileServiceProvider
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFileHandler(svc *service.FileService, logger *zap.Logger) *FileHandler {
|
||||||
|
return &FileHandler{svc: svc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandler) ListFiles(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
var folderID *int64
|
||||||
|
if v := c.Query("folder_id"); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid folder_id", zap.String("folder_id", v))
|
||||||
|
server.BadRequest(c, "invalid folder_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
folderID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
search := c.Query("search")
|
||||||
|
|
||||||
|
files, total, err := h.svc.ListFiles(c.Request.Context(), folderID, page, pageSize, search)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to list files", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, model.ListFilesResponse{
|
||||||
|
Files: files,
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandler) GetFile(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid file id", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, blob, err := h.svc.GetFileMetadata(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to get file", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := model.FileResponse{
|
||||||
|
ID: file.ID,
|
||||||
|
Name: file.Name,
|
||||||
|
FolderID: file.FolderID,
|
||||||
|
Size: blob.FileSize,
|
||||||
|
MimeType: blob.MimeType,
|
||||||
|
SHA256: file.BlobSHA256,
|
||||||
|
CreatedAt: file.CreatedAt,
|
||||||
|
UpdatedAt: file.UpdatedAt,
|
||||||
|
}
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandler) DownloadFile(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid file id for download", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rangeHeader := c.GetHeader("Range")
|
||||||
|
|
||||||
|
reader, file, blob, start, end, err := h.svc.DownloadFile(c.Request.Context(), id, rangeHeader)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to download file", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rangeHeader != "" {
|
||||||
|
server.StreamRange(c, reader, start, end, blob.FileSize, blob.MimeType)
|
||||||
|
} else {
|
||||||
|
server.StreamFile(c, reader, file.Name, blob.FileSize, blob.MimeType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandler) DeleteFile(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid file id for delete", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.svc.DeleteFile(c.Request.Context(), id); err != nil {
|
||||||
|
h.logger.Error("failed to delete file", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("file deleted", zap.Int64("id", id))
|
||||||
|
server.OK(c, gin.H{"message": "file deleted"})
|
||||||
|
}
|
||||||
369
internal/handler/file_handler_test.go
Normal file
369
internal/handler/file_handler_test.go
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockFileService struct {
|
||||||
|
listFilesFn func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error)
|
||||||
|
getFileMetadataFn func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error)
|
||||||
|
downloadFileFn func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error)
|
||||||
|
deleteFileFn func(ctx context.Context, fileID int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileService) ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
|
||||||
|
return m.listFilesFn(ctx, folderID, page, pageSize, search)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
|
||||||
|
return m.getFileMetadataFn(ctx, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileService) DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
|
||||||
|
return m.downloadFileFn(ctx, fileID, rangeHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileService) DeleteFile(ctx context.Context, fileID int64) error {
|
||||||
|
return m.deleteFileFn(ctx, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileHandlerSetup struct {
|
||||||
|
handler *FileHandler
|
||||||
|
mock *mockFileService
|
||||||
|
router *gin.Engine
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileHandlerSetup() *fileHandlerSetup {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
mock := &mockFileService{}
|
||||||
|
h := &FileHandler{
|
||||||
|
svc: mock,
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
}
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
files := v1.Group("/files")
|
||||||
|
files.GET("", h.ListFiles)
|
||||||
|
files.GET("/:id", h.GetFile)
|
||||||
|
files.GET("/:id/download", h.DownloadFile)
|
||||||
|
files.DELETE("/:id", h.DeleteFile)
|
||||||
|
return &fileHandlerSetup{handler: h, mock: mock, router: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- ListFiles Tests ----
|
||||||
|
|
||||||
|
func TestListFiles_Empty(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
|
||||||
|
return []model.FileResponse{}, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
files := data["files"].([]interface{})
|
||||||
|
if len(files) != 0 {
|
||||||
|
t.Fatalf("expected empty files list, got %d", len(files))
|
||||||
|
}
|
||||||
|
if data["total"].(float64) != 0 {
|
||||||
|
t.Fatalf("expected total=0, got %v", data["total"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_WithFiles(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
now := time.Now()
|
||||||
|
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
|
||||||
|
return []model.FileResponse{
|
||||||
|
{ID: 1, Name: "a.txt", Size: 100, MimeType: "text/plain", SHA256: "abc123", CreatedAt: now, UpdatedAt: now},
|
||||||
|
{ID: 2, Name: "b.pdf", Size: 200, MimeType: "application/pdf", SHA256: "def456", CreatedAt: now, UpdatedAt: now},
|
||||||
|
}, 2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files?page=1&page_size=10", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
files := data["files"].([]interface{})
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("expected 2 files, got %d", len(files))
|
||||||
|
}
|
||||||
|
if data["total"].(float64) != 2 {
|
||||||
|
t.Fatalf("expected total=2, got %v", data["total"])
|
||||||
|
}
|
||||||
|
if data["page"].(float64) != 1 {
|
||||||
|
t.Fatalf("expected page=1, got %v", data["page"])
|
||||||
|
}
|
||||||
|
if data["page_size"].(float64) != 10 {
|
||||||
|
t.Fatalf("expected page_size=10, got %v", data["page_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_WithFolderID(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
var capturedFolderID *int64
|
||||||
|
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
|
||||||
|
capturedFolderID = folderID
|
||||||
|
return []model.FileResponse{}, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files?folder_id=5", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if capturedFolderID == nil || *capturedFolderID != 5 {
|
||||||
|
t.Fatalf("expected folder_id=5, got %v", capturedFolderID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_ServiceError(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
s.mock.listFilesFn = func(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
|
||||||
|
return nil, 0, fmt.Errorf("db error")
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- GetFile Tests ----
|
||||||
|
|
||||||
|
func TestGetFile_Found(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
now := time.Now()
|
||||||
|
s.mock.getFileMetadataFn = func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
|
||||||
|
return &model.File{
|
||||||
|
ID: 1, Name: "test.txt", BlobSHA256: "abc123", CreatedAt: now, UpdatedAt: now,
|
||||||
|
}, &model.FileBlob{
|
||||||
|
ID: 1, SHA256: "abc123", FileSize: 1024, MimeType: "text/plain", CreatedAt: now,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFile_NotFound(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
s.mock.getFileMetadataFn = func(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
|
||||||
|
return nil, nil, fmt.Errorf("file not found: 999")
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/999", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFile_InvalidID(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/abc", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- DownloadFile Tests ----
|
||||||
|
|
||||||
|
func TestDownloadFile_Full(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
content := "hello world file content"
|
||||||
|
now := time.Now()
|
||||||
|
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
|
||||||
|
reader := io.NopCloser(strings.NewReader(content))
|
||||||
|
return reader,
|
||||||
|
&model.File{ID: 1, Name: "test.txt", CreatedAt: now},
|
||||||
|
&model.FileBlob{FileSize: int64(len(content)), MimeType: "text/plain"},
|
||||||
|
0, int64(len(content)) - 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1/download", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check streaming headers
|
||||||
|
if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "test.txt") {
|
||||||
|
t.Fatalf("expected Content-Disposition to contain 'test.txt', got %s", got)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Content-Type"); got != "text/plain" {
|
||||||
|
t.Fatalf("expected Content-Type=text/plain, got %s", got)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Accept-Ranges"); got != "bytes" {
|
||||||
|
t.Fatalf("expected Accept-Ranges=bytes, got %s", got)
|
||||||
|
}
|
||||||
|
if w.Body.String() != content {
|
||||||
|
t.Fatalf("expected body %q, got %q", content, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_WithRange(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
content := "hello world"
|
||||||
|
now := time.Now()
|
||||||
|
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
|
||||||
|
reader := io.NopCloser(strings.NewReader(content[0:5]))
|
||||||
|
return reader,
|
||||||
|
&model.File{ID: 1, Name: "test.txt", CreatedAt: now},
|
||||||
|
&model.FileBlob{FileSize: int64(len(content)), MimeType: "text/plain"},
|
||||||
|
0, 4, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/1/download", nil)
|
||||||
|
req.Header.Set("Range", "bytes=0-4")
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusPartialContent {
|
||||||
|
t.Fatalf("expected 206, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Content-Range header
|
||||||
|
if got := w.Header().Get("Content-Range"); got != "bytes 0-4/11" {
|
||||||
|
t.Fatalf("expected Content-Range 'bytes 0-4/11', got %s", got)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Content-Type"); got != "text/plain" {
|
||||||
|
t.Fatalf("expected Content-Type=text/plain, got %s", got)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("Content-Length"); got != "5" {
|
||||||
|
t.Fatalf("expected Content-Length=5, got %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_InvalidID(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/abc/download", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_ServiceError(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
s.mock.downloadFileFn = func(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("file not found: 999")
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/999/download", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- DeleteFile Tests ----
|
||||||
|
|
||||||
|
func TestDeleteFile_Success(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
s.mock.deleteFileFn = func(ctx context.Context, fileID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/1", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["message"] != "file deleted" {
|
||||||
|
t.Fatalf("expected message 'file deleted', got %v", data["message"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile_InvalidID(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/abc", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile_ServiceError(t *testing.T) {
|
||||||
|
s := newFileHandlerSetup()
|
||||||
|
s.mock.deleteFileFn = func(ctx context.Context, fileID int64) error {
|
||||||
|
return fmt.Errorf("file not found: 999")
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/999", nil)
|
||||||
|
s.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
133
internal/handler/folder_handler.go
Normal file
133
internal/handler/folder_handler.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type folderServiceProvider interface {
|
||||||
|
CreateFolder(ctx context.Context, name string, parentID *int64) (*model.FolderResponse, error)
|
||||||
|
GetFolder(ctx context.Context, id int64) (*model.FolderResponse, error)
|
||||||
|
ListFolders(ctx context.Context, parentID *int64) ([]model.FolderResponse, error)
|
||||||
|
DeleteFolder(ctx context.Context, id int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type FolderHandler struct {
|
||||||
|
svc folderServiceProvider
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFolderHandler(svc *service.FolderService, logger *zap.Logger) *FolderHandler {
|
||||||
|
return &FolderHandler{svc: svc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FolderHandler) CreateFolder(c *gin.Context) {
|
||||||
|
var req model.CreateFolderRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for create folder", zap.Error(err))
|
||||||
|
server.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Name == "" {
|
||||||
|
h.logger.Warn("missing folder name")
|
||||||
|
server.BadRequest(c, "name is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.svc.CreateFolder(c.Request.Context(), req.Name, req.ParentID)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "invalid folder name") || strings.Contains(errStr, "cannot be") {
|
||||||
|
h.logger.Warn("invalid folder name", zap.String("name", req.Name), zap.Error(err))
|
||||||
|
server.BadRequest(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to create folder", zap.Error(err))
|
||||||
|
server.InternalError(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("folder created", zap.Int64("id", resp.ID))
|
||||||
|
server.Created(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FolderHandler) GetFolder(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid folder id", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.svc.GetFolder(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
h.logger.Warn("folder not found", zap.Int64("id", id))
|
||||||
|
server.NotFound(c, "folder not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to get folder", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FolderHandler) ListFolders(c *gin.Context) {
|
||||||
|
var parentID *int64
|
||||||
|
if q := c.Query("parent_id"); q != "" {
|
||||||
|
pid, err := strconv.ParseInt(q, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid parent_id query param", zap.String("parent_id", q))
|
||||||
|
server.BadRequest(c, "invalid parent_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parentID = &pid
|
||||||
|
}
|
||||||
|
|
||||||
|
folders, err := h.svc.ListFolders(c.Request.Context(), parentID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to list folders", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if folders == nil {
|
||||||
|
folders = []model.FolderResponse{}
|
||||||
|
}
|
||||||
|
server.OK(c, folders)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FolderHandler) DeleteFolder(c *gin.Context) {
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid folder id for delete", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.svc.DeleteFolder(c.Request.Context(), id); err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "not empty") {
|
||||||
|
h.logger.Warn("cannot delete non-empty folder", zap.Int64("id", id))
|
||||||
|
server.BadRequest(c, "folder is not empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "not found") {
|
||||||
|
h.logger.Warn("folder not found for delete", zap.Int64("id", id))
|
||||||
|
server.NotFound(c, "folder not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to delete folder", zap.Int64("id", id), zap.Error(err))
|
||||||
|
server.InternalError(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("folder deleted", zap.Int64("id", id))
|
||||||
|
server.OK(c, gin.H{"message": "folder deleted"})
|
||||||
|
}
|
||||||
206
internal/handler/folder_handler_test.go
Normal file
206
internal/handler/folder_handler_test.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFolderHandler() (*FolderHandler, *gorm.DB) {
|
||||||
|
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
db.AutoMigrate(&model.Folder{}, &model.File{})
|
||||||
|
folderStore := store.NewFolderStore(db)
|
||||||
|
fileStore := store.NewFileStore(db)
|
||||||
|
svc := service.NewFolderService(folderStore, fileStore, zap.NewNop())
|
||||||
|
h := NewFolderHandler(svc, zap.NewNop())
|
||||||
|
return h, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupFolderRouter(h *FolderHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
folders := v1.Group("/files/folders")
|
||||||
|
folders.POST("", h.CreateFolder)
|
||||||
|
folders.GET("/:id", h.GetFolder)
|
||||||
|
folders.GET("", h.ListFolders)
|
||||||
|
folders.DELETE("/:id", h.DeleteFolder)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestFolder(h *FolderHandler, r *gin.Engine) int64 {
|
||||||
|
body, _ := json.Marshal(model.CreateFolderRequest{Name: "test-folder"})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
return int64(data["id"].(float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_Success(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateFolderRequest{Name: "my-folder"})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["name"] != "my-folder" {
|
||||||
|
t.Fatalf("expected name=my-folder, got %v", data["name"])
|
||||||
|
}
|
||||||
|
if data["path"] != "/my-folder/" {
|
||||||
|
t.Fatalf("expected path=/my-folder/, got %v", data["path"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_PathTraversal(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateFolderRequest{Name: ".."})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_MissingName(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/folders", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFolder_Success(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
id := createTestFolder(h, r)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["name"] != "test-folder" {
|
||||||
|
t.Fatalf("expected name=test-folder, got %v", data["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFolder_NotFound(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders/999", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFolders_Success(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
createTestFolder(h, r)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/folders", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].([]interface{})
|
||||||
|
if len(data) < 1 {
|
||||||
|
t.Fatal("expected at least 1 folder")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFolder_Success(t *testing.T) {
|
||||||
|
h, _ := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
id := createTestFolder(h, r)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/folders/"+itoa(id), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFolder_NonEmpty(t *testing.T) {
|
||||||
|
h, db := setupFolderHandler()
|
||||||
|
r := setupFolderRouter(h)
|
||||||
|
|
||||||
|
parentID := createTestFolder(h, r)
|
||||||
|
|
||||||
|
child := &model.Folder{
|
||||||
|
Name: "child-folder",
|
||||||
|
ParentID: &parentID,
|
||||||
|
Path: "/test-folder/child-folder/",
|
||||||
|
}
|
||||||
|
db.Create(child)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/folders/"+itoa(parentID), nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -46,16 +46,30 @@ func (h *JobHandler) SubmitJob(c *gin.Context) {
|
|||||||
server.Created(c, resp)
|
server.Created(c, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJobs handles GET /api/v1/jobs.
|
// GetJobs handles GET /api/v1/jobs with pagination.
|
||||||
func (h *JobHandler) GetJobs(c *gin.Context) {
|
func (h *JobHandler) GetJobs(c *gin.Context) {
|
||||||
jobs, err := h.jobSvc.GetJobs(c.Request.Context())
|
var query model.JobListQuery
|
||||||
|
if err := c.ShouldBindQuery(&query); err != nil {
|
||||||
|
h.logger.Warn("bad request", zap.String("method", "GetJobs"), zap.String("error", "invalid query params"))
|
||||||
|
server.BadRequest(c, "invalid query params")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.Page < 1 {
|
||||||
|
query.Page = 1
|
||||||
|
}
|
||||||
|
if query.PageSize < 1 {
|
||||||
|
query.PageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.jobSvc.GetJobs(c.Request.Context(), &query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("handler error", zap.String("method", "GetJobs"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
|
h.logger.Error("handler error", zap.String("method", "GetJobs"), zap.Int("status", http.StatusInternalServerError), zap.Error(err))
|
||||||
server.InternalError(c, err.Error())
|
server.InternalError(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
server.OK(c, jobs)
|
server.OK(c, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJob handles GET /api/v1/jobs/:id.
|
// GetJob handles GET /api/v1/jobs/:id.
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func TestGetJobs_Success(t *testing.T) {
|
|||||||
|
|
||||||
router := setupJobRouter(handler)
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs?page=1&page_size=10", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
router.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
@@ -202,6 +202,93 @@ func TestGetJobs_Success(t *testing.T) {
|
|||||||
if !resp["success"].(bool) {
|
if !resp["success"].(bool) {
|
||||||
t.Fatal("expected success=true")
|
t.Fatal("expected success=true")
|
||||||
}
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
jobs := data["jobs"].([]interface{})
|
||||||
|
if len(jobs) != 2 {
|
||||||
|
t.Fatalf("expected 2 jobs, got %d", len(jobs))
|
||||||
|
}
|
||||||
|
if int(data["total"].(float64)) != 2 {
|
||||||
|
t.Errorf("expected total=2, got %v", data["total"])
|
||||||
|
}
|
||||||
|
if int(data["page"].(float64)) != 1 {
|
||||||
|
t.Errorf("expected page=1, got %v", data["page"])
|
||||||
|
}
|
||||||
|
if int(data["page_size"].(float64)) != 10 {
|
||||||
|
t.Errorf("expected page_size=10, got %v", data["page_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobs_Pagination(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: []slurm.JobInfo{
|
||||||
|
{JobID: slurm.Ptr(int32(1)), Name: slurm.Ptr("job1")},
|
||||||
|
{JobID: slurm.Ptr(int32(2)), Name: slurm.Ptr("job2")},
|
||||||
|
{JobID: slurm.Ptr(int32(3)), Name: slurm.Ptr("job3")},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs?page=2&page_size=1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
jobs := data["jobs"].([]interface{})
|
||||||
|
if len(jobs) != 1 {
|
||||||
|
t.Fatalf("expected 1 job on page 2, got %d", len(jobs))
|
||||||
|
}
|
||||||
|
if int(data["total"].(float64)) != 3 {
|
||||||
|
t.Errorf("expected total=3, got %v", data["total"])
|
||||||
|
}
|
||||||
|
jobData := jobs[0].(map[string]interface{})
|
||||||
|
if int(jobData["job_id"].(float64)) != 2 {
|
||||||
|
t.Errorf("expected job_id=2 on page 2, got %v", jobData["job_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJobs_DefaultPagination(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/slurm/v0.0.40/jobs", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobInfoResp{Jobs: []slurm.JobInfo{}})
|
||||||
|
})
|
||||||
|
srv, handler := setupJobHandler(mux)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
router := setupJobRouter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/jobs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if int(data["page"].(float64)) != 1 {
|
||||||
|
t.Errorf("expected default page=1, got %v", data["page"])
|
||||||
|
}
|
||||||
|
if int(data["page_size"].(float64)) != 20 {
|
||||||
|
t.Errorf("expected default page_size=20, got %v", data["page_size"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetJob_Success(t *testing.T) {
|
func TestGetJob_Success(t *testing.T) {
|
||||||
|
|||||||
98
internal/handler/task_handler.go
Normal file
98
internal/handler/task_handler.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type taskServiceProvider interface {
|
||||||
|
SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error)
|
||||||
|
ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskHandler struct {
|
||||||
|
svc taskServiceProvider
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskHandler(svc taskServiceProvider, logger *zap.Logger) *TaskHandler {
|
||||||
|
return &TaskHandler{svc: svc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TaskHandler) CreateTask(c *gin.Context) {
|
||||||
|
var req model.CreateTaskRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for create task", zap.Error(err))
|
||||||
|
server.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, err := h.svc.SubmitAsync(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "not found") {
|
||||||
|
h.logger.Warn("task submit target not found", zap.Error(err))
|
||||||
|
server.NotFound(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "exceeds limit") || strings.Contains(errStr, "validation") {
|
||||||
|
h.logger.Warn("task submit validation failed", zap.Error(err))
|
||||||
|
server.BadRequest(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("failed to create task", zap.Error(err))
|
||||||
|
server.InternalError(c, errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("task created", zap.Int64("id", taskID))
|
||||||
|
server.Created(c, gin.H{"id": taskID})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TaskHandler) ListTasks(c *gin.Context) {
|
||||||
|
var query model.TaskListQuery
|
||||||
|
_ = c.ShouldBindQuery(&query)
|
||||||
|
|
||||||
|
if query.Page < 1 {
|
||||||
|
query.Page = 1
|
||||||
|
}
|
||||||
|
if query.PageSize < 1 || query.PageSize > 100 {
|
||||||
|
query.PageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, total, err := h.svc.ListTasks(c.Request.Context(), &query)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to list tasks", zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := make([]model.TaskResponse, 0, len(tasks))
|
||||||
|
for i := range tasks {
|
||||||
|
responses = append(responses, model.TaskResponse{
|
||||||
|
ID: tasks[i].ID,
|
||||||
|
TaskName: tasks[i].TaskName,
|
||||||
|
AppID: tasks[i].AppID,
|
||||||
|
AppName: tasks[i].AppName,
|
||||||
|
Status: tasks[i].Status,
|
||||||
|
CurrentStep: tasks[i].CurrentStep,
|
||||||
|
RetryCount: tasks[i].RetryCount,
|
||||||
|
SlurmJobID: tasks[i].SlurmJobID,
|
||||||
|
WorkDir: tasks[i].WorkDir,
|
||||||
|
ErrorMessage: tasks[i].ErrorMessage,
|
||||||
|
CreatedAt: tasks[i].CreatedAt,
|
||||||
|
UpdatedAt: tasks[i].UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, model.TaskListResponse{
|
||||||
|
Items: responses,
|
||||||
|
Total: total,
|
||||||
|
})
|
||||||
|
}
|
||||||
286
internal/handler/task_handler_test.go
Normal file
286
internal/handler/task_handler_test.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var taskDBCounter atomic.Int64
|
||||||
|
|
||||||
|
func setupTaskHandler(t *testing.T, slurmSrv *httptest.Server) (*TaskHandler, *gorm.DB) {
|
||||||
|
t.Helper()
|
||||||
|
dbFile := filepath.Join(t.TempDir(), fmt.Sprintf("test-%d.db", taskDBCounter.Add(1)))
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open db: %v", err)
|
||||||
|
}
|
||||||
|
db.AutoMigrate(&model.Task{}, &model.Application{})
|
||||||
|
t.Cleanup(func() { os.Remove(dbFile) })
|
||||||
|
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
|
||||||
|
var jobSvc *service.JobService
|
||||||
|
if slurmSrv != nil {
|
||||||
|
client, _ := slurm.NewClient(slurmSrv.URL, slurmSrv.Client())
|
||||||
|
jobSvc = service.NewJobService(client, zap.NewNop())
|
||||||
|
}
|
||||||
|
|
||||||
|
workDir := filepath.Join(t.TempDir(), "work")
|
||||||
|
taskSvc := service.NewTaskService(taskStore, appStore, nil, nil, nil, jobSvc, workDir, zap.NewNop())
|
||||||
|
h := NewTaskHandler(taskSvc, zap.NewNop())
|
||||||
|
return h, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTaskRouter(h *TaskHandler) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
tasks := v1.Group("/tasks")
|
||||||
|
tasks.POST("", h.CreateTask)
|
||||||
|
tasks.GET("", h.ListTasks)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestAppForTask(db *gorm.DB) int64 {
|
||||||
|
app := &model.Application{
|
||||||
|
Name: "test-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
Parameters: json.RawMessage(`[]`),
|
||||||
|
}
|
||||||
|
db.Create(app)
|
||||||
|
return app.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- CreateTask Tests ----
|
||||||
|
|
||||||
|
func TestTaskHandler_CreateTask_Success(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": 12345})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
h, db := setupTaskHandler(t, slurmSrv)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
appID := createTestAppForTask(db)
|
||||||
|
|
||||||
|
taskSvc := h.svc.(*service.TaskService)
|
||||||
|
ctx := context.Background()
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
defer taskSvc.StopProcessor()
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "my-task",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if _, ok := data["id"]; !ok {
|
||||||
|
t.Fatal("expected id in response data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_CreateTask_MissingAppID(t *testing.T) {
|
||||||
|
h, _ := setupTaskHandler(t, nil)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
body := `{"task_name":"no-app"}`
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_CreateTask_InvalidJSON(t *testing.T) {
|
||||||
|
h, _ := setupTaskHandler(t, nil)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader([]byte("not-json")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- ListTasks Tests ----
|
||||||
|
|
||||||
|
func TestTaskHandler_ListTasks_Pagination(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(100)})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
h, db := setupTaskHandler(t, slurmSrv)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
appID := createTestAppForTask(db)
|
||||||
|
|
||||||
|
taskSvc := h.svc.(*service.TaskService)
|
||||||
|
ctx := context.Background()
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
defer taskSvc.StopProcessor()
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: fmt.Sprintf("task-%d", i),
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async processing
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?page=1&page_size=3", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["total"].(float64) != 5 {
|
||||||
|
t.Fatalf("expected total=5, got %v", data["total"])
|
||||||
|
}
|
||||||
|
items := data["items"].([]interface{})
|
||||||
|
if len(items) != 3 {
|
||||||
|
t.Fatalf("expected 3 items, got %d", len(items))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_ListTasks_StatusFilter(t *testing.T) {
|
||||||
|
slurmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"job_id": int32(200)})
|
||||||
|
}))
|
||||||
|
defer slurmSrv.Close()
|
||||||
|
|
||||||
|
h, db := setupTaskHandler(t, slurmSrv)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
appID := createTestAppForTask(db)
|
||||||
|
|
||||||
|
taskSvc := h.svc.(*service.TaskService)
|
||||||
|
ctx := context.Background()
|
||||||
|
taskSvc.StartProcessor(ctx)
|
||||||
|
defer taskSvc.StopProcessor()
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
body, _ := json.Marshal(model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: fmt.Sprintf("filter-task-%d", i),
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async processing
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks?status=queued", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
items := data["items"].([]interface{})
|
||||||
|
for _, item := range items {
|
||||||
|
m := item.(map[string]interface{})
|
||||||
|
if m["status"] != "queued" {
|
||||||
|
t.Fatalf("expected status=queued, got %v", m["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskHandler_ListTasks_DefaultPagination(t *testing.T) {
|
||||||
|
h, db := setupTaskHandler(t, nil)
|
||||||
|
r := setupTaskRouter(h)
|
||||||
|
|
||||||
|
_ = createTestAppForTask(db)
|
||||||
|
|
||||||
|
// Directly insert tasks via DB to avoid needing processor
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
task := &model.Task{
|
||||||
|
TaskName: fmt.Sprintf("default-task-%d", i),
|
||||||
|
AppID: 1,
|
||||||
|
AppName: "test-app",
|
||||||
|
Status: model.TaskStatusSubmitted,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
db.Create(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/tasks", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["total"].(float64) != 15 {
|
||||||
|
t.Fatalf("expected total=15, got %v", data["total"])
|
||||||
|
}
|
||||||
|
items := data["items"].([]interface{})
|
||||||
|
if len(items) != 10 {
|
||||||
|
t.Fatalf("expected 10 items (default page_size), got %d", len(items))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
|
||||||
"gcy_hpc_server/internal/server"
|
|
||||||
"gcy_hpc_server/internal/store"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TemplateHandler struct {
|
|
||||||
store *store.TemplateStore
|
|
||||||
logger *zap.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTemplateHandler(s *store.TemplateStore, logger *zap.Logger) *TemplateHandler {
|
|
||||||
return &TemplateHandler{store: s, logger: logger}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TemplateHandler) ListTemplates(c *gin.Context) {
|
|
||||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
||||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
|
||||||
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
if pageSize < 1 {
|
|
||||||
pageSize = 20
|
|
||||||
}
|
|
||||||
|
|
||||||
templates, total, err := h.store.List(c.Request.Context(), page, pageSize)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("failed to list templates", zap.Error(err))
|
|
||||||
server.InternalError(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
server.OK(c, gin.H{
|
|
||||||
"templates": templates,
|
|
||||||
"total": total,
|
|
||||||
"page": page,
|
|
||||||
"page_size": pageSize,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TemplateHandler) CreateTemplate(c *gin.Context) {
|
|
||||||
var req model.CreateTemplateRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
h.logger.Warn("invalid request body for create template", zap.Error(err))
|
|
||||||
server.BadRequest(c, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Name == "" || req.Script == "" {
|
|
||||||
h.logger.Warn("missing required fields for create template")
|
|
||||||
server.BadRequest(c, "name and script are required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := h.store.Create(c.Request.Context(), &req)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("failed to create template", zap.Error(err))
|
|
||||||
server.InternalError(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Info("template created", zap.Int64("id", id))
|
|
||||||
server.Created(c, gin.H{"id": id})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TemplateHandler) GetTemplate(c *gin.Context) {
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn("invalid template id", zap.String("id", c.Param("id")))
|
|
||||||
server.BadRequest(c, "invalid id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl, err := h.store.GetByID(c.Request.Context(), id)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("failed to get template", zap.Int64("id", id), zap.Error(err))
|
|
||||||
server.InternalError(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tmpl == nil {
|
|
||||||
h.logger.Warn("template not found", zap.Int64("id", id))
|
|
||||||
server.NotFound(c, "template not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
server.OK(c, tmpl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TemplateHandler) UpdateTemplate(c *gin.Context) {
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn("invalid template id for update", zap.String("id", c.Param("id")))
|
|
||||||
server.BadRequest(c, "invalid id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req model.UpdateTemplateRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
h.logger.Warn("invalid request body for update template", zap.Int64("id", id), zap.Error(err))
|
|
||||||
server.BadRequest(c, "invalid request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.store.Update(c.Request.Context(), id, &req); err != nil {
|
|
||||||
h.logger.Error("failed to update template", zap.Int64("id", id), zap.Error(err))
|
|
||||||
server.InternalError(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Info("template updated", zap.Int64("id", id))
|
|
||||||
server.OK(c, gin.H{"message": "template updated"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TemplateHandler) DeleteTemplate(c *gin.Context) {
|
|
||||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn("invalid template id for delete", zap.String("id", c.Param("id")))
|
|
||||||
server.BadRequest(c, "invalid id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.store.Delete(c.Request.Context(), id); err != nil {
|
|
||||||
h.logger.Error("failed to delete template", zap.Int64("id", id), zap.Error(err))
|
|
||||||
server.InternalError(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Info("template deleted", zap.Int64("id", id))
|
|
||||||
server.OK(c, gin.H{"message": "template deleted"})
|
|
||||||
}
|
|
||||||
@@ -1,387 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
|
||||||
"gcy_hpc_server/internal/server"
|
|
||||||
"gcy_hpc_server/internal/store"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"go.uber.org/zap/zapcore"
|
|
||||||
"go.uber.org/zap/zaptest/observer"
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupTemplateHandler() (*TemplateHandler, *gorm.DB) {
|
|
||||||
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
|
||||||
db.AutoMigrate(&model.JobTemplate{})
|
|
||||||
s := store.NewTemplateStore(db)
|
|
||||||
h := NewTemplateHandler(s, zap.NewNop())
|
|
||||||
return h, db
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupTemplateRouter(h *TemplateHandler) *gin.Engine {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
r := gin.New()
|
|
||||||
v1 := r.Group("/api/v1")
|
|
||||||
templates := v1.Group("/templates")
|
|
||||||
templates.GET("", h.ListTemplates)
|
|
||||||
templates.POST("", h.CreateTemplate)
|
|
||||||
templates.GET("/:id", h.GetTemplate)
|
|
||||||
templates.PUT("/:id", h.UpdateTemplate)
|
|
||||||
templates.DELETE("/:id", h.DeleteTemplate)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListTemplates_Success(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
// Seed data
|
|
||||||
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal", QOS: "high", CPUs: 4, Memory: "4GB"})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
var resp server.APIResponse
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
|
||||||
if !resp.Success {
|
|
||||||
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateTemplate_Success(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
body := `{"name":"my-tpl","description":"desc","script":"echo hello","partition":"gpu"}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusCreated {
|
|
||||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
var resp server.APIResponse
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
|
||||||
if !resp.Success {
|
|
||||||
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateTemplate_MissingFields(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
body := `{"name":"","script":""}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTemplate_Success(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
// Seed data
|
|
||||||
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi", Partition: "normal"})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates/"+itoa(id), nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
var resp server.APIResponse
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
|
||||||
if !resp.Success {
|
|
||||||
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTemplate_NotFound(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusNotFound {
|
|
||||||
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTemplate_InvalidID(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateTemplate_Success(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
// Seed data
|
|
||||||
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
|
|
||||||
|
|
||||||
body := `{"name":"updated"}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
var resp server.APIResponse
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
|
||||||
if !resp.Success {
|
|
||||||
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteTemplate_Success(t *testing.T) {
|
|
||||||
h, _ := setupTemplateHandler()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
// Seed data
|
|
||||||
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
var resp server.APIResponse
|
|
||||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
|
||||||
if !resp.Success {
|
|
||||||
t.Fatalf("expected success=true, got: %s", w.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// itoa converts int64 to string for URL path construction.
|
|
||||||
func itoa(id int64) string {
|
|
||||||
return fmt.Sprintf("%d", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupTemplateHandlerWithObserver() (*TemplateHandler, *gorm.DB, *observer.ObservedLogs) {
|
|
||||||
core, recorded := observer.New(zapcore.DebugLevel)
|
|
||||||
l := zap.New(core)
|
|
||||||
db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
|
||||||
db.AutoMigrate(&model.JobTemplate{})
|
|
||||||
s := store.NewTemplateStore(db)
|
|
||||||
return NewTemplateHandler(s, l), db, recorded
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_CreateSuccess_LogsInfoWithID(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
body := `{"name":"log-tpl","script":"echo hi"}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusCreated {
|
|
||||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := recorded.FilterMessage("template created").FilterLevelExact(zapcore.InfoLevel).All()
|
|
||||||
if len(entries) != 1 {
|
|
||||||
t.Fatalf("expected 1 info log for 'template created', got %d", len(entries))
|
|
||||||
}
|
|
||||||
fields := entries[0].ContextMap()
|
|
||||||
if fields["id"] == nil {
|
|
||||||
t.Fatal("expected log entry to contain 'id' field")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_UpdateSuccess_LogsInfoWithID(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "old", Script: "echo hi"})
|
|
||||||
|
|
||||||
body := `{"name":"updated"}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("PUT", "/api/v1/templates/"+itoa(id), bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := recorded.FilterMessage("template updated").FilterLevelExact(zapcore.InfoLevel).All()
|
|
||||||
if len(entries) != 1 {
|
|
||||||
t.Fatalf("expected 1 info log for 'template updated', got %d", len(entries))
|
|
||||||
}
|
|
||||||
fields := entries[0].ContextMap()
|
|
||||||
if fields["id"] == nil {
|
|
||||||
t.Fatal("expected log entry to contain 'id' field")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_DeleteSuccess_LogsInfoWithID(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
id, _ := h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "to-delete", Script: "echo hi"})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("DELETE", "/api/v1/templates/"+itoa(id), nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := recorded.FilterMessage("template deleted").FilterLevelExact(zapcore.InfoLevel).All()
|
|
||||||
if len(entries) != 1 {
|
|
||||||
t.Fatalf("expected 1 info log for 'template deleted', got %d", len(entries))
|
|
||||||
}
|
|
||||||
fields := entries[0].ContextMap()
|
|
||||||
if fields["id"] == nil {
|
|
||||||
t.Fatal("expected log entry to contain 'id' field")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_GetNotFound_LogsWarnWithID(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates/999", nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusNotFound {
|
|
||||||
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := recorded.FilterMessage("template not found").FilterLevelExact(zapcore.WarnLevel).All()
|
|
||||||
if len(entries) != 1 {
|
|
||||||
t.Fatalf("expected 1 warn log for 'template not found', got %d", len(entries))
|
|
||||||
}
|
|
||||||
fields := entries[0].ContextMap()
|
|
||||||
if fields["id"] == nil {
|
|
||||||
t.Fatal("expected log entry to contain 'id' field")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_CreateBadRequest_LogsWarn(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
body := `{"name":"","script":""}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
|
|
||||||
if len(warnEntries) == 0 {
|
|
||||||
t.Fatal("expected at least 1 warn log for bad request")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_InvalidID_LogsWarn(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates/abc", nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
warnEntries := recorded.FilterLevelExact(zapcore.WarnLevel).All()
|
|
||||||
if len(warnEntries) == 0 {
|
|
||||||
t.Fatal("expected at least 1 warn log for invalid id")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_ListSuccess_NoInfoLog(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
// Seed data
|
|
||||||
h.store.Create(context.Background(), &model.CreateTemplateRequest{Name: "test-tpl", Script: "echo hi"})
|
|
||||||
|
|
||||||
// Reset recorded logs so the create log doesn't interfere
|
|
||||||
recorded.TakeAll()
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "/api/v1/templates", nil)
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
infoEntries := recorded.FilterLevelExact(zapcore.InfoLevel).All()
|
|
||||||
if len(infoEntries) != 0 {
|
|
||||||
t.Fatalf("expected 0 info logs for list success, got %d: %+v", len(infoEntries), infoEntries)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateLogging_LogsDoNotContainTemplateContent(t *testing.T) {
|
|
||||||
h, _, recorded := setupTemplateHandlerWithObserver()
|
|
||||||
r := setupTemplateRouter(h)
|
|
||||||
|
|
||||||
body := `{"name":"secret-name","script":"secret-script","partition":"secret-partition"}`
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("POST", "/api/v1/templates", bytes.NewBufferString(body))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
r.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusCreated {
|
|
||||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := recorded.All()
|
|
||||||
for _, e := range entries {
|
|
||||||
logStr := e.Message + " " + fmt.Sprintf("%v", e.ContextMap())
|
|
||||||
if strings.Contains(logStr, "secret-name") || strings.Contains(logStr, "secret-script") || strings.Contains(logStr, "secret-partition") {
|
|
||||||
t.Fatalf("log entry contains sensitive template content: %s", logStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
154
internal/handler/upload_handler.go
Normal file
154
internal/handler/upload_handler.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// uploadServiceProvider defines the interface for upload operations.
|
||||||
|
type uploadServiceProvider interface {
|
||||||
|
InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error)
|
||||||
|
UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error
|
||||||
|
CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error)
|
||||||
|
GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error)
|
||||||
|
CancelUpload(ctx context.Context, sessionID int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadHandler handles HTTP requests for chunked file uploads.
|
||||||
|
type UploadHandler struct {
|
||||||
|
svc uploadServiceProvider
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUploadHandler creates a new UploadHandler.
|
||||||
|
func NewUploadHandler(svc *service.UploadService, logger *zap.Logger) *UploadHandler {
|
||||||
|
return &UploadHandler{svc: svc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitUpload initiates a new chunked upload session.
|
||||||
|
// POST /api/v1/files/uploads
|
||||||
|
func (h *UploadHandler) InitUpload(c *gin.Context) {
|
||||||
|
var req model.InitUploadRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
h.logger.Warn("invalid request body for init upload", zap.Error(err))
|
||||||
|
server.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.svc.InitUpload(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to init upload", zap.Error(err))
|
||||||
|
server.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch resp := result.(type) {
|
||||||
|
case model.FileResponse:
|
||||||
|
server.OK(c, resp)
|
||||||
|
case model.UploadSessionResponse:
|
||||||
|
server.Created(c, resp)
|
||||||
|
default:
|
||||||
|
server.Created(c, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadChunk uploads a single chunk of an upload session.
|
||||||
|
// PUT /api/v1/files/uploads/:id/chunks/:index
|
||||||
|
func (h *UploadHandler) UploadChunk(c *gin.Context) {
|
||||||
|
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid session id", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid session id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkIndex, err := strconv.Atoi(c.Param("index"))
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid chunk index", zap.String("index", c.Param("index")))
|
||||||
|
server.BadRequest(c, "invalid chunk index")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, header, err := c.Request.FormFile("chunk")
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("missing chunk file in request", zap.Error(err))
|
||||||
|
server.BadRequest(c, "missing chunk file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
if err := h.svc.UploadChunk(c.Request.Context(), sessionID, chunkIndex, file, header.Size); err != nil {
|
||||||
|
h.logger.Error("failed to upload chunk", zap.Int64("session_id", sessionID), zap.Int("chunk_index", chunkIndex), zap.Error(err))
|
||||||
|
server.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, gin.H{"message": "chunk uploaded"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteUpload finalizes an upload session and assembles the file.
|
||||||
|
// POST /api/v1/files/uploads/:id/complete
|
||||||
|
func (h *UploadHandler) CompleteUpload(c *gin.Context) {
|
||||||
|
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid session id for complete", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid session id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.svc.CompleteUpload(c.Request.Context(), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to complete upload", zap.Int64("session_id", sessionID), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.Created(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUploadStatus returns the current status of an upload session.
|
||||||
|
// GET /api/v1/files/uploads/:id
|
||||||
|
func (h *UploadHandler) GetUploadStatus(c *gin.Context) {
|
||||||
|
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid session id for status", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid session id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.svc.GetUploadStatus(c.Request.Context(), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("failed to get upload status", zap.Int64("session_id", sessionID), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelUpload cancels and cleans up an upload session.
|
||||||
|
// DELETE /api/v1/files/uploads/:id
|
||||||
|
func (h *UploadHandler) CancelUpload(c *gin.Context) {
|
||||||
|
sessionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("invalid session id for cancel", zap.String("id", c.Param("id")))
|
||||||
|
server.BadRequest(c, "invalid session id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.svc.CancelUpload(c.Request.Context(), sessionID); err != nil {
|
||||||
|
h.logger.Error("failed to cancel upload", zap.Int64("session_id", sessionID), zap.Error(err))
|
||||||
|
server.InternalError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server.OK(c, gin.H{"message": "upload cancelled"})
|
||||||
|
}
|
||||||
307
internal/handler/upload_handler_test.go
Normal file
307
internal/handler/upload_handler_test.go
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockUploadService struct {
|
||||||
|
initUploadFn func(ctx context.Context, req model.InitUploadRequest) (interface{}, error)
|
||||||
|
uploadChunkFn func(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error
|
||||||
|
completeUploadFn func(ctx context.Context, sessionID int64) (*model.FileResponse, error)
|
||||||
|
getUploadStatusFn func(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error)
|
||||||
|
cancelUploadFn func(ctx context.Context, sessionID int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
|
||||||
|
return m.initUploadFn(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
|
||||||
|
return m.uploadChunkFn(ctx, sessionID, chunkIndex, reader, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
|
||||||
|
return m.completeUploadFn(ctx, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
|
||||||
|
return m.getUploadStatusFn(ctx, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUploadService) CancelUpload(ctx context.Context, sessionID int64) error {
|
||||||
|
return m.cancelUploadFn(ctx, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupUploadHandler(mock uploadServiceProvider) (*UploadHandler, *gin.Engine) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
h := &UploadHandler{svc: mock, logger: zap.NewNop()}
|
||||||
|
r := gin.New()
|
||||||
|
v1 := r.Group("/api/v1")
|
||||||
|
uploads := v1.Group("/files/uploads")
|
||||||
|
uploads.POST("", h.InitUpload)
|
||||||
|
uploads.PUT("/:id/chunks/:index", h.UploadChunk)
|
||||||
|
uploads.POST("/:id/complete", h.CompleteUpload)
|
||||||
|
uploads.GET("/:id", h.GetUploadStatus)
|
||||||
|
uploads.DELETE("/:id", h.CancelUpload)
|
||||||
|
return h, r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_NewSession(t *testing.T) {
|
||||||
|
mock := &mockUploadService{
|
||||||
|
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
|
||||||
|
return model.UploadSessionResponse{
|
||||||
|
ID: 1,
|
||||||
|
FileName: req.FileName,
|
||||||
|
FileSize: req.FileSize,
|
||||||
|
ChunkSize: 5 * 1024 * 1024,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: req.SHA256,
|
||||||
|
Status: "pending",
|
||||||
|
UploadedChunks: []int{},
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.InitUploadRequest{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 10 * 1024 * 1024,
|
||||||
|
SHA256: "abc123",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["status"] != "pending" {
|
||||||
|
t.Fatalf("expected status=pending, got %v", data["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_DedupHit(t *testing.T) {
|
||||||
|
mock := &mockUploadService{
|
||||||
|
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
|
||||||
|
return model.FileResponse{
|
||||||
|
ID: 42,
|
||||||
|
Name: req.FileName,
|
||||||
|
Size: req.FileSize,
|
||||||
|
SHA256: req.SHA256,
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(model.InitUploadRequest{
|
||||||
|
FileName: "existing.bin",
|
||||||
|
FileSize: 1024,
|
||||||
|
SHA256: "existinghash",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["sha256"] != "existinghash" {
|
||||||
|
t.Fatalf("expected sha256=existinghash, got %v", data["sha256"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_MissingFields(t *testing.T) {
|
||||||
|
mock := &mockUploadService{
|
||||||
|
initUploadFn: func(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads", bytes.NewReader([]byte(`{}`)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadChunk_Success(t *testing.T) {
|
||||||
|
var receivedSessionID int64
|
||||||
|
var receivedChunkIndex int
|
||||||
|
|
||||||
|
mock := &mockUploadService{
|
||||||
|
uploadChunkFn: func(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
|
||||||
|
receivedSessionID = sessionID
|
||||||
|
receivedChunkIndex = chunkIndex
|
||||||
|
readBytes, _ := io.ReadAll(reader)
|
||||||
|
if len(readBytes) == 0 {
|
||||||
|
t.Error("expected non-empty chunk data")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buf)
|
||||||
|
part, _ := writer.CreateFormFile("chunk", "test.bin")
|
||||||
|
fmt.Fprintf(part, "chunk data here")
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPut, "/api/v1/files/uploads/1/chunks/0", &buf)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
if receivedSessionID != 1 {
|
||||||
|
t.Fatalf("expected session_id=1, got %d", receivedSessionID)
|
||||||
|
}
|
||||||
|
if receivedChunkIndex != 0 {
|
||||||
|
t.Fatalf("expected chunk_index=0, got %d", receivedChunkIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteUpload_Success(t *testing.T) {
|
||||||
|
mock := &mockUploadService{
|
||||||
|
completeUploadFn: func(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
|
||||||
|
return &model.FileResponse{
|
||||||
|
ID: 10,
|
||||||
|
Name: "completed.bin",
|
||||||
|
Size: 2048,
|
||||||
|
SHA256: "completehash",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/v1/files/uploads/5/complete", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["name"] != "completed.bin" {
|
||||||
|
t.Fatalf("expected name=completed.bin, got %v", data["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUploadStatus_Success(t *testing.T) {
|
||||||
|
mock := &mockUploadService{
|
||||||
|
getUploadStatusFn: func(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
|
||||||
|
return &model.UploadSessionResponse{
|
||||||
|
ID: sessionID,
|
||||||
|
FileName: "status.bin",
|
||||||
|
FileSize: 4096,
|
||||||
|
ChunkSize: 2048,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: "statushash",
|
||||||
|
Status: "uploading",
|
||||||
|
UploadedChunks: []int{0},
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/files/uploads/3", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
if !resp["success"].(bool) {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["status"] != "uploading" {
|
||||||
|
t.Fatalf("expected status=uploading, got %v", data["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelUpload_Success(t *testing.T) {
|
||||||
|
var receivedSessionID int64
|
||||||
|
|
||||||
|
mock := &mockUploadService{
|
||||||
|
cancelUploadFn: func(ctx context.Context, sessionID int64) error {
|
||||||
|
receivedSessionID = sessionID
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, r := setupUploadHandler(mock)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, "/api/v1/files/uploads/7", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if receivedSessionID != 7 {
|
||||||
|
t.Fatalf("expected session_id=7, got %d", receivedSessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
data := resp["data"].(map[string]interface{})
|
||||||
|
if data["message"] != "upload cancelled" {
|
||||||
|
t.Fatalf("expected message='upload cancelled', got %v", data["message"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -47,7 +47,7 @@ func NewLogger(cfg config.LogConfig) (*zap.Logger, error) {
|
|||||||
maxSize := applyDefaultInt(cfg.MaxSize, 100)
|
maxSize := applyDefaultInt(cfg.MaxSize, 100)
|
||||||
maxBackups := applyDefaultInt(cfg.MaxBackups, 5)
|
maxBackups := applyDefaultInt(cfg.MaxBackups, 5)
|
||||||
maxAge := applyDefaultInt(cfg.MaxAge, 30)
|
maxAge := applyDefaultInt(cfg.MaxAge, 30)
|
||||||
compress := cfg.Compress || cfg.MaxSize == 0 && cfg.MaxBackups == 0 && cfg.MaxAge == 0
|
compress := cfg.Compress || (cfg.MaxSize == 0 && cfg.MaxBackups == 0 && cfg.MaxAge == 0)
|
||||||
|
|
||||||
lj := &lumberjack.Logger{
|
lj := &lumberjack.Logger{
|
||||||
Filename: cfg.FilePath,
|
Filename: cfg.FilePath,
|
||||||
|
|||||||
76
internal/model/application.go
Normal file
76
internal/model/application.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parameter type constants for ParameterSchema.Type.
|
||||||
|
const (
|
||||||
|
ParamTypeString = "string"
|
||||||
|
ParamTypeInteger = "integer"
|
||||||
|
ParamTypeEnum = "enum"
|
||||||
|
ParamTypeFile = "file"
|
||||||
|
ParamTypeDirectory = "directory"
|
||||||
|
ParamTypeBoolean = "boolean"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Application represents a parameterized application definition for HPC job submission.
|
||||||
|
type Application struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
Name string `gorm:"uniqueIndex;size:255;not null" json:"name"`
|
||||||
|
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||||
|
Icon string `gorm:"size:255" json:"icon,omitempty"`
|
||||||
|
Category string `gorm:"size:255" json:"category,omitempty"`
|
||||||
|
ScriptTemplate string `gorm:"type:text;not null" json:"script_template"`
|
||||||
|
Parameters json.RawMessage `gorm:"type:json" json:"parameters,omitempty"`
|
||||||
|
Scope string `gorm:"size:50;default:'system'" json:"scope,omitempty"`
|
||||||
|
CreatedBy int64 `json:"created_by,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Application) TableName() string {
|
||||||
|
return "hpc_applications"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParameterSchema defines a single parameter in an application's form schema.
|
||||||
|
type ParameterSchema struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Label string `json:"label,omitempty"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required bool `json:"required,omitempty"`
|
||||||
|
Default string `json:"default,omitempty"`
|
||||||
|
Options []string `json:"options,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateApplicationRequest is the DTO for creating a new application.
|
||||||
|
type CreateApplicationRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
Category string `json:"category,omitempty"`
|
||||||
|
ScriptTemplate string `json:"script_template" binding:"required"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
Scope string `json:"scope,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateApplicationRequest is the DTO for updating an existing application.
|
||||||
|
// All fields are optional. Parameters uses *json.RawMessage to distinguish
|
||||||
|
// between "not provided" (nil) and "set to empty" (non-nil).
|
||||||
|
type UpdateApplicationRequest struct {
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
Icon *string `json:"icon,omitempty"`
|
||||||
|
Category *string `json:"category,omitempty"`
|
||||||
|
ScriptTemplate *string `json:"script_template,omitempty"`
|
||||||
|
Parameters *json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
Scope *string `json:"scope,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplicationSubmitRequest is the DTO for submitting a job from an application.
|
||||||
|
// ApplicationID is parsed from the URL :id parameter, not included in the body.
|
||||||
|
type ApplicationSubmitRequest struct {
|
||||||
|
Values map[string]string `json:"values" binding:"required"`
|
||||||
|
}
|
||||||
@@ -1,23 +1,73 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
// NodeResponse is the simplified API response for a node.
|
// NodeResponse is the API response for a node.
|
||||||
type NodeResponse struct {
|
type NodeResponse struct {
|
||||||
Name string `json:"name"`
|
// Identity
|
||||||
State []string `json:"state"`
|
Name string `json:"name"` // 节点主机名
|
||||||
CPUs int32 `json:"cpus"`
|
State []string `json:"state"` // 节点状态 (e.g. ["IDLE"], ["ALLOCATED","COMPLETING"])
|
||||||
RealMemory int64 `json:"real_memory"`
|
Reason string `json:"reason,omitempty"` // 节点 DOWN/DRAIN 的原因
|
||||||
AllocMem int64 `json:"alloc_memory,omitempty"`
|
ReasonSetByUser string `json:"reason_set_by_user,omitempty"` // 设置原因的用户
|
||||||
Arch string `json:"architecture,omitempty"`
|
|
||||||
OS string `json:"operating_system,omitempty"`
|
// CPU Resources
|
||||||
|
CPUs int32 `json:"cpus"` // 总 CPU 核数
|
||||||
|
AllocCpus *int32 `json:"alloc_cpus,omitempty"` // 已分配 CPU 核数
|
||||||
|
Cores *int32 `json:"cores,omitempty"` // 物理核心数
|
||||||
|
Sockets *int32 `json:"sockets,omitempty"` // CPU 插槽数
|
||||||
|
Threads *int32 `json:"threads,omitempty"` // 每核线程数
|
||||||
|
CpuLoad *int32 `json:"cpu_load,omitempty"` // CPU 负载 (内核 nice 值乘以 100)
|
||||||
|
|
||||||
|
// Memory (MiB)
|
||||||
|
RealMemory int64 `json:"real_memory"` // 物理内存总量
|
||||||
|
AllocMemory int64 `json:"alloc_memory,omitempty"` // 已分配内存
|
||||||
|
FreeMem *int64 `json:"free_mem,omitempty"` // 空闲内存
|
||||||
|
|
||||||
|
// Hardware
|
||||||
|
Arch string `json:"architecture,omitempty"` // 系统架构 (e.g. x86_64)
|
||||||
|
OS string `json:"operating_system,omitempty"` // 操作系统版本
|
||||||
|
Gres string `json:"gres,omitempty"` // 可用通用资源 (e.g. "gpu:4")
|
||||||
|
GresUsed string `json:"gres_used,omitempty"` // 已使用的通用资源 (e.g. "gpu:2")
|
||||||
|
|
||||||
|
// Network
|
||||||
|
Address string `json:"address,omitempty"` // 节点地址 (IP)
|
||||||
|
Hostname string `json:"hostname,omitempty"` // 节点主机名 (可能与 Name 不同)
|
||||||
|
|
||||||
|
// Scheduling
|
||||||
|
Weight *int32 `json:"weight,omitempty"` // 调度权重
|
||||||
|
Features string `json:"features,omitempty"` // 节点特性标签 (可修改)
|
||||||
|
ActiveFeatures string `json:"active_features,omitempty"` // 当前生效的特性标签 (只读)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PartitionResponse is the simplified API response for a partition.
|
// PartitionResponse is the API response for a partition.
|
||||||
type PartitionResponse struct {
|
type PartitionResponse struct {
|
||||||
Name string `json:"name"`
|
// Identity
|
||||||
State []string `json:"state"`
|
Name string `json:"name"` // 分区名称
|
||||||
Nodes string `json:"nodes,omitempty"`
|
State []string `json:"state"` // 分区状态 (e.g. ["UP"], ["DOWN","DRAIN"])
|
||||||
TotalCPUs int32 `json:"total_cpus,omitempty"`
|
Default bool `json:"default,omitempty"` // 是否为默认分区
|
||||||
TotalNodes int32 `json:"total_nodes,omitempty"`
|
|
||||||
MaxTime string `json:"max_time,omitempty"`
|
// Nodes
|
||||||
Default bool `json:"default,omitempty"`
|
Nodes string `json:"nodes,omitempty"` // 分区包含的节点列表
|
||||||
|
TotalNodes int32 `json:"total_nodes,omitempty"` // 节点总数
|
||||||
|
|
||||||
|
// CPUs
|
||||||
|
TotalCPUs int32 `json:"total_cpus,omitempty"` // CPU 总核数
|
||||||
|
MaxCPUsPerNode *int32 `json:"max_cpus_per_node,omitempty"` // 每节点最大 CPU 核数
|
||||||
|
|
||||||
|
// Limits
|
||||||
|
MaxTime string `json:"max_time,omitempty"` // 最大运行时间 (分钟,"UNLIMITED" 表示无限)
|
||||||
|
MaxNodes *int32 `json:"max_nodes,omitempty"` // 单作业最大节点数
|
||||||
|
MinNodes *int32 `json:"min_nodes,omitempty"` // 单作业最小节点数
|
||||||
|
DefaultTime string `json:"default_time,omitempty"` // 默认运行时间限制
|
||||||
|
GraceTime *int32 `json:"grace_time,omitempty"` // 作业抢占后的宽限时间 (秒)
|
||||||
|
|
||||||
|
// Priority
|
||||||
|
Priority *int32 `json:"priority,omitempty"` // 分区内作业优先级因子
|
||||||
|
|
||||||
|
// Access Control - QOS
|
||||||
|
QOSAllowed string `json:"qos_allowed,omitempty"` // 允许使用的 QOS 列表
|
||||||
|
QOSDeny string `json:"qos_deny,omitempty"` // 禁止使用的 QOS 列表
|
||||||
|
QOSAssigned string `json:"qos_assigned,omitempty"` // 分区默认分配的 QOS
|
||||||
|
|
||||||
|
// Access Control - Accounts
|
||||||
|
AccountsAllowed string `json:"accounts_allowed,omitempty"` // 允许使用的账户列表
|
||||||
|
AccountsDeny string `json:"accounts_deny,omitempty"` // 禁止使用的账户列表
|
||||||
}
|
}
|
||||||
|
|||||||
193
internal/model/file.go
Normal file
193
internal/model/file.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileBlob represents a physical file stored in MinIO, deduplicated by SHA256.
|
||||||
|
type FileBlob struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
SHA256 string `gorm:"uniqueIndex;size:64;not null" json:"sha256"`
|
||||||
|
MinioKey string `gorm:"size:255;not null" json:"minio_key"`
|
||||||
|
FileSize int64 `gorm:"not null" json:"file_size"`
|
||||||
|
MimeType string `gorm:"size:255" json:"mime_type"`
|
||||||
|
RefCount int `gorm:"not null;default:0" json:"ref_count"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (FileBlob) TableName() string {
|
||||||
|
return "hpc_file_blobs"
|
||||||
|
}
|
||||||
|
|
||||||
|
// File represents a logical file visible to users, backed by a FileBlob.
|
||||||
|
type File struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
Name string `gorm:"size:255;not null" json:"name"`
|
||||||
|
FolderID *int64 `gorm:"index" json:"folder_id,omitempty"`
|
||||||
|
BlobSHA256 string `gorm:"size:64;not null" json:"blob_sha256"`
|
||||||
|
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (File) TableName() string {
|
||||||
|
return "hpc_files"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Folder represents a directory in the virtual file system.
|
||||||
|
type Folder struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
Name string `gorm:"size:255;not null" json:"name"`
|
||||||
|
ParentID *int64 `gorm:"index" json:"parent_id,omitempty"`
|
||||||
|
Path string `gorm:"uniqueIndex;size:768;not null" json:"path"`
|
||||||
|
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Folder) TableName() string {
|
||||||
|
return "hpc_folders"
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadSession represents an in-progress chunked upload.
|
||||||
|
// State transitions: pending→uploading, pending→completed(zero-byte), uploading→merging,
|
||||||
|
// uploading→cancelled, merging→completed, merging→failed, any→expired
|
||||||
|
type UploadSession struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
FileName string `gorm:"size:255;not null" json:"file_name"`
|
||||||
|
FileSize int64 `gorm:"not null" json:"file_size"`
|
||||||
|
ChunkSize int64 `gorm:"not null" json:"chunk_size"`
|
||||||
|
TotalChunks int `gorm:"not null" json:"total_chunks"`
|
||||||
|
SHA256 string `gorm:"size:64;not null" json:"sha256"`
|
||||||
|
FolderID *int64 `gorm:"index" json:"folder_id,omitempty"`
|
||||||
|
Status string `gorm:"size:20;not null;default:pending" json:"status"`
|
||||||
|
MinioPrefix string `gorm:"size:255;not null" json:"minio_prefix"`
|
||||||
|
MimeType string `gorm:"size:255;default:'application/octet-stream'" json:"mime_type"`
|
||||||
|
UserID *int64 `gorm:"index" json:"user_id,omitempty"`
|
||||||
|
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UploadSession) TableName() string {
|
||||||
|
return "hpc_upload_sessions"
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadChunk represents a single chunk of an upload session.
|
||||||
|
type UploadChunk struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
SessionID int64 `gorm:"not null;uniqueIndex:idx_session_chunk" json:"session_id"`
|
||||||
|
ChunkIndex int `gorm:"not null;uniqueIndex:idx_session_chunk" json:"chunk_index"`
|
||||||
|
MinioKey string `gorm:"size:255;not null" json:"minio_key"`
|
||||||
|
SHA256 string `gorm:"size:64" json:"sha256,omitempty"`
|
||||||
|
Size int64 `gorm:"not null" json:"size"`
|
||||||
|
Status string `gorm:"size:20;not null;default:pending" json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UploadChunk) TableName() string {
|
||||||
|
return "hpc_upload_chunks"
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitUploadRequest is the DTO for initiating a chunked upload.
|
||||||
|
type InitUploadRequest struct {
|
||||||
|
FileName string `json:"file_name" binding:"required"`
|
||||||
|
FileSize int64 `json:"file_size" binding:"required"`
|
||||||
|
SHA256 string `json:"sha256" binding:"required"`
|
||||||
|
FolderID *int64 `json:"folder_id,omitempty"`
|
||||||
|
ChunkSize *int64 `json:"chunk_size,omitempty"`
|
||||||
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFolderRequest is the DTO for creating a new folder.
|
||||||
|
type CreateFolderRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
ParentID *int64 `json:"parent_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadSessionResponse is the DTO returned when creating/querying an upload session.
|
||||||
|
type UploadSessionResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
FileName string `json:"file_name"`
|
||||||
|
FileSize int64 `json:"file_size"`
|
||||||
|
ChunkSize int64 `json:"chunk_size"`
|
||||||
|
TotalChunks int `json:"total_chunks"`
|
||||||
|
SHA256 string `json:"sha256"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
UploadedChunks []int `json:"uploaded_chunks"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileResponse is the DTO for a file in API responses.
|
||||||
|
type FileResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
FolderID *int64 `json:"folder_id,omitempty"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
MimeType string `json:"mime_type"`
|
||||||
|
SHA256 string `json:"sha256"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FolderResponse is the DTO for a folder in API responses.
|
||||||
|
type FolderResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
ParentID *int64 `json:"parent_id,omitempty"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
FileCount int64 `json:"file_count"`
|
||||||
|
SubFolderCount int64 `json:"subfolder_count"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFilesResponse is the paginated response for listing files.
|
||||||
|
type ListFilesResponse struct {
|
||||||
|
Files []FileResponse `json:"files"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFileName rejects empty, "..", "/", "\", null bytes, control chars, leading/trailing spaces.
|
||||||
|
func ValidateFileName(name string) error {
|
||||||
|
if name == "" {
|
||||||
|
return fmt.Errorf("file name cannot be empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) != name {
|
||||||
|
return fmt.Errorf("file name cannot have leading or trailing spaces")
|
||||||
|
}
|
||||||
|
if name == ".." {
|
||||||
|
return fmt.Errorf("file name cannot be '..'")
|
||||||
|
}
|
||||||
|
if strings.Contains(name, "/") || strings.Contains(name, "\\") {
|
||||||
|
return fmt.Errorf("file name cannot contain '/' or '\\'")
|
||||||
|
}
|
||||||
|
for _, r := range name {
|
||||||
|
if r == 0 {
|
||||||
|
return fmt.Errorf("file name cannot contain null bytes")
|
||||||
|
}
|
||||||
|
if unicode.IsControl(r) {
|
||||||
|
return fmt.Errorf("file name cannot contain control characters")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFolderName rejects same as ValidateFileName plus ".".
|
||||||
|
func ValidateFolderName(name string) error {
|
||||||
|
if name == "." {
|
||||||
|
return fmt.Errorf("folder name cannot be '.'")
|
||||||
|
}
|
||||||
|
return ValidateFileName(name)
|
||||||
|
}
|
||||||
@@ -2,46 +2,95 @@ package model
|
|||||||
|
|
||||||
// SubmitJobRequest is the API request for submitting a job.
|
// SubmitJobRequest is the API request for submitting a job.
|
||||||
type SubmitJobRequest struct {
|
type SubmitJobRequest struct {
|
||||||
Script string `json:"script" binding:"required"`
|
Script string `json:"script"` // 作业脚本内容
|
||||||
Partition string `json:"partition,omitempty"`
|
Partition string `json:"partition,omitempty"` // 提交到的分区
|
||||||
QOS string `json:"qos,omitempty"`
|
QOS string `json:"qos,omitempty"` // 使用的 QOS 策略
|
||||||
CPUs int32 `json:"cpus,omitempty"`
|
CPUs int32 `json:"cpus,omitempty"` // 请求的 CPU 核数
|
||||||
Memory string `json:"memory,omitempty"`
|
Memory string `json:"memory,omitempty"` // 请求的内存大小
|
||||||
TimeLimit string `json:"time_limit,omitempty"`
|
TimeLimit string `json:"time_limit,omitempty"` // 运行时间限制 (分钟)
|
||||||
JobName string `json:"job_name,omitempty"`
|
JobName string `json:"job_name,omitempty"` // 作业名称
|
||||||
Environment map[string]string `json:"environment,omitempty"`
|
Environment map[string]string `json:"environment,omitempty"` // 环境变量键值对
|
||||||
|
WorkDir string `json:"work_dir,omitempty"` // 作业工作目录
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobResponse is the simplified API response for a job.
|
// JobResponse is the API response for a job.
|
||||||
type JobResponse struct {
|
type JobResponse struct {
|
||||||
JobID int32 `json:"job_id"`
|
// Identity
|
||||||
Name string `json:"name"`
|
JobID int32 `json:"job_id"` // Slurm 作业 ID
|
||||||
State []string `json:"job_state"`
|
Name string `json:"name"` // 作业名称
|
||||||
Partition string `json:"partition"`
|
State []string `json:"job_state"` // 作业当前状态 (e.g. ["RUNNING"], ["PENDING","REQUEUED"])
|
||||||
SubmitTime *int64 `json:"submit_time,omitempty"`
|
StateReason string `json:"state_reason,omitempty"` // 作业等待/失败的原因
|
||||||
StartTime *int64 `json:"start_time,omitempty"`
|
|
||||||
EndTime *int64 `json:"end_time,omitempty"`
|
// Scheduling
|
||||||
ExitCode *int32 `json:"exit_code,omitempty"`
|
Partition string `json:"partition"` // 所属分区
|
||||||
Nodes string `json:"nodes,omitempty"`
|
QOS string `json:"qos,omitempty"` // 使用的 QOS 策略
|
||||||
|
Priority *int32 `json:"priority,omitempty"` // 作业优先级
|
||||||
|
TimeLimit string `json:"time_limit,omitempty"` // 运行时间限制 (分钟,"UNLIMITED" 表示无限)
|
||||||
|
|
||||||
|
// Ownership
|
||||||
|
Account string `json:"account,omitempty"` // 计费账户
|
||||||
|
User string `json:"user,omitempty"` // 提交用户
|
||||||
|
Cluster string `json:"cluster,omitempty"` // 所属集群
|
||||||
|
|
||||||
|
// Resources
|
||||||
|
Cpus *int32 `json:"cpus,omitempty"` // 分配/请求的 CPU 核数
|
||||||
|
Tasks *int32 `json:"tasks,omitempty"` // 任务数
|
||||||
|
NodeCount *int32 `json:"node_count,omitempty"` // 节点数
|
||||||
|
Nodes string `json:"nodes,omitempty"` // 分配的节点列表
|
||||||
|
BatchHost string `json:"batch_host,omitempty"` // 批处理主节点
|
||||||
|
|
||||||
|
// Timing (Unix timestamp)
|
||||||
|
SubmitTime *int64 `json:"submit_time,omitempty"` // 提交时间
|
||||||
|
StartTime *int64 `json:"start_time,omitempty"` // 开始运行时间
|
||||||
|
EndTime *int64 `json:"end_time,omitempty"` // 结束/预计结束时间
|
||||||
|
|
||||||
|
// Result
|
||||||
|
ExitCode *int32 `json:"exit_code,omitempty"` // 退出码 (nil 表示未结束)
|
||||||
|
|
||||||
|
// IO Paths
|
||||||
|
StdOut string `json:"standard_output,omitempty"` // 标准输出文件路径
|
||||||
|
StdErr string `json:"standard_error,omitempty"` // 标准错误文件路径
|
||||||
|
StdIn string `json:"standard_input,omitempty"` // 标准输入文件路径
|
||||||
|
WorkDir string `json:"working_directory,omitempty"` // 工作目录
|
||||||
|
Command string `json:"command,omitempty"` // 执行的命令
|
||||||
|
|
||||||
|
// Array Job
|
||||||
|
ArrayJobID *int32 `json:"array_job_id,omitempty"` // 数组作业的父 Job ID
|
||||||
|
ArrayTaskID *int32 `json:"array_task_id,omitempty"` // 数组作业中的子任务 ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobListResponse is the paginated response for job listings.
|
// JobListResponse is the paginated response for job listings.
|
||||||
type JobListResponse struct {
|
type JobListResponse struct {
|
||||||
Jobs []JobResponse `json:"jobs"`
|
Jobs []JobResponse `json:"jobs"` // 作业列表
|
||||||
Total int `json:"total"`
|
Total int `json:"total"` // 符合条件的作业总数
|
||||||
Page int `json:"page"`
|
Page int `json:"page"` // 当前页码 (从 1 开始)
|
||||||
PageSize int `json:"page_size"`
|
PageSize int `json:"page_size"` // 每页条数
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobListQuery contains pagination parameters for active job listing.
|
||||||
|
type JobListQuery struct {
|
||||||
|
Page int `form:"page,default=1" json:"page,omitempty"` // 页码 (从 1 开始)
|
||||||
|
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"` // 每页条数
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobHistoryQuery contains query parameters for job history.
|
// JobHistoryQuery contains query parameters for job history.
|
||||||
type JobHistoryQuery struct {
|
type JobHistoryQuery struct {
|
||||||
Users string `form:"users" json:"users,omitempty"`
|
Users string `form:"users" json:"users,omitempty"` // 按用户名过滤 (逗号分隔)
|
||||||
StartTime string `form:"start_time" json:"start_time,omitempty"`
|
StartTime string `form:"start_time" json:"start_time,omitempty"` // 作业开始时间下限 (Unix 时间戳)
|
||||||
EndTime string `form:"end_time" json:"end_time,omitempty"`
|
EndTime string `form:"end_time" json:"end_time,omitempty"` // 作业结束时间上限 (Unix 时间戳)
|
||||||
Account string `form:"account" json:"account,omitempty"`
|
SubmitTime string `form:"submit_time" json:"submit_time,omitempty"` // 作业提交时间过滤 (Unix 时间戳)
|
||||||
Partition string `form:"partition" json:"partition,omitempty"`
|
Account string `form:"account" json:"account,omitempty"` // 按计费账户过滤
|
||||||
State string `form:"state" json:"state,omitempty"`
|
Partition string `form:"partition" json:"partition,omitempty"` // 按分区过滤
|
||||||
JobName string `form:"job_name" json:"job_name,omitempty"`
|
State string `form:"state" json:"state,omitempty"` // 按作业状态过滤 (e.g. "COMPLETED", "FAILED")
|
||||||
Page int `form:"page,default=1" json:"page,omitempty"`
|
JobName string `form:"job_name" json:"job_name,omitempty"` // 按作业名称过滤
|
||||||
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"`
|
Cluster string `form:"cluster" json:"cluster,omitempty"` // 按集群名称过滤
|
||||||
|
Qos string `form:"qos" json:"qos,omitempty"` // 按 QOS 策略过滤
|
||||||
|
Constraints string `form:"constraints" json:"constraints,omitempty"` // 按节点约束过滤
|
||||||
|
ExitCode string `form:"exit_code" json:"exit_code,omitempty"` // 按退出码过滤
|
||||||
|
Node string `form:"node" json:"node,omitempty"` // 按分配节点过滤
|
||||||
|
Reservation string `form:"reservation" json:"reservation,omitempty"` // 按预约名称过滤
|
||||||
|
Groups string `form:"groups" json:"groups,omitempty"` // 按用户组过滤
|
||||||
|
Wckey string `form:"wckey" json:"wckey,omitempty"` // 按 WCKey (Workload Characterization Key) 过滤
|
||||||
|
Page int `form:"page,default=1" json:"page,omitempty"` // 页码 (从 1 开始)
|
||||||
|
PageSize int `form:"page_size,default=20" json:"page_size,omitempty"` // 每页条数
|
||||||
}
|
}
|
||||||
|
|||||||
93
internal/model/task.go
Normal file
93
internal/model/task.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Task status constants.
|
||||||
|
const (
|
||||||
|
TaskStatusSubmitted = "submitted"
|
||||||
|
TaskStatusPreparing = "preparing"
|
||||||
|
TaskStatusDownloading = "downloading"
|
||||||
|
TaskStatusReady = "ready"
|
||||||
|
TaskStatusQueued = "queued"
|
||||||
|
TaskStatusRunning = "running"
|
||||||
|
TaskStatusCompleted = "completed"
|
||||||
|
TaskStatusFailed = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Task step constants for step-level retry tracking.
|
||||||
|
const (
|
||||||
|
TaskStepPreparing = "preparing"
|
||||||
|
TaskStepDownloading = "downloading"
|
||||||
|
TaskStepSubmitting = "submitting"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Task represents an HPC task submitted through the application framework.
|
||||||
|
type Task struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
|
TaskName string `gorm:"size:255" json:"task_name"`
|
||||||
|
AppID int64 `json:"app_id"`
|
||||||
|
AppName string `gorm:"size:255" json:"app_name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CurrentStep string `json:"current_step"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
Values json.RawMessage `gorm:"type:text" json:"values,omitempty"`
|
||||||
|
InputFileIDs json.RawMessage `json:"input_file_ids" gorm:"column:input_file_ids;type:text"`
|
||||||
|
Script string `json:"script,omitempty"`
|
||||||
|
SlurmJobID *int32 `json:"slurm_job_id,omitempty"`
|
||||||
|
WorkDir string `json:"work_dir,omitempty"`
|
||||||
|
Partition string `json:"partition,omitempty"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
SubmittedAt time.Time `json:"submitted_at"`
|
||||||
|
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||||
|
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Task) TableName() string {
|
||||||
|
return "hpc_tasks"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTaskRequest is the DTO for creating a new task.
|
||||||
|
type CreateTaskRequest struct {
|
||||||
|
AppID int64 `json:"app_id" binding:"required"`
|
||||||
|
TaskName string `json:"task_name"`
|
||||||
|
Values map[string]string `json:"values"`
|
||||||
|
InputFileIDs []int64 `json:"file_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskResponse is the DTO returned in API responses.
|
||||||
|
type TaskResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
TaskName string `json:"task_name"`
|
||||||
|
AppID int64 `json:"app_id"`
|
||||||
|
AppName string `json:"app_name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CurrentStep string `json:"current_step"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
SlurmJobID *int32 `json:"slurm_job_id"`
|
||||||
|
WorkDir string `json:"work_dir"`
|
||||||
|
ErrorMessage string `json:"error_message"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskListResponse is the paginated response for listing tasks.
|
||||||
|
type TaskListResponse struct {
|
||||||
|
Items []TaskResponse `json:"items"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskListQuery contains query parameters for listing tasks.
|
||||||
|
type TaskListQuery struct {
|
||||||
|
Page int `form:"page" json:"page,omitempty"`
|
||||||
|
PageSize int `form:"page_size" json:"page_size,omitempty"`
|
||||||
|
Status string `form:"status" json:"status,omitempty"`
|
||||||
|
}
|
||||||
104
internal/model/task_test.go
Normal file
104
internal/model/task_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTask_TableName(t *testing.T) {
|
||||||
|
task := Task{}
|
||||||
|
if got := task.TableName(); got != "hpc_tasks" {
|
||||||
|
t.Errorf("Task.TableName() = %q, want %q", got, "hpc_tasks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTask_JSONRoundTrip(t *testing.T) {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
jobID := int32(42)
|
||||||
|
|
||||||
|
task := Task{
|
||||||
|
ID: 1,
|
||||||
|
TaskName: "test task",
|
||||||
|
AppID: 10,
|
||||||
|
AppName: "GROMACS",
|
||||||
|
Status: TaskStatusRunning,
|
||||||
|
CurrentStep: TaskStepSubmitting,
|
||||||
|
RetryCount: 1,
|
||||||
|
Values: json.RawMessage(`{"np":"4"}`),
|
||||||
|
InputFileIDs: json.RawMessage(`[1,2,3]`),
|
||||||
|
Script: "#!/bin/bash",
|
||||||
|
SlurmJobID: &jobID,
|
||||||
|
WorkDir: "/data/work",
|
||||||
|
Partition: "gpu",
|
||||||
|
ErrorMessage: "",
|
||||||
|
UserID: "user1",
|
||||||
|
SubmittedAt: now,
|
||||||
|
StartedAt: &now,
|
||||||
|
FinishedAt: nil,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(task)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal task: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got Task
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Fatalf("unmarshal task: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.ID != task.ID {
|
||||||
|
t.Errorf("ID = %d, want %d", got.ID, task.ID)
|
||||||
|
}
|
||||||
|
if got.TaskName != task.TaskName {
|
||||||
|
t.Errorf("TaskName = %q, want %q", got.TaskName, task.TaskName)
|
||||||
|
}
|
||||||
|
if got.Status != task.Status {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, task.Status)
|
||||||
|
}
|
||||||
|
if got.CurrentStep != task.CurrentStep {
|
||||||
|
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, task.CurrentStep)
|
||||||
|
}
|
||||||
|
if got.RetryCount != task.RetryCount {
|
||||||
|
t.Errorf("RetryCount = %d, want %d", got.RetryCount, task.RetryCount)
|
||||||
|
}
|
||||||
|
if got.SlurmJobID == nil || *got.SlurmJobID != jobID {
|
||||||
|
t.Errorf("SlurmJobID = %v, want %d", got.SlurmJobID, jobID)
|
||||||
|
}
|
||||||
|
if got.UserID != task.UserID {
|
||||||
|
t.Errorf("UserID = %q, want %q", got.UserID, task.UserID)
|
||||||
|
}
|
||||||
|
if string(got.Values) != string(task.Values) {
|
||||||
|
t.Errorf("Values = %s, want %s", got.Values, task.Values)
|
||||||
|
}
|
||||||
|
if string(got.InputFileIDs) != string(task.InputFileIDs) {
|
||||||
|
t.Errorf("InputFileIDs = %s, want %s", got.InputFileIDs, task.InputFileIDs)
|
||||||
|
}
|
||||||
|
if got.FinishedAt != nil {
|
||||||
|
t.Errorf("FinishedAt = %v, want nil", got.FinishedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateTaskRequest_JSONBinding(t *testing.T) {
|
||||||
|
payload := `{"app_id":5,"task_name":"my task","values":{"np":"8"},"file_ids":[10,20]}`
|
||||||
|
var req CreateTaskRequest
|
||||||
|
if err := json.Unmarshal([]byte(payload), &req); err != nil {
|
||||||
|
t.Fatalf("unmarshal CreateTaskRequest: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.AppID != 5 {
|
||||||
|
t.Errorf("AppID = %d, want 5", req.AppID)
|
||||||
|
}
|
||||||
|
if req.TaskName != "my task" {
|
||||||
|
t.Errorf("TaskName = %q, want %q", req.TaskName, "my task")
|
||||||
|
}
|
||||||
|
if v, ok := req.Values["np"]; !ok || v != "8" {
|
||||||
|
t.Errorf("Values[\"np\"] = %q, want %q", v, "8")
|
||||||
|
}
|
||||||
|
if len(req.InputFileIDs) != 2 || req.InputFileIDs[0] != 10 || req.InputFileIDs[1] != 20 {
|
||||||
|
t.Errorf("InputFileIDs = %v, want [10 20]", req.InputFileIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// JobTemplate represents a saved job template.
|
|
||||||
type JobTemplate struct {
|
|
||||||
ID int64 `json:"id" gorm:"primaryKey;autoIncrement"`
|
|
||||||
Name string `json:"name" gorm:"uniqueIndex;size:255;not null"`
|
|
||||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
|
||||||
Script string `json:"script" gorm:"type:text;not null"`
|
|
||||||
Partition string `json:"partition,omitempty" gorm:"size:255"`
|
|
||||||
QOS string `json:"qos,omitempty" gorm:"column:qos;size:255"`
|
|
||||||
CPUs int `json:"cpus,omitempty" gorm:"column:cpus"`
|
|
||||||
Memory string `json:"memory,omitempty" gorm:"size:50"`
|
|
||||||
TimeLimit string `json:"time_limit,omitempty" gorm:"column:time_limit;size:50"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName specifies the database table name for GORM.
|
|
||||||
func (JobTemplate) TableName() string { return "job_templates" }
|
|
||||||
|
|
||||||
// CreateTemplateRequest is the API request for creating a template.
|
|
||||||
type CreateTemplateRequest struct {
|
|
||||||
Name string `json:"name" binding:"required"`
|
|
||||||
Description string `json:"description,omitempty"`
|
|
||||||
Script string `json:"script" binding:"required"`
|
|
||||||
Partition string `json:"partition,omitempty"`
|
|
||||||
QOS string `json:"qos,omitempty"`
|
|
||||||
CPUs int `json:"cpus,omitempty"`
|
|
||||||
Memory string `json:"memory,omitempty"`
|
|
||||||
TimeLimit string `json:"time_limit,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTemplateRequest is the API request for updating a template.
|
|
||||||
type UpdateTemplateRequest struct {
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
Description string `json:"description,omitempty"`
|
|
||||||
Script string `json:"script,omitempty"`
|
|
||||||
Partition string `json:"partition,omitempty"`
|
|
||||||
QOS string `json:"qos,omitempty"`
|
|
||||||
CPUs int `json:"cpus,omitempty"`
|
|
||||||
Memory string `json:"memory,omitempty"`
|
|
||||||
TimeLimit string `json:"time_limit,omitempty"`
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -42,3 +46,97 @@ func InternalError(c *gin.Context, msg string) {
|
|||||||
func ErrorWithStatus(c *gin.Context, code int, msg string) {
|
func ErrorWithStatus(c *gin.Context, code int, msg string) {
|
||||||
c.JSON(code, APIResponse{Success: false, Error: msg})
|
c.JSON(code, APIResponse{Success: false, Error: msg})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseRange parses an HTTP Range header (RFC 7233).
|
||||||
|
// Only single-part ranges are supported: bytes=start-end, bytes=start-, bytes=-suffix.
|
||||||
|
// Multi-part ranges (bytes=0-100,200-300) return an error.
|
||||||
|
func ParseRange(rangeHeader string, fileSize int64) (start, end int64, err error) {
|
||||||
|
if rangeHeader == "" {
|
||||||
|
return 0, 0, fmt.Errorf("empty range header")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(rangeHeader, "bytes=") {
|
||||||
|
return 0, 0, fmt.Errorf("invalid range unit: %s", rangeHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
|
||||||
|
|
||||||
|
if strings.Contains(rangeSpec, ",") {
|
||||||
|
return 0, 0, fmt.Errorf("multi-part ranges are not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
rangeSpec = strings.TrimSpace(rangeSpec)
|
||||||
|
parts := strings.Split(rangeSpec, "-")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid range format: %s", rangeSpec)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parts[0] == "" {
|
||||||
|
suffix, parseErr := strconv.ParseInt(parts[1], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
return 0, 0, fmt.Errorf("invalid suffix range: %s", parts[1])
|
||||||
|
}
|
||||||
|
if suffix <= 0 || suffix > fileSize {
|
||||||
|
return 0, 0, fmt.Errorf("suffix range %d exceeds file size %d", suffix, fileSize)
|
||||||
|
}
|
||||||
|
start = fileSize - suffix
|
||||||
|
end = fileSize - 1
|
||||||
|
} else if parts[1] == "" {
|
||||||
|
start, err = strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("invalid range start: %s", parts[0])
|
||||||
|
}
|
||||||
|
if start >= fileSize {
|
||||||
|
return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize)
|
||||||
|
}
|
||||||
|
end = fileSize - 1
|
||||||
|
} else {
|
||||||
|
start, err = strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("invalid range start: %s", parts[0])
|
||||||
|
}
|
||||||
|
end, err = strconv.ParseInt(parts[1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("invalid range end: %s", parts[1])
|
||||||
|
}
|
||||||
|
if start > end {
|
||||||
|
return 0, 0, fmt.Errorf("range start %d > end %d", start, end)
|
||||||
|
}
|
||||||
|
if start >= fileSize {
|
||||||
|
return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize)
|
||||||
|
}
|
||||||
|
if end >= fileSize {
|
||||||
|
end = fileSize - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return start, end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamFile sends a full file as an HTTP response with proper headers.
|
||||||
|
func StreamFile(c *gin.Context, reader io.ReadCloser, filename string, fileSize int64, contentType string) {
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
|
||||||
|
c.Header("Content-Type", contentType)
|
||||||
|
c.Header("Content-Length", strconv.FormatInt(fileSize, 10))
|
||||||
|
c.Header("Accept-Ranges", "bytes")
|
||||||
|
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
io.Copy(c.Writer, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamRange sends a partial content response (206) for a byte range.
|
||||||
|
func StreamRange(c *gin.Context, reader io.ReadCloser, start, end, totalSize int64, contentType string) {
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
contentLength := end - start + 1
|
||||||
|
|
||||||
|
c.Header("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize))
|
||||||
|
c.Header("Content-Type", contentType)
|
||||||
|
c.Header("Content-Length", strconv.FormatInt(contentLength, 10))
|
||||||
|
c.Header("Accept-Ranges", "bytes")
|
||||||
|
|
||||||
|
c.Status(http.StatusPartialContent)
|
||||||
|
io.Copy(c.Writer, reader)
|
||||||
|
}
|
||||||
|
|||||||
@@ -114,3 +114,65 @@ func TestErrorWithStatus(t *testing.T) {
|
|||||||
t.Fatalf("expected error 'already exists', got '%s'", resp.Error)
|
t.Fatalf("expected error 'already exists', got '%s'", resp.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseRangeStandard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
rangeHeader string
|
||||||
|
fileSize int64
|
||||||
|
wantStart int64
|
||||||
|
wantEnd int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"bytes=0-1023", 10000, 0, 1023, false},
|
||||||
|
{"bytes=1024-", 10000, 1024, 9999, false},
|
||||||
|
{"bytes=-1024", 10000, 8976, 9999, false},
|
||||||
|
{"bytes=0-0", 10000, 0, 0, false},
|
||||||
|
{"bytes=9999-", 10000, 9999, 9999, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
start, end, err := ParseRange(tt.rangeHeader, tt.fileSize)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ParseRange(%q, %d) error = %v, wantErr %v", tt.rangeHeader, tt.fileSize, err, tt.wantErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !tt.wantErr {
|
||||||
|
if start != tt.wantStart || end != tt.wantEnd {
|
||||||
|
t.Errorf("ParseRange(%q, %d) = (%d, %d), want (%d, %d)", tt.rangeHeader, tt.fileSize, start, end, tt.wantStart, tt.wantEnd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRangeInvalidAndMultiPart(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
rangeHeader string
|
||||||
|
fileSize int64
|
||||||
|
}{
|
||||||
|
{"", 10000},
|
||||||
|
{"bytes=9999-0", 10000},
|
||||||
|
{"bytes=20000-", 10000},
|
||||||
|
{"bytes=0-100,200-300", 10000},
|
||||||
|
{"bytes=0-100, 400-500", 10000},
|
||||||
|
{"bytes=", 10000},
|
||||||
|
{"chars=0-100", 10000},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, _, err := ParseRange(tt.rangeHeader, tt.fileSize)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("ParseRange(%q, %d) expected error, got nil", tt.rangeHeader, tt.fileSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRangeEdgeCases(t *testing.T) {
|
||||||
|
start, end, err := ParseRange("bytes=0-99999", 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if end != 9999 {
|
||||||
|
t.Errorf("end = %d, want 9999 (clamped to fileSize-1)", end)
|
||||||
|
}
|
||||||
|
if start != 0 {
|
||||||
|
t.Errorf("start = %d, want 0", start)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,16 +25,44 @@ type ClusterHandler interface {
|
|||||||
GetDiag(c *gin.Context)
|
GetDiag(c *gin.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TemplateHandler interface {
|
type ApplicationHandler interface {
|
||||||
ListTemplates(c *gin.Context)
|
ListApplications(c *gin.Context)
|
||||||
CreateTemplate(c *gin.Context)
|
CreateApplication(c *gin.Context)
|
||||||
GetTemplate(c *gin.Context)
|
GetApplication(c *gin.Context)
|
||||||
UpdateTemplate(c *gin.Context)
|
UpdateApplication(c *gin.Context)
|
||||||
DeleteTemplate(c *gin.Context)
|
DeleteApplication(c *gin.Context)
|
||||||
|
// SubmitApplication(c *gin.Context) // [已禁用] 已被 POST /tasks 取代
|
||||||
|
}
|
||||||
|
|
||||||
|
type UploadHandler interface {
|
||||||
|
InitUpload(c *gin.Context)
|
||||||
|
GetUploadStatus(c *gin.Context)
|
||||||
|
UploadChunk(c *gin.Context)
|
||||||
|
CompleteUpload(c *gin.Context)
|
||||||
|
CancelUpload(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileHandler interface {
|
||||||
|
ListFiles(c *gin.Context)
|
||||||
|
GetFile(c *gin.Context)
|
||||||
|
DownloadFile(c *gin.Context)
|
||||||
|
DeleteFile(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FolderHandler interface {
|
||||||
|
CreateFolder(c *gin.Context)
|
||||||
|
GetFolder(c *gin.Context)
|
||||||
|
ListFolders(c *gin.Context)
|
||||||
|
DeleteFolder(c *gin.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskHandler interface {
|
||||||
|
CreateTask(c *gin.Context)
|
||||||
|
ListTasks(c *gin.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
|
// NewRouter creates a Gin engine with all API v1 routes registered with real handlers.
|
||||||
func NewRouter(jobH JobHandler, clusterH ClusterHandler, templateH TemplateHandler, logger *zap.Logger) *gin.Engine {
|
func NewRouter(jobH JobHandler, clusterH ClusterHandler, appH ApplicationHandler, uploadH UploadHandler, fileH FileHandler, folderH FolderHandler, taskH TaskHandler, logger *zap.Logger) *gin.Engine {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.Use(gin.Recovery())
|
r.Use(gin.Recovery())
|
||||||
@@ -59,12 +87,47 @@ func NewRouter(jobH JobHandler, clusterH ClusterHandler, templateH TemplateHandl
|
|||||||
|
|
||||||
v1.GET("/diag", clusterH.GetDiag)
|
v1.GET("/diag", clusterH.GetDiag)
|
||||||
|
|
||||||
templates := v1.Group("/templates")
|
apps := v1.Group("/applications")
|
||||||
templates.GET("", templateH.ListTemplates)
|
apps.GET("", appH.ListApplications)
|
||||||
templates.POST("", templateH.CreateTemplate)
|
apps.POST("", appH.CreateApplication)
|
||||||
templates.GET("/:id", templateH.GetTemplate)
|
apps.GET("/:id", appH.GetApplication)
|
||||||
templates.PUT("/:id", templateH.UpdateTemplate)
|
apps.PUT("/:id", appH.UpdateApplication)
|
||||||
templates.DELETE("/:id", templateH.DeleteTemplate)
|
apps.DELETE("/:id", appH.DeleteApplication)
|
||||||
|
// apps.POST("/:id/submit", appH.SubmitApplication) // [已禁用] 已被 POST /tasks 取代
|
||||||
|
|
||||||
|
files := v1.Group("/files")
|
||||||
|
|
||||||
|
if uploadH != nil {
|
||||||
|
uploads := files.Group("/uploads")
|
||||||
|
uploads.POST("", uploadH.InitUpload)
|
||||||
|
uploads.GET("/:id", uploadH.GetUploadStatus)
|
||||||
|
uploads.PUT("/:id/chunks/:index", uploadH.UploadChunk)
|
||||||
|
uploads.POST("/:id/complete", uploadH.CompleteUpload)
|
||||||
|
uploads.DELETE("/:id", uploadH.CancelUpload)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fileH != nil {
|
||||||
|
files.GET("", fileH.ListFiles)
|
||||||
|
files.GET("/:id", fileH.GetFile)
|
||||||
|
files.GET("/:id/download", fileH.DownloadFile)
|
||||||
|
files.DELETE("/:id", fileH.DeleteFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
if folderH != nil {
|
||||||
|
folders := files.Group("/folders")
|
||||||
|
folders.POST("", folderH.CreateFolder)
|
||||||
|
folders.GET("", folderH.ListFolders)
|
||||||
|
folders.GET("/:id", folderH.GetFolder)
|
||||||
|
folders.DELETE("/:id", folderH.DeleteFolder)
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskH != nil {
|
||||||
|
tasks := v1.Group("/tasks")
|
||||||
|
{
|
||||||
|
tasks.POST("", taskH.CreateTask)
|
||||||
|
tasks.GET("", taskH.ListTasks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -95,12 +158,36 @@ func registerPlaceholderRoutes(v1 *gin.RouterGroup) {
|
|||||||
|
|
||||||
v1.GET("/diag", notImplemented)
|
v1.GET("/diag", notImplemented)
|
||||||
|
|
||||||
templates := v1.Group("/templates")
|
apps := v1.Group("/applications")
|
||||||
templates.GET("", notImplemented)
|
apps.GET("", notImplemented)
|
||||||
templates.POST("", notImplemented)
|
apps.POST("", notImplemented)
|
||||||
templates.GET("/:id", notImplemented)
|
apps.GET("/:id", notImplemented)
|
||||||
templates.PUT("/:id", notImplemented)
|
apps.PUT("/:id", notImplemented)
|
||||||
templates.DELETE("/:id", notImplemented)
|
apps.DELETE("/:id", notImplemented)
|
||||||
|
// apps.POST("/:id/submit", notImplemented) // [已禁用] 已被 POST /tasks 取代
|
||||||
|
|
||||||
|
files := v1.Group("/files")
|
||||||
|
|
||||||
|
uploads := files.Group("/uploads")
|
||||||
|
uploads.POST("", notImplemented)
|
||||||
|
uploads.GET("/:id", notImplemented)
|
||||||
|
uploads.PUT("/:id/chunks/:index", notImplemented)
|
||||||
|
uploads.POST("/:id/complete", notImplemented)
|
||||||
|
uploads.DELETE("/:id", notImplemented)
|
||||||
|
|
||||||
|
files.GET("", notImplemented)
|
||||||
|
files.GET("/:id", notImplemented)
|
||||||
|
files.GET("/:id/download", notImplemented)
|
||||||
|
files.DELETE("/:id", notImplemented)
|
||||||
|
|
||||||
|
folders := files.Group("/folders")
|
||||||
|
folders.POST("", notImplemented)
|
||||||
|
folders.GET("", notImplemented)
|
||||||
|
folders.GET("/:id", notImplemented)
|
||||||
|
folders.DELETE("/:id", notImplemented)
|
||||||
|
|
||||||
|
v1.POST("/tasks", notImplemented)
|
||||||
|
v1.GET("/tasks", notImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
func notImplemented(c *gin.Context) {
|
func notImplemented(c *gin.Context) {
|
||||||
|
|||||||
@@ -27,11 +27,12 @@ func TestAllRoutesRegistered(t *testing.T) {
|
|||||||
{"GET", "/api/v1/partitions"},
|
{"GET", "/api/v1/partitions"},
|
||||||
{"GET", "/api/v1/partitions/:name"},
|
{"GET", "/api/v1/partitions/:name"},
|
||||||
{"GET", "/api/v1/diag"},
|
{"GET", "/api/v1/diag"},
|
||||||
{"GET", "/api/v1/templates"},
|
{"GET", "/api/v1/applications"},
|
||||||
{"POST", "/api/v1/templates"},
|
{"POST", "/api/v1/applications"},
|
||||||
{"GET", "/api/v1/templates/:id"},
|
{"GET", "/api/v1/applications/:id"},
|
||||||
{"PUT", "/api/v1/templates/:id"},
|
{"PUT", "/api/v1/applications/:id"},
|
||||||
{"DELETE", "/api/v1/templates/:id"},
|
{"DELETE", "/api/v1/applications/:id"},
|
||||||
|
// {"POST", "/api/v1/applications/:id/submit"}, // [已禁用] 已被 POST /tasks 取代
|
||||||
}
|
}
|
||||||
|
|
||||||
routeMap := map[string]bool{}
|
routeMap := map[string]bool{}
|
||||||
@@ -74,7 +75,7 @@ func TestRegisteredPathReturns501(t *testing.T) {
|
|||||||
{"GET", "/api/v1/nodes"},
|
{"GET", "/api/v1/nodes"},
|
||||||
{"GET", "/api/v1/partitions"},
|
{"GET", "/api/v1/partitions"},
|
||||||
{"GET", "/api/v1/diag"},
|
{"GET", "/api/v1/diag"},
|
||||||
{"GET", "/api/v1/templates"},
|
{"GET", "/api/v1/applications"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ep := range endpoints {
|
for _, ep := range endpoints {
|
||||||
|
|||||||
111
internal/service/application_service.go
Normal file
111
internal/service/application_service.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplicationService handles parameter validation, script rendering, and job
|
||||||
|
// submission for parameterized HPC applications.
|
||||||
|
type ApplicationService struct {
|
||||||
|
store *store.ApplicationStore
|
||||||
|
jobSvc *JobService
|
||||||
|
workDirBase string
|
||||||
|
logger *zap.Logger
|
||||||
|
taskSvc *TaskService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewApplicationService(store *store.ApplicationStore, jobSvc *JobService, workDirBase string, logger *zap.Logger, taskSvc ...*TaskService) *ApplicationService {
|
||||||
|
var ts *TaskService
|
||||||
|
if len(taskSvc) > 0 {
|
||||||
|
ts = taskSvc[0]
|
||||||
|
}
|
||||||
|
return &ApplicationService{store: store, jobSvc: jobSvc, workDirBase: workDirBase, logger: logger, taskSvc: ts}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListApplications delegates to the store.
|
||||||
|
func (s *ApplicationService) ListApplications(ctx context.Context, page, pageSize int) ([]model.Application, int, error) {
|
||||||
|
return s.store.List(ctx, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateApplication delegates to the store.
|
||||||
|
func (s *ApplicationService) CreateApplication(ctx context.Context, req *model.CreateApplicationRequest) (int64, error) {
|
||||||
|
return s.store.Create(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetApplication delegates to the store.
|
||||||
|
func (s *ApplicationService) GetApplication(ctx context.Context, id int64) (*model.Application, error) {
|
||||||
|
return s.store.GetByID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateApplication delegates to the store.
|
||||||
|
func (s *ApplicationService) UpdateApplication(ctx context.Context, id int64, req *model.UpdateApplicationRequest) error {
|
||||||
|
return s.store.Update(ctx, id, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteApplication delegates to the store.
|
||||||
|
func (s *ApplicationService) DeleteApplication(ctx context.Context, id int64) error {
|
||||||
|
return s.store.Delete(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此方法不再被调用。
|
||||||
|
/* // SubmitFromApplication orchestrates the full submission flow.
|
||||||
|
// When TaskService is available, it delegates to ProcessTaskSync which creates
|
||||||
|
// an hpc_tasks record and runs the full pipeline. Otherwise falls back to the
|
||||||
|
// original direct implementation.
|
||||||
|
func (s *ApplicationService) SubmitFromApplication(ctx context.Context, applicationID int64, values map[string]string) (*model.JobResponse, error) {
|
||||||
|
// [已禁用] 旧的直接提交路径,已被 TaskService 管道取代。生产环境中 taskSvc 始终非 nil,此分支不会执行。
|
||||||
|
// if s.taskSvc != nil {
|
||||||
|
req := &model.CreateTaskRequest{
|
||||||
|
AppID: applicationID,
|
||||||
|
Values: values,
|
||||||
|
InputFileIDs: nil, // old API has no file_ids concept
|
||||||
|
TaskName: "",
|
||||||
|
}
|
||||||
|
return s.taskSvc.ProcessTaskSync(ctx, req)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Fallback: original direct logic when TaskService not available
|
||||||
|
// app, err := s.store.GetByID(ctx, applicationID)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("get application: %w", err)
|
||||||
|
// }
|
||||||
|
// if app == nil {
|
||||||
|
// return nil, fmt.Errorf("application %d not found", applicationID)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// var params []model.ParameterSchema
|
||||||
|
// if len(app.Parameters) > 0 {
|
||||||
|
// if err := json.Unmarshal(app.Parameters, ¶ms); err != nil {
|
||||||
|
// return nil, fmt.Errorf("parse parameters: %w", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if err := ValidateParams(params, values); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// rendered := RenderScript(app.ScriptTemplate, params, values)
|
||||||
|
//
|
||||||
|
// workDir := ""
|
||||||
|
// if s.workDirBase != "" {
|
||||||
|
// safeName := SanitizeDirName(app.Name)
|
||||||
|
// subDir := time.Now().Format("20060102_150405") + "_" + RandomSuffix(4)
|
||||||
|
// workDir = filepath.Join(s.workDirBase, safeName, subDir)
|
||||||
|
// if err := os.MkdirAll(workDir, 0777); err != nil {
|
||||||
|
// return nil, fmt.Errorf("create work directory %s: %w", workDir, err)
|
||||||
|
// }
|
||||||
|
// // 绕过 umask,确保整条路径都有写权限
|
||||||
|
// for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
|
||||||
|
// os.Chmod(dir, 0777)
|
||||||
|
// }
|
||||||
|
// os.Chmod(s.workDirBase, 0777)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// req := &model.SubmitJobRequest{Script: rendered, WorkDir: workDir}
|
||||||
|
// return s.jobSvc.SubmitJob(ctx, req)
|
||||||
|
} */
|
||||||
367
internal/service/application_service_test.go
Normal file
367
internal/service/application_service_test.go
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupApplicationService(t *testing.T, slurmHandler http.HandlerFunc) (*ApplicationService, func()) {
|
||||||
|
t.Helper()
|
||||||
|
srv := httptest.NewServer(slurmHandler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Application{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jobSvc := NewJobService(client, zap.NewNop())
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
appSvc := NewApplicationService(appStore, jobSvc, "", zap.NewNop())
|
||||||
|
|
||||||
|
return appSvc, srv.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateParams_AllRequired(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{
|
||||||
|
{Name: "NAME", Type: model.ParamTypeString, Required: true},
|
||||||
|
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
|
||||||
|
}
|
||||||
|
values := map[string]string{"NAME": "hello", "COUNT": "5"}
|
||||||
|
if err := ValidateParams(params, values); err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateParams_MissingRequired(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{
|
||||||
|
{Name: "NAME", Type: model.ParamTypeString, Required: true},
|
||||||
|
}
|
||||||
|
values := map[string]string{}
|
||||||
|
err := ValidateParams(params, values)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing required param")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "NAME") {
|
||||||
|
t.Errorf("error should mention param name, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateParams_InvalidInteger(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{
|
||||||
|
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
|
||||||
|
}
|
||||||
|
values := map[string]string{"COUNT": "abc"}
|
||||||
|
err := ValidateParams(params, values)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid integer")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "integer") {
|
||||||
|
t.Errorf("error should mention integer, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateParams_InvalidEnum(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{
|
||||||
|
{Name: "MODE", Type: model.ParamTypeEnum, Required: true, Options: []string{"fast", "slow"}},
|
||||||
|
}
|
||||||
|
values := map[string]string{"MODE": "medium"}
|
||||||
|
err := ValidateParams(params, values)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid enum value")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "MODE") {
|
||||||
|
t.Errorf("error should mention param name, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateParams_BooleanValues(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{
|
||||||
|
{Name: "FLAG", Type: model.ParamTypeBoolean, Required: true},
|
||||||
|
}
|
||||||
|
for _, val := range []string{"true", "false", "1", "0"} {
|
||||||
|
err := ValidateParams(params, map[string]string{"FLAG": val})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("boolean value %q should be valid, got error: %v", val, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderScript_SimpleReplacement(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
||||||
|
values := map[string]string{"INPUT": "data.txt"}
|
||||||
|
result := RenderScript("echo $INPUT", params, values)
|
||||||
|
expected := "echo 'data.txt'"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("got %q, want %q", result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderScript_DefaultValues(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{{Name: "OUTPUT", Type: model.ParamTypeString, Default: "out.log"}}
|
||||||
|
values := map[string]string{}
|
||||||
|
result := RenderScript("cat $OUTPUT", params, values)
|
||||||
|
expected := "cat 'out.log'"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("got %q, want %q", result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderScript_PreservesUnknownVars(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
||||||
|
values := map[string]string{"INPUT": "data.txt"}
|
||||||
|
result := RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values)
|
||||||
|
if !strings.Contains(result, "$HOME") {
|
||||||
|
t.Error("$HOME should be preserved")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "$PATH") {
|
||||||
|
t.Error("$PATH should be preserved")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "'data.txt'") {
|
||||||
|
t.Error("$INPUT should be replaced")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderScript_ShellEscaping(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"semicolon injection", "; rm -rf /", "'; rm -rf /'"},
|
||||||
|
{"command substitution", "$(cat /etc/passwd)", "'$(cat /etc/passwd)'"},
|
||||||
|
{"single quote", "hello'world", "'hello'\\''world'"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value})
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("got %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderScript_OverlappingParams(t *testing.T) {
|
||||||
|
template := "$JOB_NAME and $JOB"
|
||||||
|
params := []model.ParameterSchema{
|
||||||
|
{Name: "JOB", Type: model.ParamTypeString},
|
||||||
|
{Name: "JOB_NAME", Type: model.ParamTypeString},
|
||||||
|
}
|
||||||
|
values := map[string]string{"JOB": "myjob", "JOB_NAME": "my-test-job"}
|
||||||
|
result := RenderScript(template, params, values)
|
||||||
|
if strings.Contains(result, "$JOB_NAME") {
|
||||||
|
t.Error("$JOB_NAME was not replaced")
|
||||||
|
}
|
||||||
|
if strings.Contains(result, "$JOB") {
|
||||||
|
t.Error("$JOB was not replaced")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "'my-test-job'") {
|
||||||
|
t.Errorf("expected 'my-test-job' in result, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "'myjob'") {
|
||||||
|
t.Errorf("expected 'myjob' in result, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenderScript_NewlineInValue(t *testing.T) {
|
||||||
|
params := []model.ParameterSchema{{Name: "CMD", Type: model.ParamTypeString}}
|
||||||
|
values := map[string]string{"CMD": "line1\nline2"}
|
||||||
|
result := RenderScript("echo $CMD", params, values)
|
||||||
|
expected := "echo 'line1\nline2'"
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("got %q, want %q", result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitFromApplication_Success(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "test-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
|
||||||
|
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
|
||||||
|
"JOB_NAME": "my-job",
|
||||||
|
"INPUT": "hello",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitFromApplication() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 42 {
|
||||||
|
t.Errorf("JobID = %d, want 42", resp.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitFromApplication_AppNotFound(t *testing.T) {
|
||||||
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := appSvc.SubmitFromApplication(context.Background(), 99999, map[string]string{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-existent app")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not found") {
|
||||||
|
t.Errorf("error should mention 'not found', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitFromApplication_ValidationFail(t *testing.T) {
|
||||||
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "valid-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho $INPUT",
|
||||||
|
Parameters: json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = appSvc.SubmitFromApplication(context.Background(), 1, map[string]string{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation error for missing required param")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "missing") {
|
||||||
|
t.Errorf("error should mention 'missing', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// [已禁用] 测试的是旧的直接提交路径,该路径已被注释掉
|
||||||
|
/*
|
||||||
|
func TestSubmitFromApplication_NoParameters(t *testing.T) {
|
||||||
|
jobID := int32(99)
|
||||||
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "simple-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitFromApplication() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 99 {
|
||||||
|
t.Errorf("JobID = %d, want 99", resp.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// [已禁用] 前端已全部迁移到 POST /tasks 接口,此测试不再需要。
|
||||||
|
/*
|
||||||
|
func TestSubmitFromApplication_DelegatesToTaskService(t *testing.T) {
|
||||||
|
jobID := int32(77)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
jobSvc := NewJobService(client, zap.NewNop())
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Application{}, &model.Task{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
appStore := store.NewApplicationStore(db)
|
||||||
|
taskStore := store.NewTaskStore(db)
|
||||||
|
fileStore := store.NewFileStore(db)
|
||||||
|
blobStore := store.NewBlobStore(db)
|
||||||
|
|
||||||
|
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||||
|
os.MkdirAll(workDirBase, 0777)
|
||||||
|
|
||||||
|
taskSvc := NewTaskService(taskStore, appStore, fileStore, blobStore, nil, jobSvc, workDirBase, zap.NewNop())
|
||||||
|
appSvc := NewApplicationService(appStore, jobSvc, workDirBase, zap.NewNop(), taskSvc)
|
||||||
|
|
||||||
|
id, err := appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "delegated-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
|
||||||
|
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
|
||||||
|
"JOB_NAME": "delegated-job",
|
||||||
|
"INPUT": "test-data",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitFromApplication() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 77 {
|
||||||
|
t.Errorf("JobID = %d, want 77", resp.JobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var task model.Task
|
||||||
|
if err := db.Where("app_id = ?", id).First(&task).Error; err != nil {
|
||||||
|
t.Fatalf("no hpc_tasks record found for app_id %d: %v", id, err)
|
||||||
|
}
|
||||||
|
if task.SlurmJobID == nil || *task.SlurmJobID != 77 {
|
||||||
|
t.Errorf("task SlurmJobID = %v, want 77", task.SlurmJobID)
|
||||||
|
}
|
||||||
|
if task.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("task Status = %q, want %q", task.Status, model.TaskStatusQueued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
"gcy_hpc_server/internal/model"
|
||||||
"gcy_hpc_server/internal/slurm"
|
"gcy_hpc_server/internal/slurm"
|
||||||
@@ -45,6 +46,27 @@ func uint32NoValString(v *slurm.Uint32NoVal) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func derefUint64NoValInt64(v *slurm.Uint64NoVal) *int64 {
|
||||||
|
if v != nil && v.Number != nil {
|
||||||
|
return v.Number
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func derefCSVString(cs *slurm.CSVString) string {
|
||||||
|
if cs == nil || len(*cs) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
result := ""
|
||||||
|
for i, s := range *cs {
|
||||||
|
if i > 0 {
|
||||||
|
result += ","
|
||||||
|
}
|
||||||
|
result += s
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type ClusterService struct {
|
type ClusterService struct {
|
||||||
client *slurm.Client
|
client *slurm.Client
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
@@ -55,11 +77,30 @@ func NewClusterService(client *slurm.Client, logger *zap.Logger) *ClusterService
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) {
|
func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetNodes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
resp, _, err := s.client.Nodes.GetNodes(ctx, nil)
|
resp, _, err := s.client.Nodes.GetNodes(ctx, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetNodes"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get nodes", zap.Error(err))
|
s.logger.Error("failed to get nodes", zap.Error(err))
|
||||||
return nil, fmt.Errorf("get nodes: %w", err)
|
return nil, fmt.Errorf("get nodes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetNodes"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", resp),
|
||||||
|
)
|
||||||
|
|
||||||
if resp.Nodes == nil {
|
if resp.Nodes == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -71,11 +112,33 @@ func (s *ClusterService) GetNodes(ctx context.Context) ([]model.NodeResponse, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) {
|
func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeResponse, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetNode"),
|
||||||
|
zap.String("node_name", name),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
resp, _, err := s.client.Nodes.GetNode(ctx, name, nil)
|
resp, _, err := s.client.Nodes.GetNode(ctx, name, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetNode"),
|
||||||
|
zap.String("node_name", name),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err))
|
s.logger.Error("failed to get node", zap.String("name", name), zap.Error(err))
|
||||||
return nil, fmt.Errorf("get node %s: %w", name, err)
|
return nil, fmt.Errorf("get node %s: %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetNode"),
|
||||||
|
zap.String("node_name", name),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", resp),
|
||||||
|
)
|
||||||
|
|
||||||
if resp.Nodes == nil || len(*resp.Nodes) == 0 {
|
if resp.Nodes == nil || len(*resp.Nodes) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -85,11 +148,30 @@ func (s *ClusterService) GetNode(ctx context.Context, name string) (*model.NodeR
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) {
|
func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionResponse, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetPartitions"),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
resp, _, err := s.client.Partitions.GetPartitions(ctx, nil)
|
resp, _, err := s.client.Partitions.GetPartitions(ctx, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetPartitions"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get partitions", zap.Error(err))
|
s.logger.Error("failed to get partitions", zap.Error(err))
|
||||||
return nil, fmt.Errorf("get partitions: %w", err)
|
return nil, fmt.Errorf("get partitions: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetPartitions"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", resp),
|
||||||
|
)
|
||||||
|
|
||||||
if resp.Partitions == nil {
|
if resp.Partitions == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -101,11 +183,33 @@ func (s *ClusterService) GetPartitions(ctx context.Context) ([]model.PartitionRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) {
|
func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.PartitionResponse, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetPartition"),
|
||||||
|
zap.String("partition_name", name),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil)
|
resp, _, err := s.client.Partitions.GetPartition(ctx, name, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetPartition"),
|
||||||
|
zap.String("partition_name", name),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err))
|
s.logger.Error("failed to get partition", zap.String("name", name), zap.Error(err))
|
||||||
return nil, fmt.Errorf("get partition %s: %w", name, err)
|
return nil, fmt.Errorf("get partition %s: %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetPartition"),
|
||||||
|
zap.String("partition_name", name),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", resp),
|
||||||
|
)
|
||||||
|
|
||||||
if resp.Partitions == nil || len(*resp.Partitions) == 0 {
|
if resp.Partitions == nil || len(*resp.Partitions) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -115,11 +219,30 @@ func (s *ClusterService) GetPartition(ctx context.Context, name string) (*model.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) {
|
func (s *ClusterService) GetDiag(ctx context.Context) (*slurm.OpenapiDiagResp, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetDiag"),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
resp, _, err := s.client.Diag.GetDiag(ctx)
|
resp, _, err := s.client.Diag.GetDiag(ctx)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetDiag"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get diag", zap.Error(err))
|
s.logger.Error("failed to get diag", zap.Error(err))
|
||||||
return nil, fmt.Errorf("get diag: %w", err)
|
return nil, fmt.Errorf("get diag: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetDiag"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", resp),
|
||||||
|
)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,17 +251,39 @@ func mapNode(n slurm.Node) model.NodeResponse {
|
|||||||
Name: derefStr(n.Name),
|
Name: derefStr(n.Name),
|
||||||
State: n.State,
|
State: n.State,
|
||||||
CPUs: derefInt32(n.Cpus),
|
CPUs: derefInt32(n.Cpus),
|
||||||
|
AllocCpus: n.AllocCpus,
|
||||||
|
Cores: n.Cores,
|
||||||
|
Sockets: n.Sockets,
|
||||||
|
Threads: n.Threads,
|
||||||
RealMemory: derefInt64(n.RealMemory),
|
RealMemory: derefInt64(n.RealMemory),
|
||||||
AllocMem: derefInt64(n.AllocMemory),
|
AllocMemory: derefInt64(n.AllocMemory),
|
||||||
|
FreeMem: derefUint64NoValInt64(n.FreeMem),
|
||||||
|
CpuLoad: n.CpuLoad,
|
||||||
Arch: derefStr(n.Architecture),
|
Arch: derefStr(n.Architecture),
|
||||||
OS: derefStr(n.OperatingSystem),
|
OS: derefStr(n.OperatingSystem),
|
||||||
|
Gres: derefStr(n.Gres),
|
||||||
|
GresUsed: derefStr(n.GresUsed),
|
||||||
|
Reason: derefStr(n.Reason),
|
||||||
|
ReasonSetByUser: derefStr(n.ReasonSetByUser),
|
||||||
|
Address: derefStr(n.Address),
|
||||||
|
Hostname: derefStr(n.Hostname),
|
||||||
|
Weight: n.Weight,
|
||||||
|
Features: derefCSVString(n.Features),
|
||||||
|
ActiveFeatures: derefCSVString(n.ActiveFeatures),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
|
func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
|
||||||
var state []string
|
var state []string
|
||||||
|
var isDefault bool
|
||||||
if pi.Partition != nil {
|
if pi.Partition != nil {
|
||||||
state = pi.Partition.State
|
state = pi.Partition.State
|
||||||
|
for _, s := range state {
|
||||||
|
if s == "DEFAULT" {
|
||||||
|
isDefault = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var nodes string
|
var nodes string
|
||||||
if pi.Nodes != nil {
|
if pi.Nodes != nil {
|
||||||
@@ -156,12 +301,56 @@ func mapPartition(pi slurm.PartitionInfo) model.PartitionResponse {
|
|||||||
if pi.Maximums != nil {
|
if pi.Maximums != nil {
|
||||||
maxTime = uint32NoValString(pi.Maximums.Time)
|
maxTime = uint32NoValString(pi.Maximums.Time)
|
||||||
}
|
}
|
||||||
|
var maxNodes *int32
|
||||||
|
if pi.Maximums != nil {
|
||||||
|
maxNodes = mapUint32NoValToInt32(pi.Maximums.Nodes)
|
||||||
|
}
|
||||||
|
var maxCPUsPerNode *int32
|
||||||
|
if pi.Maximums != nil {
|
||||||
|
maxCPUsPerNode = mapUint32NoValToInt32(pi.Maximums.CpusPerNode)
|
||||||
|
}
|
||||||
|
var minNodes *int32
|
||||||
|
if pi.Minimums != nil {
|
||||||
|
minNodes = pi.Minimums.Nodes
|
||||||
|
}
|
||||||
|
var defaultTime string
|
||||||
|
if pi.Defaults != nil {
|
||||||
|
defaultTime = uint32NoValString(pi.Defaults.Time)
|
||||||
|
}
|
||||||
|
var graceTime *int32 = pi.GraceTime
|
||||||
|
var priority *int32
|
||||||
|
if pi.Priority != nil {
|
||||||
|
priority = pi.Priority.JobFactor
|
||||||
|
}
|
||||||
|
var qosAllowed, qosDeny, qosAssigned string
|
||||||
|
if pi.QOS != nil {
|
||||||
|
qosAllowed = derefStr(pi.QOS.Allowed)
|
||||||
|
qosDeny = derefStr(pi.QOS.Deny)
|
||||||
|
qosAssigned = derefStr(pi.QOS.Assigned)
|
||||||
|
}
|
||||||
|
var accountsAllowed, accountsDeny string
|
||||||
|
if pi.Accounts != nil {
|
||||||
|
accountsAllowed = derefStr(pi.Accounts.Allowed)
|
||||||
|
accountsDeny = derefStr(pi.Accounts.Deny)
|
||||||
|
}
|
||||||
return model.PartitionResponse{
|
return model.PartitionResponse{
|
||||||
Name: derefStr(pi.Name),
|
Name: derefStr(pi.Name),
|
||||||
State: state,
|
State: state,
|
||||||
|
Default: isDefault,
|
||||||
Nodes: nodes,
|
Nodes: nodes,
|
||||||
TotalCPUs: totalCPUs,
|
|
||||||
TotalNodes: totalNodes,
|
TotalNodes: totalNodes,
|
||||||
|
TotalCPUs: totalCPUs,
|
||||||
MaxTime: maxTime,
|
MaxTime: maxTime,
|
||||||
|
MaxNodes: maxNodes,
|
||||||
|
MaxCPUsPerNode: maxCPUsPerNode,
|
||||||
|
MinNodes: minNodes,
|
||||||
|
DefaultTime: defaultTime,
|
||||||
|
GraceTime: graceTime,
|
||||||
|
Priority: priority,
|
||||||
|
QOSAllowed: qosAllowed,
|
||||||
|
QOSDeny: qosDeny,
|
||||||
|
QOSAssigned: qosAssigned,
|
||||||
|
AccountsAllowed: accountsAllowed,
|
||||||
|
AccountsDeny: accountsDeny,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -352,10 +352,10 @@ func TestClusterService_GetNodes_ErrorLogging(t *testing.T) {
|
|||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if logs.Len() != 1 {
|
if logs.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
t.Fatalf("expected 3 log entries, got %d", logs.Len())
|
||||||
}
|
}
|
||||||
entry := logs.All()[0]
|
entry := logs.All()[2]
|
||||||
if entry.Level != zapcore.ErrorLevel {
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
@@ -374,10 +374,10 @@ func TestClusterService_GetNode_ErrorLogging(t *testing.T) {
|
|||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if logs.Len() != 1 {
|
if logs.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
t.Fatalf("expected 3 log entries, got %d", logs.Len())
|
||||||
}
|
}
|
||||||
entry := logs.All()[0]
|
entry := logs.All()[2]
|
||||||
if entry.Level != zapcore.ErrorLevel {
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
@@ -403,10 +403,10 @@ func TestClusterService_GetPartitions_ErrorLogging(t *testing.T) {
|
|||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if logs.Len() != 1 {
|
if logs.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
t.Fatalf("expected 3 log entries, got %d", logs.Len())
|
||||||
}
|
}
|
||||||
entry := logs.All()[0]
|
entry := logs.All()[2]
|
||||||
if entry.Level != zapcore.ErrorLevel {
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
@@ -425,10 +425,10 @@ func TestClusterService_GetPartition_ErrorLogging(t *testing.T) {
|
|||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if logs.Len() != 1 {
|
if logs.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
t.Fatalf("expected 3 log entries, got %d", logs.Len())
|
||||||
}
|
}
|
||||||
entry := logs.All()[0]
|
entry := logs.All()[2]
|
||||||
if entry.Level != zapcore.ErrorLevel {
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
@@ -454,10 +454,10 @@ func TestClusterService_GetDiag_ErrorLogging(t *testing.T) {
|
|||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if logs.Len() != 1 {
|
if logs.Len() != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", logs.Len())
|
t.Fatalf("expected 3 log entries, got %d", logs.Len())
|
||||||
}
|
}
|
||||||
entry := logs.All()[0]
|
entry := logs.All()[2]
|
||||||
if entry.Level != zapcore.ErrorLevel {
|
if entry.Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
t.Errorf("expected ErrorLevel, got %v", entry.Level)
|
||||||
}
|
}
|
||||||
|
|||||||
98
internal/service/download_service.go
Normal file
98
internal/service/download_service.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DownloadService handles file downloads with streaming and Range support.
|
||||||
|
type DownloadService struct {
|
||||||
|
storage storage.ObjectStorage
|
||||||
|
blobStore *store.BlobStore
|
||||||
|
fileStore *store.FileStore
|
||||||
|
bucket string
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDownloadService creates a new DownloadService.
|
||||||
|
func NewDownloadService(storage storage.ObjectStorage, blobStore *store.BlobStore, fileStore *store.FileStore, bucket string, logger *zap.Logger) *DownloadService {
|
||||||
|
return &DownloadService{
|
||||||
|
storage: storage,
|
||||||
|
blobStore: blobStore,
|
||||||
|
fileStore: fileStore,
|
||||||
|
bucket: bucket,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download returns a stream reader for the given file, optionally limited to a byte range.
|
||||||
|
// Returns (reader, file, blob, start, end, error).
|
||||||
|
func (s *DownloadService) Download(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
|
||||||
|
file, err := s.fileStore.GetByID(ctx, fileID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("get file: %w", err)
|
||||||
|
}
|
||||||
|
if file == nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("file not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("get blob: %w", err)
|
||||||
|
}
|
||||||
|
if blob == nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("blob not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var start, end int64
|
||||||
|
if rangeHeader != "" {
|
||||||
|
start, end, err = server.ParseRange(rangeHeader, blob.FileSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("parse range: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
start = 0
|
||||||
|
end = blob.FileSize - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := storage.GetOptions{
|
||||||
|
Start: &start,
|
||||||
|
End: &end,
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("get object: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader, file, blob, start, end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileMetadata returns the file and its associated blob metadata.
|
||||||
|
func (s *DownloadService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
|
||||||
|
file, err := s.fileStore.GetByID(ctx, fileID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get file: %w", err)
|
||||||
|
}
|
||||||
|
if file == nil {
|
||||||
|
return nil, nil, fmt.Errorf("file not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get blob: %w", err)
|
||||||
|
}
|
||||||
|
if blob == nil {
|
||||||
|
return nil, nil, fmt.Errorf("blob not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return file, blob, nil
|
||||||
|
}
|
||||||
260
internal/service/download_service_test.go
Normal file
260
internal/service/download_service_test.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockDownloadStorage struct {
|
||||||
|
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) PutObject(_ context.Context, _ string, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if m.getObjectFn != nil {
|
||||||
|
return m.getObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return nil, storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) AbortMultipartUpload(_ context.Context, _ string, _ string, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) RemoveIncompleteUpload(_ context.Context, _ string, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) RemoveObject(_ context.Context, _ string, _ string, _ storage.RemoveObjectOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) ListObjects(_ context.Context, _ string, _ string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) RemoveObjects(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) BucketExists(_ context.Context, _ string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDownloadStorage) StatObject(_ context.Context, _ string, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
return storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDownloadService(t *testing.T) (*DownloadService, *store.FileStore, *store.BlobStore, *mockDownloadStorage) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStorage := &mockDownloadStorage{}
|
||||||
|
blobStore := store.NewBlobStore(db)
|
||||||
|
fileStore := store.NewFileStore(db)
|
||||||
|
svc := NewDownloadService(mockStorage, blobStore, fileStore, "test-bucket", zap.NewNop())
|
||||||
|
|
||||||
|
return svc, fileStore, blobStore, mockStorage
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestFileAndBlob(t *testing.T, fileStore *store.FileStore, blobStore *store.BlobStore) (*model.File, *model.FileBlob) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
blob := &model.FileBlob{
|
||||||
|
SHA256: "abc123def456abc123def456abc123def456abc123def456abc123def456abcd",
|
||||||
|
MinioKey: "chunks/session1/part-0",
|
||||||
|
FileSize: 5000,
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
RefCount: 1,
|
||||||
|
}
|
||||||
|
if err := blobStore.Create(context.Background(), blob); err != nil {
|
||||||
|
t.Fatalf("create blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file := &model.File{
|
||||||
|
Name: "test.dat",
|
||||||
|
BlobSHA256: blob.SHA256,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := fileStore.Create(context.Background(), file); err != nil {
|
||||||
|
t.Fatalf("create file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return file, blob
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownload_FullFile(t *testing.T) {
|
||||||
|
svc, fileStore, blobStore, mockStorage := setupDownloadService(t)
|
||||||
|
file, blob := createTestFileAndBlob(t, fileStore, blobStore)
|
||||||
|
|
||||||
|
content := make([]byte, blob.FileSize)
|
||||||
|
for i := range content {
|
||||||
|
content[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStorage.getObjectFn = func(_ context.Context, _, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if opts.Start == nil || opts.End == nil {
|
||||||
|
t.Fatal("expected Start and End to be set")
|
||||||
|
}
|
||||||
|
if *opts.Start != 0 {
|
||||||
|
t.Fatalf("expected start=0, got %d", *opts.Start)
|
||||||
|
}
|
||||||
|
if *opts.End != blob.FileSize-1 {
|
||||||
|
t.Fatalf("expected end=%d, got %d", blob.FileSize-1, *opts.End)
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(content)), storage.ObjectInfo{Size: blob.FileSize}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, gotFile, gotBlob, start, end, err := svc.Download(context.Background(), file.ID, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Download: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
if gotFile.ID != file.ID {
|
||||||
|
t.Fatalf("expected file ID %d, got %d", file.ID, gotFile.ID)
|
||||||
|
}
|
||||||
|
if gotBlob.SHA256 != blob.SHA256 {
|
||||||
|
t.Fatalf("expected blob SHA256 %s, got %s", blob.SHA256, gotBlob.SHA256)
|
||||||
|
}
|
||||||
|
if start != 0 {
|
||||||
|
t.Fatalf("expected start=0, got %d", start)
|
||||||
|
}
|
||||||
|
if end != blob.FileSize-1 {
|
||||||
|
t.Fatalf("expected end=%d, got %d", blob.FileSize-1, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
read, _ := io.ReadAll(reader)
|
||||||
|
if int64(len(read)) != blob.FileSize {
|
||||||
|
t.Fatalf("expected %d bytes, got %d", blob.FileSize, len(read))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownload_WithRange(t *testing.T) {
|
||||||
|
svc, fileStore, blobStore, mockStorage := setupDownloadService(t)
|
||||||
|
file, _ := createTestFileAndBlob(t, fileStore, blobStore)
|
||||||
|
|
||||||
|
content := make([]byte, 1024)
|
||||||
|
for i := range content {
|
||||||
|
content[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockStorage.getObjectFn = func(_ context.Context, _, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if opts.Start == nil || opts.End == nil {
|
||||||
|
t.Fatal("expected Start and End to be set")
|
||||||
|
}
|
||||||
|
if *opts.Start != 0 {
|
||||||
|
t.Fatalf("expected start=0, got %d", *opts.Start)
|
||||||
|
}
|
||||||
|
if *opts.End != 1023 {
|
||||||
|
t.Fatalf("expected end=1023, got %d", *opts.End)
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(content[:1024])), storage.ObjectInfo{Size: 1024}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, _, _, start, end, err := svc.Download(context.Background(), file.ID, "bytes=0-1023")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Download: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
if start != 0 {
|
||||||
|
t.Fatalf("expected start=0, got %d", start)
|
||||||
|
}
|
||||||
|
if end != 1023 {
|
||||||
|
t.Fatalf("expected end=1023, got %d", end)
|
||||||
|
}
|
||||||
|
|
||||||
|
read, _ := io.ReadAll(reader)
|
||||||
|
if len(read) != 1024 {
|
||||||
|
t.Fatalf("expected 1024 bytes, got %d", len(read))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownload_FileNotFound(t *testing.T) {
|
||||||
|
svc, _, _, _ := setupDownloadService(t)
|
||||||
|
|
||||||
|
_, _, _, _, _, err := svc.Download(context.Background(), 99999, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing file")
|
||||||
|
}
|
||||||
|
if err.Error() != "file not found" {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownload_BlobNotFound(t *testing.T) {
|
||||||
|
svc, fileStore, _, _ := setupDownloadService(t)
|
||||||
|
|
||||||
|
file := &model.File{
|
||||||
|
Name: "orphan.dat",
|
||||||
|
BlobSHA256: "nonexistent_hash_0000000000000000000000000000000000000000",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := fileStore.Create(context.Background(), file); err != nil {
|
||||||
|
t.Fatalf("create file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, _, _, err := svc.Download(context.Background(), file.ID, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing blob")
|
||||||
|
}
|
||||||
|
if err.Error() != "blob not found" {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileMetadata(t *testing.T) {
|
||||||
|
svc, fileStore, blobStore, _ := setupDownloadService(t)
|
||||||
|
file, _ := createTestFileAndBlob(t, fileStore, blobStore)
|
||||||
|
|
||||||
|
gotFile, gotBlob, err := svc.GetFileMetadata(context.Background(), file.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFileMetadata: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotFile.ID != file.ID {
|
||||||
|
t.Fatalf("expected file ID %d, got %d", file.ID, gotFile.ID)
|
||||||
|
}
|
||||||
|
if gotBlob.FileSize != 5000 {
|
||||||
|
t.Fatalf("expected file size 5000, got %d", gotBlob.FileSize)
|
||||||
|
}
|
||||||
|
if gotBlob.MimeType != "application/octet-stream" {
|
||||||
|
t.Fatalf("expected mime type application/octet-stream, got %s", gotBlob.MimeType)
|
||||||
|
}
|
||||||
|
}
|
||||||
178
internal/service/file_service.go
Normal file
178
internal/service/file_service.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/server"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileService handles file listing, metadata, download, and deletion operations.
|
||||||
|
type FileService struct {
|
||||||
|
storage storage.ObjectStorage
|
||||||
|
blobStore *store.BlobStore
|
||||||
|
fileStore *store.FileStore
|
||||||
|
bucket string
|
||||||
|
db *gorm.DB
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileService creates a new FileService.
|
||||||
|
func NewFileService(storage storage.ObjectStorage, blobStore *store.BlobStore, fileStore *store.FileStore, bucket string, db *gorm.DB, logger *zap.Logger) *FileService {
|
||||||
|
return &FileService{
|
||||||
|
storage: storage,
|
||||||
|
blobStore: blobStore,
|
||||||
|
fileStore: fileStore,
|
||||||
|
bucket: bucket,
|
||||||
|
db: db,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFiles returns a paginated list of files, optionally filtered by folder or search query.
|
||||||
|
func (s *FileService) ListFiles(ctx context.Context, folderID *int64, page, pageSize int, search string) ([]model.FileResponse, int64, error) {
|
||||||
|
var files []model.File
|
||||||
|
var total int64
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if search != "" {
|
||||||
|
files, total, err = s.fileStore.Search(ctx, search, page, pageSize)
|
||||||
|
} else {
|
||||||
|
files, total, err = s.fileStore.List(ctx, folderID, page, pageSize)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("list files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := make([]model.FileResponse, 0, len(files))
|
||||||
|
for _, f := range files {
|
||||||
|
blob, err := s.blobStore.GetBySHA256(ctx, f.BlobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("get blob for file %d: %w", f.ID, err)
|
||||||
|
}
|
||||||
|
if blob == nil {
|
||||||
|
return nil, 0, fmt.Errorf("blob not found for file %d", f.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
responses = append(responses, model.FileResponse{
|
||||||
|
ID: f.ID,
|
||||||
|
Name: f.Name,
|
||||||
|
FolderID: f.FolderID,
|
||||||
|
Size: blob.FileSize,
|
||||||
|
MimeType: blob.MimeType,
|
||||||
|
SHA256: f.BlobSHA256,
|
||||||
|
CreatedAt: f.CreatedAt,
|
||||||
|
UpdatedAt: f.UpdatedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return responses, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileMetadata returns the file and its associated blob metadata.
|
||||||
|
func (s *FileService) GetFileMetadata(ctx context.Context, fileID int64) (*model.File, *model.FileBlob, error) {
|
||||||
|
file, err := s.fileStore.GetByID(ctx, fileID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get file: %w", err)
|
||||||
|
}
|
||||||
|
if file == nil {
|
||||||
|
return nil, nil, fmt.Errorf("file not found: %d", fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := s.blobStore.GetBySHA256(ctx, file.BlobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get blob: %w", err)
|
||||||
|
}
|
||||||
|
if blob == nil {
|
||||||
|
return nil, nil, fmt.Errorf("blob not found for file %d", fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return file, blob, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadFile returns a reader for the file content, along with file and blob metadata.
|
||||||
|
// If rangeHeader is non-empty, it parses the range and returns partial content.
|
||||||
|
func (s *FileService) DownloadFile(ctx context.Context, fileID int64, rangeHeader string) (io.ReadCloser, *model.File, *model.FileBlob, int64, int64, error) {
|
||||||
|
file, blob, err := s.GetFileMetadata(ctx, fileID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var start, end int64
|
||||||
|
if rangeHeader != "" {
|
||||||
|
start, end, err = server.ParseRange(rangeHeader, blob.FileSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("parse range: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
start = 0
|
||||||
|
end = blob.FileSize - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{
|
||||||
|
Start: &start,
|
||||||
|
End: &end,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, 0, fmt.Errorf("get object: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader, file, blob, start, end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFile soft-deletes a file. If no other active files reference the same blob,
|
||||||
|
// it decrements the blob ref count and removes the object from storage when ref count reaches 0.
|
||||||
|
func (s *FileService) DeleteFile(ctx context.Context, fileID int64) error {
|
||||||
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
txFileStore := store.NewFileStore(tx)
|
||||||
|
txBlobStore := store.NewBlobStore(tx)
|
||||||
|
|
||||||
|
blobSHA256, err := txFileStore.GetBlobSHA256ByID(ctx, fileID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get blob sha256: %w", err)
|
||||||
|
}
|
||||||
|
if blobSHA256 == "" {
|
||||||
|
return fmt.Errorf("file not found: %d", fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Delete(&model.File{}, fileID).Error; err != nil {
|
||||||
|
return fmt.Errorf("soft delete file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
activeCount, err := txFileStore.CountByBlobSHA256(ctx, blobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("count active refs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if activeCount == 0 {
|
||||||
|
newRefCount, err := txBlobStore.DecrementRef(ctx, blobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decrement ref: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newRefCount == 0 {
|
||||||
|
blob, err := txBlobStore.GetBySHA256(ctx, blobSHA256)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get blob for cleanup: %w", err)
|
||||||
|
}
|
||||||
|
if blob != nil {
|
||||||
|
if err := s.storage.RemoveObject(ctx, s.bucket, blob.MinioKey, storage.RemoveObjectOptions{}); err != nil {
|
||||||
|
return fmt.Errorf("remove object: %w", err)
|
||||||
|
}
|
||||||
|
if err := txBlobStore.Delete(ctx, blobSHA256); err != nil {
|
||||||
|
return fmt.Errorf("delete blob: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
484
internal/service/file_service_test.go
Normal file
484
internal/service/file_service_test.go
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockFileStorage struct {
|
||||||
|
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
|
||||||
|
removeObjectFn func(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) PutObject(_ context.Context, _, _ string, _ io.Reader, _ int64, _ storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if m.getObjectFn != nil {
|
||||||
|
return m.getObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return io.NopCloser(strings.NewReader("data")), storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) ComposeObject(_ context.Context, _ string, _ string, _ []string) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
|
||||||
|
if m.removeObjectFn != nil {
|
||||||
|
return m.removeObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) ListObjects(_ context.Context, _ string, _ string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) RemoveObjects(_ context.Context, _ string, _ []string, _ storage.RemoveObjectsOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) BucketExists(_ context.Context, _ string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) MakeBucket(_ context.Context, _ string, _ storage.MakeBucketOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileStorage) StatObject(_ context.Context, _, _ string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
return storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupFileTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupFileService(t *testing.T) (*FileService, *mockFileStorage, *gorm.DB) {
|
||||||
|
t.Helper()
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
ms := &mockFileStorage{}
|
||||||
|
svc := NewFileService(ms, store.NewBlobStore(db), store.NewFileStore(db), "test-bucket", db, zap.NewNop())
|
||||||
|
return svc, ms, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestBlob(t *testing.T, db *gorm.DB, sha256, minioKey, mimeType string, fileSize int64, refCount int) *model.FileBlob {
|
||||||
|
t.Helper()
|
||||||
|
blob := &model.FileBlob{
|
||||||
|
SHA256: sha256,
|
||||||
|
MinioKey: minioKey,
|
||||||
|
FileSize: fileSize,
|
||||||
|
MimeType: mimeType,
|
||||||
|
RefCount: refCount,
|
||||||
|
}
|
||||||
|
if err := db.Create(blob).Error; err != nil {
|
||||||
|
t.Fatalf("create blob: %v", err)
|
||||||
|
}
|
||||||
|
return blob
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestFile(t *testing.T, db *gorm.DB, name, blobSHA256 string, folderID *int64) *model.File {
|
||||||
|
t.Helper()
|
||||||
|
file := &model.File{
|
||||||
|
Name: name,
|
||||||
|
FolderID: folderID,
|
||||||
|
BlobSHA256: blobSHA256,
|
||||||
|
}
|
||||||
|
if err := db.Create(file).Error; err != nil {
|
||||||
|
t.Fatalf("create file: %v", err)
|
||||||
|
}
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_Empty(t *testing.T) {
|
||||||
|
svc, _, _ := setupFileService(t)
|
||||||
|
|
||||||
|
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFiles: %v", err)
|
||||||
|
}
|
||||||
|
if total != 0 {
|
||||||
|
t.Errorf("expected total 0, got %d", total)
|
||||||
|
}
|
||||||
|
if len(files) != 0 {
|
||||||
|
t.Errorf("expected empty files, got %d", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_WithFiles(t *testing.T) {
|
||||||
|
svc, _, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256abc", "blobs/abc", "text/plain", 1024, 2)
|
||||||
|
createTestFile(t, db, "file1.txt", blob.SHA256, nil)
|
||||||
|
createTestFile(t, db, "file2.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFiles: %v", err)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("expected total 2, got %d", total)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("expected 2 files, got %d", len(files))
|
||||||
|
}
|
||||||
|
for _, f := range files {
|
||||||
|
if f.Size != 1024 {
|
||||||
|
t.Errorf("expected size 1024, got %d", f.Size)
|
||||||
|
}
|
||||||
|
if f.MimeType != "text/plain" {
|
||||||
|
t.Errorf("expected mime text/plain, got %s", f.MimeType)
|
||||||
|
}
|
||||||
|
if f.SHA256 != "sha256abc" {
|
||||||
|
t.Errorf("expected sha256 sha256abc, got %s", f.SHA256)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_Search(t *testing.T) {
|
||||||
|
svc, _, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256search", "blobs/search", "image/png", 2048, 1)
|
||||||
|
createTestFile(t, db, "photo.png", blob.SHA256, nil)
|
||||||
|
createTestFile(t, db, "document.pdf", "sha256other", nil)
|
||||||
|
createTestBlob(t, db, "sha256other", "blobs/other", "application/pdf", 512, 1)
|
||||||
|
|
||||||
|
files, total, err := svc.ListFiles(context.Background(), nil, 1, 10, "photo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFiles: %v", err)
|
||||||
|
}
|
||||||
|
if total != 1 {
|
||||||
|
t.Errorf("expected total 1, got %d", total)
|
||||||
|
}
|
||||||
|
if len(files) != 1 {
|
||||||
|
t.Fatalf("expected 1 file, got %d", len(files))
|
||||||
|
}
|
||||||
|
if files[0].Name != "photo.png" {
|
||||||
|
t.Errorf("expected photo.png, got %s", files[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileMetadata_Found(t *testing.T) {
|
||||||
|
svc, _, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256meta", "blobs/meta", "application/json", 42, 1)
|
||||||
|
file := createTestFile(t, db, "data.json", blob.SHA256, nil)
|
||||||
|
|
||||||
|
gotFile, gotBlob, err := svc.GetFileMetadata(context.Background(), file.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFileMetadata: %v", err)
|
||||||
|
}
|
||||||
|
if gotFile.ID != file.ID {
|
||||||
|
t.Errorf("expected file id %d, got %d", file.ID, gotFile.ID)
|
||||||
|
}
|
||||||
|
if gotBlob.SHA256 != blob.SHA256 {
|
||||||
|
t.Errorf("expected blob sha256 %s, got %s", blob.SHA256, gotBlob.SHA256)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileMetadata_NotFound(t *testing.T) {
|
||||||
|
svc, _, _ := setupFileService(t)
|
||||||
|
|
||||||
|
_, _, err := svc.GetFileMetadata(context.Background(), 9999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_Full(t *testing.T) {
|
||||||
|
svc, ms, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256dl", "blobs/dl", "text/plain", 100, 1)
|
||||||
|
file := createTestFile(t, db, "download.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
content := []byte("hello world")
|
||||||
|
ms.getObjectFn = func(_ context.Context, _ string, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if opts.Start == nil || opts.End == nil {
|
||||||
|
t.Error("expected start and end to be set")
|
||||||
|
} else if *opts.Start != 0 || *opts.End != 99 {
|
||||||
|
t.Errorf("expected range 0-99, got %d-%d", *opts.Start, *opts.End)
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(content)), storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, gotFile, gotBlob, start, end, err := svc.DownloadFile(context.Background(), file.ID, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DownloadFile: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
if gotFile.ID != file.ID {
|
||||||
|
t.Errorf("expected file id %d, got %d", file.ID, gotFile.ID)
|
||||||
|
}
|
||||||
|
if gotBlob.SHA256 != blob.SHA256 {
|
||||||
|
t.Errorf("expected blob sha256 %s, got %s", blob.SHA256, gotBlob.SHA256)
|
||||||
|
}
|
||||||
|
if start != 0 {
|
||||||
|
t.Errorf("expected start 0, got %d", start)
|
||||||
|
}
|
||||||
|
if end != 99 {
|
||||||
|
t.Errorf("expected end 99, got %d", end)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := io.ReadAll(reader)
|
||||||
|
if string(data) != "hello world" {
|
||||||
|
t.Errorf("expected 'hello world', got %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_WithRange(t *testing.T) {
|
||||||
|
svc, ms, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256range", "blobs/range", "text/plain", 1000, 1)
|
||||||
|
file := createTestFile(t, db, "range.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
ms.getObjectFn = func(_ context.Context, _ string, _ string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if opts.Start != nil && *opts.Start != 100 {
|
||||||
|
t.Errorf("expected start 100, got %d", *opts.Start)
|
||||||
|
}
|
||||||
|
if opts.End != nil && *opts.End != 199 {
|
||||||
|
t.Errorf("expected end 199, got %d", *opts.End)
|
||||||
|
}
|
||||||
|
return io.NopCloser(strings.NewReader("partial")), storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, _, _, start, end, err := svc.DownloadFile(context.Background(), file.ID, "bytes=100-199")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DownloadFile: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
if start != 100 {
|
||||||
|
t.Errorf("expected start 100, got %d", start)
|
||||||
|
}
|
||||||
|
if end != 199 {
|
||||||
|
t.Errorf("expected end 199, got %d", end)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile_LastRef(t *testing.T) {
|
||||||
|
svc, ms, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256del", "blobs/del", "text/plain", 50, 1)
|
||||||
|
file := createTestFile(t, db, "delete-me.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
removed := false
|
||||||
|
ms.removeObjectFn = func(_ context.Context, bucket, key string, _ storage.RemoveObjectOptions) error {
|
||||||
|
if bucket != "test-bucket" {
|
||||||
|
t.Errorf("expected bucket 'test-bucket', got %q", bucket)
|
||||||
|
}
|
||||||
|
if key != "blobs/del" {
|
||||||
|
t.Errorf("expected key 'blobs/del', got %q", key)
|
||||||
|
}
|
||||||
|
removed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := svc.DeleteFile(context.Background(), file.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !removed {
|
||||||
|
t.Error("expected RemoveObject to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
db.Model(&model.FileBlob{}).Where("sha256 = ?", "sha256del").Count(&count)
|
||||||
|
if count != 0 {
|
||||||
|
t.Errorf("expected blob to be hard deleted, found %d records", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileCount int64
|
||||||
|
db.Unscoped().Model(&model.File{}).Where("id = ?", file.ID).Count(&fileCount)
|
||||||
|
if fileCount != 1 {
|
||||||
|
t.Errorf("expected file to still exist (soft deleted), found %d", fileCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
var deletedFile model.File
|
||||||
|
db.Unscoped().First(&deletedFile, file.ID)
|
||||||
|
if deletedFile.DeletedAt.Time.IsZero() {
|
||||||
|
t.Error("expected deleted_at to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile_OtherRefsExist(t *testing.T) {
|
||||||
|
svc, ms, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256multi", "blobs/multi", "text/plain", 50, 3)
|
||||||
|
file1 := createTestFile(t, db, "ref1.txt", blob.SHA256, nil)
|
||||||
|
createTestFile(t, db, "ref2.txt", blob.SHA256, nil)
|
||||||
|
createTestFile(t, db, "ref3.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
removed := false
|
||||||
|
ms.removeObjectFn = func(_ context.Context, _, _ string, _ storage.RemoveObjectOptions) error {
|
||||||
|
removed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := svc.DeleteFile(context.Background(), file1.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed {
|
||||||
|
t.Error("expected RemoveObject NOT to be called since other refs exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
var updatedBlob model.FileBlob
|
||||||
|
db.Where("sha256 = ?", "sha256multi").First(&updatedBlob)
|
||||||
|
if updatedBlob.RefCount != 3 {
|
||||||
|
t.Errorf("expected ref_count to remain 3, got %d", updatedBlob.RefCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile_SoftDeleteNotAffectRefcount(t *testing.T) {
|
||||||
|
svc, _, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256soft", "blobs/soft", "text/plain", 50, 2)
|
||||||
|
file1 := createTestFile(t, db, "soft1.txt", blob.SHA256, nil)
|
||||||
|
file2 := createTestFile(t, db, "soft2.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
if err := svc.DeleteFile(context.Background(), file1.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var updatedBlob model.FileBlob
|
||||||
|
db.Where("sha256 = ?", "sha256soft").First(&updatedBlob)
|
||||||
|
if updatedBlob.RefCount != 2 {
|
||||||
|
t.Errorf("expected ref_count to remain 2 (soft delete should not decrement), got %d", updatedBlob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
activeCount, err := store.NewFileStore(db).CountByBlobSHA256(context.Background(), "sha256soft")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountByBlobSHA256: %v", err)
|
||||||
|
}
|
||||||
|
if activeCount != 1 {
|
||||||
|
t.Errorf("expected 1 active ref after soft delete, got %d", activeCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allFiles []model.File
|
||||||
|
db.Unscoped().Where("blob_sha256 = ?", "sha256soft").Find(&allFiles)
|
||||||
|
if len(allFiles) != 2 {
|
||||||
|
t.Errorf("expected 2 total files (one soft deleted), got %d", len(allFiles))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := svc.DeleteFile(context.Background(), file2.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteFile second: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
activeCount2, err := store.NewFileStore(db).CountByBlobSHA256(context.Background(), "sha256soft")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountByBlobSHA256 after second delete: %v", err)
|
||||||
|
}
|
||||||
|
if activeCount2 != 0 {
|
||||||
|
t.Errorf("expected 0 active refs after both deleted, got %d", activeCount2)
|
||||||
|
}
|
||||||
|
|
||||||
|
var finalBlob model.FileBlob
|
||||||
|
db.Where("sha256 = ?", "sha256soft").First(&finalBlob)
|
||||||
|
if finalBlob.RefCount != 1 {
|
||||||
|
t.Errorf("expected ref_count=1 (decremented once from 2), got %d", finalBlob.RefCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile_NotFound(t *testing.T) {
|
||||||
|
svc, _, _ := setupFileService(t)
|
||||||
|
|
||||||
|
err := svc.DeleteFile(context.Background(), 9999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_StorageError(t *testing.T) {
|
||||||
|
svc, ms, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256err", "blobs/err", "text/plain", 100, 1)
|
||||||
|
file := createTestFile(t, db, "error.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
ms.getObjectFn = func(_ context.Context, _ string, _ string, _ storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
return nil, storage.ObjectInfo{}, errors.New("storage unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, _, _, _, err := svc.DownloadFile(context.Background(), file.ID, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from storage")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "storage unavailable") {
|
||||||
|
t.Errorf("expected storage error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFiles_WithFolderFilter(t *testing.T) {
|
||||||
|
svc, _, db := setupFileService(t)
|
||||||
|
|
||||||
|
blob := createTestBlob(t, db, "sha256folder", "blobs/folder", "text/plain", 100, 2)
|
||||||
|
folderID := int64(1)
|
||||||
|
createTestFile(t, db, "in_folder.txt", blob.SHA256, &folderID)
|
||||||
|
createTestFile(t, db, "root.txt", blob.SHA256, nil)
|
||||||
|
|
||||||
|
files, total, err := svc.ListFiles(context.Background(), &folderID, 1, 10, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFiles: %v", err)
|
||||||
|
}
|
||||||
|
if total != 1 {
|
||||||
|
t.Errorf("expected total 1, got %d", total)
|
||||||
|
}
|
||||||
|
if len(files) != 1 {
|
||||||
|
t.Fatalf("expected 1 file, got %d", len(files))
|
||||||
|
}
|
||||||
|
if files[0].Name != "in_folder.txt" {
|
||||||
|
t.Errorf("expected in_folder.txt, got %s", files[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
rootFiles, rootTotal, err := svc.ListFiles(context.Background(), nil, 1, 10, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFiles root: %v", err)
|
||||||
|
}
|
||||||
|
if rootTotal != 1 {
|
||||||
|
t.Errorf("expected root total 1, got %d", rootTotal)
|
||||||
|
}
|
||||||
|
if len(rootFiles) != 1 || rootFiles[0].Name != "root.txt" {
|
||||||
|
t.Errorf("expected root.txt in root listing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileMetadata_BlobMissing(t *testing.T) {
|
||||||
|
svc, _, db := setupFileService(t)
|
||||||
|
|
||||||
|
file := createTestFile(t, db, "orphan.txt", "nonexistent_sha256", nil)
|
||||||
|
|
||||||
|
_, _, err := svc.GetFileMetadata(context.Background(), file.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when blob is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
145
internal/service/file_staging_service.go
Normal file
145
internal/service/file_staging_service.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileStagingService batch downloads files from MinIO to a local (NFS) directory,
|
||||||
|
// deduplicating by blob SHA256 so each unique blob is fetched only once.
|
||||||
|
type FileStagingService struct {
|
||||||
|
fileStore *store.FileStore
|
||||||
|
blobStore *store.BlobStore
|
||||||
|
storage storage.ObjectStorage
|
||||||
|
bucket string
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFileStagingService(fileStore *store.FileStore, blobStore *store.BlobStore, st storage.ObjectStorage, bucket string, logger *zap.Logger) *FileStagingService {
|
||||||
|
return &FileStagingService{
|
||||||
|
fileStore: fileStore,
|
||||||
|
blobStore: blobStore,
|
||||||
|
storage: st,
|
||||||
|
bucket: bucket,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadFilesToDir downloads the given files into destDir.
|
||||||
|
// Files sharing the same blob SHA256 are deduplicated: the blob is fetched once
|
||||||
|
// and then copied to each filename. Filenames are sanitized with filepath.Base
|
||||||
|
// to prevent path traversal.
|
||||||
|
func (s *FileStagingService) DownloadFilesToDir(ctx context.Context, fileIDs []int64, destDir string) error {
|
||||||
|
if len(fileIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetch files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type group struct {
|
||||||
|
primary *model.File // first file — written via io.Copy from MinIO
|
||||||
|
others []*model.File // remaining files — local copy of primary
|
||||||
|
}
|
||||||
|
groups := make(map[string]*group)
|
||||||
|
for i := range files {
|
||||||
|
f := &files[i]
|
||||||
|
g, ok := groups[f.BlobSHA256]
|
||||||
|
if !ok {
|
||||||
|
groups[f.BlobSHA256] = &group{primary: f}
|
||||||
|
} else {
|
||||||
|
g.others = append(g.others, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sha256s := make([]string, 0, len(groups))
|
||||||
|
for sh := range groups {
|
||||||
|
sha256s = append(sha256s, sh)
|
||||||
|
}
|
||||||
|
blobs, err := s.blobStore.GetBySHA256s(ctx, sha256s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetch blobs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blobMap := make(map[string]*model.FileBlob, len(blobs))
|
||||||
|
for i := range blobs {
|
||||||
|
blobMap[blobs[i].SHA256] = &blobs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for sha256, g := range groups {
|
||||||
|
blob, ok := blobMap[sha256]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("blob %s not found", sha256)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, _, err := s.storage.GetObject(ctx, s.bucket, blob.MinioKey, storage.GetOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get object %s: %w", blob.MinioKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle filename collisions when multiple files have the same Name (low risk without user auth, revisit when auth is added)
|
||||||
|
primaryName := filepath.Base(g.primary.Name)
|
||||||
|
primaryPath := filepath.Join(destDir, primaryName)
|
||||||
|
|
||||||
|
if err := writeFile(primaryPath, reader); err != nil {
|
||||||
|
reader.Close()
|
||||||
|
os.Remove(primaryPath)
|
||||||
|
return fmt.Errorf("write file %s: %w", primaryName, err)
|
||||||
|
}
|
||||||
|
reader.Close()
|
||||||
|
|
||||||
|
for _, other := range g.others {
|
||||||
|
otherName := filepath.Base(other.Name)
|
||||||
|
otherPath := filepath.Join(destDir, otherName)
|
||||||
|
|
||||||
|
if err := copyFile(primaryPath, otherPath); err != nil {
|
||||||
|
return fmt.Errorf("copy %s to %s: %w", primaryName, otherName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(path string, reader io.Reader) error {
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(f, reader); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
in, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer in.Close()
|
||||||
|
|
||||||
|
out, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, in); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
232
internal/service/file_staging_service_test.go
Normal file
232
internal/service/file_staging_service_test.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stagingMockStorage struct {
|
||||||
|
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *stagingMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if m.getObjectFn != nil {
|
||||||
|
return m.getObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return nil, storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *stagingMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *stagingMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
return storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupStagingTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStagingService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *FileStagingService {
|
||||||
|
t.Helper()
|
||||||
|
return NewFileStagingService(
|
||||||
|
store.NewFileStore(db),
|
||||||
|
store.NewBlobStore(db),
|
||||||
|
st,
|
||||||
|
"test-bucket",
|
||||||
|
zap.NewNop(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStaging_DownloadWithDedup(t *testing.T) {
|
||||||
|
db := setupStagingTestDB(t)
|
||||||
|
|
||||||
|
sha1 := "aaa111"
|
||||||
|
sha2 := "bbb222"
|
||||||
|
|
||||||
|
db.Create(&model.FileBlob{SHA256: sha1, MinioKey: "blobs/aaa111", FileSize: 5, MimeType: "text/plain", RefCount: 2})
|
||||||
|
db.Create(&model.FileBlob{SHA256: sha2, MinioKey: "blobs/bbb222", FileSize: 3, MimeType: "text/plain", RefCount: 1})
|
||||||
|
|
||||||
|
db.Create(&model.File{Name: "file1.txt", BlobSHA256: sha1})
|
||||||
|
db.Create(&model.File{Name: "file2.txt", BlobSHA256: sha1})
|
||||||
|
db.Create(&model.File{Name: "file3.txt", BlobSHA256: sha2})
|
||||||
|
|
||||||
|
var files []model.File
|
||||||
|
db.Find(&files)
|
||||||
|
if len(files) < 3 {
|
||||||
|
t.Fatalf("need 3 files, got %d", len(files))
|
||||||
|
}
|
||||||
|
|
||||||
|
var getObjCalls int32
|
||||||
|
st := &stagingMockStorage{}
|
||||||
|
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
atomic.AddInt32(&getObjCalls, 1)
|
||||||
|
var content string
|
||||||
|
switch key {
|
||||||
|
case "blobs/aaa111":
|
||||||
|
content = "content-a"
|
||||||
|
case "blobs/bbb222":
|
||||||
|
content = "content-b"
|
||||||
|
default:
|
||||||
|
return nil, storage.ObjectInfo{}, fmt.Errorf("unexpected key %s", key)
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader([]byte(content))), storage.ObjectInfo{Key: key}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
destDir := t.TempDir()
|
||||||
|
svc := newStagingService(t, st, db)
|
||||||
|
|
||||||
|
err := svc.DownloadFilesToDir(context.Background(), []int64{files[0].ID, files[1].ID, files[2].ID}, destDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DownloadFilesToDir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if calls := atomic.LoadInt32(&getObjCalls); calls != 2 {
|
||||||
|
t.Errorf("GetObject called %d times, want 2", calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := map[string]string{
|
||||||
|
"file1.txt": "content-a",
|
||||||
|
"file2.txt": "content-a",
|
||||||
|
"file3.txt": "content-b",
|
||||||
|
}
|
||||||
|
for name, want := range expected {
|
||||||
|
p := filepath.Join(destDir, name)
|
||||||
|
data, err := os.ReadFile(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read %s: %v", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if string(data) != want {
|
||||||
|
t.Errorf("%s content = %q, want %q", name, data, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStaging_PathTraversal(t *testing.T) {
|
||||||
|
db := setupStagingTestDB(t)
|
||||||
|
|
||||||
|
sha := "traversal123"
|
||||||
|
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/traversal", FileSize: 4, MimeType: "text/plain", RefCount: 1})
|
||||||
|
db.Create(&model.File{Name: "../../../etc/passwd", BlobSHA256: sha})
|
||||||
|
|
||||||
|
var file model.File
|
||||||
|
db.First(&file)
|
||||||
|
|
||||||
|
st := &stagingMockStorage{}
|
||||||
|
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader([]byte("safe"))), storage.ObjectInfo{Key: key}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
destDir := t.TempDir()
|
||||||
|
svc := newStagingService(t, st, db)
|
||||||
|
|
||||||
|
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DownloadFilesToDir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized := filepath.Join(destDir, "passwd")
|
||||||
|
data, err := os.ReadFile(sanitized)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read sanitized file: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "safe" {
|
||||||
|
t.Errorf("content = %q, want %q", data, "safe")
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(destDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readdir: %v", err)
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.Name() != "passwd" {
|
||||||
|
t.Errorf("unexpected file in destDir: %s", e.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStaging_EmptyList(t *testing.T) {
|
||||||
|
db := setupStagingTestDB(t)
|
||||||
|
st := &stagingMockStorage{}
|
||||||
|
svc := newStagingService(t, st, db)
|
||||||
|
|
||||||
|
err := svc.DownloadFilesToDir(context.Background(), []int64{}, t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected nil for empty list, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStaging_GetObjectFails(t *testing.T) {
|
||||||
|
db := setupStagingTestDB(t)
|
||||||
|
|
||||||
|
sha := "fail123"
|
||||||
|
db.Create(&model.FileBlob{SHA256: sha, MinioKey: "blobs/fail", FileSize: 5, MimeType: "text/plain", RefCount: 1})
|
||||||
|
db.Create(&model.File{Name: "willfail.txt", BlobSHA256: sha})
|
||||||
|
|
||||||
|
var file model.File
|
||||||
|
db.First(&file)
|
||||||
|
|
||||||
|
st := &stagingMockStorage{}
|
||||||
|
st.getObjectFn = func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
return nil, storage.ObjectInfo{}, fmt.Errorf("minio down")
|
||||||
|
}
|
||||||
|
|
||||||
|
destDir := t.TempDir()
|
||||||
|
svc := newStagingService(t, st, db)
|
||||||
|
|
||||||
|
err := svc.DownloadFilesToDir(context.Background(), []int64{file.ID}, destDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when GetObject fails")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "minio down") {
|
||||||
|
t.Errorf("error = %q, want 'minio down'", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
142
internal/service/folder_service.go
Normal file
142
internal/service/folder_service.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FolderService provides CRUD operations for folders with path validation
|
||||||
|
// and directory tree management.
|
||||||
|
type FolderService struct {
|
||||||
|
folderStore *store.FolderStore
|
||||||
|
fileStore *store.FileStore
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFolderService creates a new FolderService.
|
||||||
|
func NewFolderService(folderStore *store.FolderStore, fileStore *store.FileStore, logger *zap.Logger) *FolderService {
|
||||||
|
return &FolderService{
|
||||||
|
folderStore: folderStore,
|
||||||
|
fileStore: fileStore,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFolder validates the name, computes a materialized path, checks for
|
||||||
|
// duplicates, and persists the folder.
|
||||||
|
func (s *FolderService) CreateFolder(ctx context.Context, name string, parentID *int64) (*model.FolderResponse, error) {
|
||||||
|
if err := model.ValidateFolderName(name); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid folder name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if parentID == nil {
|
||||||
|
path = "/" + name + "/"
|
||||||
|
} else {
|
||||||
|
parent, err := s.folderStore.GetByID(ctx, *parentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get parent folder: %w", err)
|
||||||
|
}
|
||||||
|
if parent == nil {
|
||||||
|
return nil, fmt.Errorf("parent folder %d not found", *parentID)
|
||||||
|
}
|
||||||
|
path = parent.Path + name + "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := s.folderStore.GetByPath(ctx, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("check duplicate path: %w", err)
|
||||||
|
}
|
||||||
|
if existing != nil {
|
||||||
|
return nil, fmt.Errorf("folder with path %q already exists", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
folder := &model.Folder{
|
||||||
|
Name: name,
|
||||||
|
ParentID: parentID,
|
||||||
|
Path: path,
|
||||||
|
}
|
||||||
|
if err := s.folderStore.Create(ctx, folder); err != nil {
|
||||||
|
return nil, fmt.Errorf("create folder: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.toFolderResponse(ctx, folder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFolder retrieves a folder by ID with file and subfolder counts.
|
||||||
|
func (s *FolderService) GetFolder(ctx context.Context, id int64) (*model.FolderResponse, error) {
|
||||||
|
folder, err := s.folderStore.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get folder: %w", err)
|
||||||
|
}
|
||||||
|
if folder == nil {
|
||||||
|
return nil, fmt.Errorf("folder %d not found", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.toFolderResponse(ctx, folder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFolders returns all direct children of the given parent folder (or root
|
||||||
|
// if parentID is nil).
|
||||||
|
func (s *FolderService) ListFolders(ctx context.Context, parentID *int64) ([]model.FolderResponse, error) {
|
||||||
|
folders, err := s.folderStore.ListByParentID(ctx, parentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list folders: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]model.FolderResponse, 0, len(folders))
|
||||||
|
for i := range folders {
|
||||||
|
resp, err := s.toFolderResponse(ctx, &folders[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, *resp)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFolder soft-deletes a folder only if it has no children (sub-folders
|
||||||
|
// or files).
|
||||||
|
func (s *FolderService) DeleteFolder(ctx context.Context, id int64) error {
|
||||||
|
hasChildren, err := s.folderStore.HasChildren(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check children: %w", err)
|
||||||
|
}
|
||||||
|
if hasChildren {
|
||||||
|
return fmt.Errorf("folder is not empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.folderStore.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("delete folder: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toFolderResponse converts a Folder model into a FolderResponse DTO with
|
||||||
|
// computed file and subfolder counts.
|
||||||
|
func (s *FolderService) toFolderResponse(ctx context.Context, f *model.Folder) (*model.FolderResponse, error) {
|
||||||
|
subFolders, err := s.folderStore.ListByParentID(ctx, &f.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("count subfolders: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, fileCount, err := s.fileStore.List(ctx, &f.ID, 1, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("count files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.FolderResponse{
|
||||||
|
ID: f.ID,
|
||||||
|
Name: f.Name,
|
||||||
|
ParentID: f.ParentID,
|
||||||
|
Path: f.Path,
|
||||||
|
FileCount: fileCount,
|
||||||
|
SubFolderCount: int64(len(subFolders)),
|
||||||
|
CreatedAt: f.CreatedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
230
internal/service/folder_service_test.go
Normal file
230
internal/service/folder_service_test.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFolderService(t *testing.T) *FolderService {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Folder{}, &model.File{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return NewFolderService(
|
||||||
|
store.NewFolderStore(db),
|
||||||
|
store.NewFileStore(db),
|
||||||
|
zap.NewNop(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_ValidName(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
resp, err := svc.CreateFolder(context.Background(), "datasets", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateFolder() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp.Name != "datasets" {
|
||||||
|
t.Errorf("Name = %q, want %q", resp.Name, "datasets")
|
||||||
|
}
|
||||||
|
if resp.Path != "/datasets/" {
|
||||||
|
t.Errorf("Path = %q, want %q", resp.Path, "/datasets/")
|
||||||
|
}
|
||||||
|
if resp.ParentID != nil {
|
||||||
|
t.Errorf("ParentID should be nil for root folder, got %d", *resp.ParentID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_SubFolder(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
parent, err := svc.CreateFolder(context.Background(), "datasets", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create parent: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
child, err := svc.CreateFolder(context.Background(), "images", &parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateFolder() error = %v", err)
|
||||||
|
}
|
||||||
|
if child.Path != "/datasets/images/" {
|
||||||
|
t.Errorf("Path = %q, want %q", child.Path, "/datasets/images/")
|
||||||
|
}
|
||||||
|
if child.ParentID == nil || *child.ParentID != parent.ID {
|
||||||
|
t.Errorf("ParentID = %v, want %d", child.ParentID, parent.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_RejectPathTraversal(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
for _, name := range []string{"..", "../etc", "/absolute", "a/b"} {
|
||||||
|
_, err := svc.CreateFolder(context.Background(), name, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for folder name %q, got nil", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_DuplicatePath(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
_, err := svc.CreateFolder(context.Background(), "datasets", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first create: %v", err)
|
||||||
|
}
|
||||||
|
_, err = svc.CreateFolder(context.Background(), "datasets", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate folder name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "already exists") {
|
||||||
|
t.Errorf("error should mention 'already exists', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFolder_ParentNotFound(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
badID := int64(99999)
|
||||||
|
_, err := svc.CreateFolder(context.Background(), "orphan", &badID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-existent parent")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not found") {
|
||||||
|
t.Errorf("error should mention 'not found', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFolder(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
created, err := svc.CreateFolder(context.Background(), "datasets", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateFolder() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := svc.GetFolder(context.Background(), created.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFolder() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp.ID != created.ID {
|
||||||
|
t.Errorf("ID = %d, want %d", resp.ID, created.ID)
|
||||||
|
}
|
||||||
|
if resp.Name != "datasets" {
|
||||||
|
t.Errorf("Name = %q, want %q", resp.Name, "datasets")
|
||||||
|
}
|
||||||
|
if resp.FileCount != 0 {
|
||||||
|
t.Errorf("FileCount = %d, want 0", resp.FileCount)
|
||||||
|
}
|
||||||
|
if resp.SubFolderCount != 0 {
|
||||||
|
t.Errorf("SubFolderCount = %d, want 0", resp.SubFolderCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFolder_NotFound(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
_, err := svc.GetFolder(context.Background(), 99999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-existent folder")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not found") {
|
||||||
|
t.Errorf("error should mention 'not found', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFolders(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
parent, err := svc.CreateFolder(context.Background(), "root", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create root: %v", err)
|
||||||
|
}
|
||||||
|
_, err = svc.CreateFolder(context.Background(), "child1", &parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create child1: %v", err)
|
||||||
|
}
|
||||||
|
_, err = svc.CreateFolder(context.Background(), "child2", &parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create child2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := svc.ListFolders(context.Background(), &parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFolders() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 2 {
|
||||||
|
t.Fatalf("len(list) = %d, want 2", len(list))
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make(map[string]bool)
|
||||||
|
for _, f := range list {
|
||||||
|
names[f.Name] = true
|
||||||
|
}
|
||||||
|
if !names["child1"] || !names["child2"] {
|
||||||
|
t.Errorf("expected child1 and child2, got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFolders_Root(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
_, err := svc.CreateFolder(context.Background(), "alpha", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create alpha: %v", err)
|
||||||
|
}
|
||||||
|
_, err = svc.CreateFolder(context.Background(), "beta", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create beta: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := svc.ListFolders(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListFolders() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 2 {
|
||||||
|
t.Errorf("len(list) = %d, want 2", len(list))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFolder_Success(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
created, err := svc.CreateFolder(context.Background(), "temp", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateFolder() error = %v", err)
|
||||||
|
}
|
||||||
|
if err := svc.DeleteFolder(context.Background(), created.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteFolder() error = %v", err)
|
||||||
|
}
|
||||||
|
_, err = svc.GetFolder(context.Background(), created.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error after deletion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFolder_NonEmpty(t *testing.T) {
|
||||||
|
svc := setupFolderService(t)
|
||||||
|
parent, err := svc.CreateFolder(context.Background(), "haschild", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create parent: %v", err)
|
||||||
|
}
|
||||||
|
_, err = svc.CreateFolder(context.Background(), "child", &parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create child: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = svc.DeleteFolder(context.Background(), parent.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when deleting non-empty folder")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not empty") {
|
||||||
|
t.Errorf("error should mention 'not empty', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
"gcy_hpc_server/internal/model"
|
||||||
"gcy_hpc_server/internal/slurm"
|
"gcy_hpc_server/internal/slurm"
|
||||||
@@ -31,6 +32,9 @@ func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest)
|
|||||||
Qos: strToPtrOrNil(req.QOS),
|
Qos: strToPtrOrNil(req.QOS),
|
||||||
Name: strToPtrOrNil(req.JobName),
|
Name: strToPtrOrNil(req.JobName),
|
||||||
}
|
}
|
||||||
|
if req.WorkDir != "" {
|
||||||
|
jobDesc.CurrentWorkingDirectory = &req.WorkDir
|
||||||
|
}
|
||||||
if req.CPUs > 0 {
|
if req.CPUs > 0 {
|
||||||
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
|
jobDesc.MinimumCpus = slurm.Ptr(req.CPUs)
|
||||||
}
|
}
|
||||||
@@ -40,17 +44,41 @@ func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
jobDesc.Environment = slurm.StringArray{
|
||||||
|
"PATH=/usr/local/bin:/usr/bin:/bin",
|
||||||
|
"HOME=/root",
|
||||||
|
}
|
||||||
|
|
||||||
submitReq := &slurm.JobSubmitReq{
|
submitReq := &slurm.JobSubmitReq{
|
||||||
Script: &script,
|
Script: &script,
|
||||||
Job: jobDesc,
|
Job: jobDesc,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "SubmitJob"),
|
||||||
|
zap.Any("body", submitReq),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
|
result, _, err := s.client.Jobs.SubmitJob(ctx, submitReq)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "SubmitJob"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
|
s.logger.Error("failed to submit job", zap.Error(err), zap.String("operation", "submit"))
|
||||||
return nil, fmt.Errorf("submit job: %w", err)
|
return nil, fmt.Errorf("submit job: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "SubmitJob"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", result),
|
||||||
|
)
|
||||||
|
|
||||||
resp := &model.JobResponse{}
|
resp := &model.JobResponse{}
|
||||||
if result.Result != nil && result.Result.JobID != nil {
|
if result.Result != nil && result.Result.JobID != nil {
|
||||||
resp.JobID = *result.Result.JobID
|
resp.JobID = *result.Result.JobID
|
||||||
@@ -62,44 +90,173 @@ func (s *JobService) SubmitJob(ctx context.Context, req *model.SubmitJobRequest)
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJobs lists all current jobs from Slurm.
|
// GetJobs lists all current jobs from Slurm with in-memory pagination.
|
||||||
func (s *JobService) GetJobs(ctx context.Context) ([]model.JobResponse, error) {
|
func (s *JobService) GetJobs(ctx context.Context, query *model.JobListQuery) (*model.JobListResponse, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetJobs"),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
|
result, _, err := s.client.Jobs.GetJobs(ctx, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetJobs"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
|
s.logger.Error("failed to get jobs", zap.Error(err), zap.String("operation", "get_jobs"))
|
||||||
return nil, fmt.Errorf("get jobs: %w", err)
|
return nil, fmt.Errorf("get jobs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs := make([]model.JobResponse, 0, len(result.Jobs))
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetJobs"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Int("job_count", len(result.Jobs)),
|
||||||
|
zap.Any("body", result),
|
||||||
|
)
|
||||||
|
|
||||||
|
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
|
||||||
for i := range result.Jobs {
|
for i := range result.Jobs {
|
||||||
jobs = append(jobs, mapJobInfo(&result.Jobs[i]))
|
allJobs = append(allJobs, mapJobInfo(&result.Jobs[i]))
|
||||||
}
|
}
|
||||||
return jobs, nil
|
|
||||||
|
total := len(allJobs)
|
||||||
|
page := query.Page
|
||||||
|
pageSize := query.PageSize
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
startIdx := (page - 1) * pageSize
|
||||||
|
end := startIdx + pageSize
|
||||||
|
if startIdx > total {
|
||||||
|
startIdx = total
|
||||||
|
}
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.JobListResponse{
|
||||||
|
Jobs: allJobs[startIdx:end],
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJob retrieves a single job by ID.
|
// GetJob retrieves a single job by ID. If the job is not found in the active
|
||||||
|
// queue (404 or empty result), it falls back to querying SlurmDBD history.
|
||||||
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
|
func (s *JobService) GetJob(ctx context.Context, jobID string) (*model.JobResponse, error) {
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetJob"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
|
result, _, err := s.client.Jobs.GetJob(ctx, jobID, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if slurm.IsNotFound(err) {
|
||||||
|
s.logger.Debug("job not in active queue, querying history",
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
)
|
||||||
|
return s.getJobFromHistory(ctx, jobID)
|
||||||
|
}
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetJob"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
|
s.logger.Error("failed to get job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "get_job"))
|
||||||
return nil, fmt.Errorf("get job %s: %w", jobID, err)
|
return nil, fmt.Errorf("get job %s: %w", jobID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetJob"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", result),
|
||||||
|
)
|
||||||
|
|
||||||
if len(result.Jobs) == 0 {
|
if len(result.Jobs) == 0 {
|
||||||
return nil, nil
|
s.logger.Debug("empty jobs response, querying history",
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
)
|
||||||
|
return s.getJobFromHistory(ctx, jobID)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := mapJobInfo(&result.Jobs[0])
|
resp := mapJobInfo(&result.Jobs[0])
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *JobService) getJobFromHistory(ctx context.Context, jobID string) (*model.JobResponse, error) {
|
||||||
|
start := time.Now()
|
||||||
|
result, _, err := s.client.SlurmdbJobs.GetJob(ctx, jobID)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Debug("slurmdb API error response",
|
||||||
|
zap.String("operation", "getJobFromHistory"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
if slurm.IsNotFound(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get job history %s: %w", jobID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurmdb API response",
|
||||||
|
zap.String("operation", "getJobFromHistory"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", result),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result.Jobs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := mapSlurmdbJob(&result.Jobs[0])
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CancelJob cancels a job by ID.
|
// CancelJob cancels a job by ID.
|
||||||
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
|
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
|
||||||
_, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "CancelJob"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
result, _, err := s.client.Jobs.DeleteJob(ctx, jobID, nil)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "CancelJob"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
|
s.logger.Error("failed to cancel job", zap.Error(err), zap.String("job_id", jobID), zap.String("operation", "cancel"))
|
||||||
return fmt.Errorf("cancel job %s: %w", jobID, err)
|
return fmt.Errorf("cancel job %s: %w", jobID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "CancelJob"),
|
||||||
|
zap.String("job_id", jobID),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Any("body", result),
|
||||||
|
)
|
||||||
s.logger.Info("job cancelled", zap.String("job_id", jobID))
|
s.logger.Info("job cancelled", zap.String("job_id", jobID))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -128,13 +285,60 @@ func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQ
|
|||||||
if query.EndTime != "" {
|
if query.EndTime != "" {
|
||||||
opts.EndTime = strToPtr(query.EndTime)
|
opts.EndTime = strToPtr(query.EndTime)
|
||||||
}
|
}
|
||||||
|
if query.SubmitTime != "" {
|
||||||
|
opts.SubmitTime = strToPtr(query.SubmitTime)
|
||||||
|
}
|
||||||
|
if query.Cluster != "" {
|
||||||
|
opts.Cluster = strToPtr(query.Cluster)
|
||||||
|
}
|
||||||
|
if query.Qos != "" {
|
||||||
|
opts.Qos = strToPtr(query.Qos)
|
||||||
|
}
|
||||||
|
if query.Constraints != "" {
|
||||||
|
opts.Constraints = strToPtr(query.Constraints)
|
||||||
|
}
|
||||||
|
if query.ExitCode != "" {
|
||||||
|
opts.ExitCode = strToPtr(query.ExitCode)
|
||||||
|
}
|
||||||
|
if query.Node != "" {
|
||||||
|
opts.Node = strToPtr(query.Node)
|
||||||
|
}
|
||||||
|
if query.Reservation != "" {
|
||||||
|
opts.Reservation = strToPtr(query.Reservation)
|
||||||
|
}
|
||||||
|
if query.Groups != "" {
|
||||||
|
opts.Groups = strToPtr(query.Groups)
|
||||||
|
}
|
||||||
|
if query.Wckey != "" {
|
||||||
|
opts.Wckey = strToPtr(query.Wckey)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API request",
|
||||||
|
zap.String("operation", "GetJobHistory"),
|
||||||
|
zap.Any("body", opts),
|
||||||
|
)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
|
result, _, err := s.client.SlurmdbJobs.GetJobs(ctx, opts)
|
||||||
|
took := time.Since(start)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Debug("slurm API error response",
|
||||||
|
zap.String("operation", "GetJobHistory"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
|
s.logger.Error("failed to get job history", zap.Error(err), zap.String("operation", "get_job_history"))
|
||||||
return nil, fmt.Errorf("get job history: %w", err)
|
return nil, fmt.Errorf("get job history: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("slurm API response",
|
||||||
|
zap.String("operation", "GetJobHistory"),
|
||||||
|
zap.Duration("took", took),
|
||||||
|
zap.Int("job_count", len(result.Jobs)),
|
||||||
|
zap.Any("body", result),
|
||||||
|
)
|
||||||
|
|
||||||
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
|
allJobs := make([]model.JobResponse, 0, len(result.Jobs))
|
||||||
for i := range result.Jobs {
|
for i := range result.Jobs {
|
||||||
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
|
allJobs = append(allJobs, mapSlurmdbJob(&result.Jobs[i]))
|
||||||
@@ -150,17 +354,17 @@ func (s *JobService) GetJobHistory(ctx context.Context, query *model.JobHistoryQ
|
|||||||
pageSize = 20
|
pageSize = 20
|
||||||
}
|
}
|
||||||
|
|
||||||
start := (page - 1) * pageSize
|
startIdx := (page - 1) * pageSize
|
||||||
end := start + pageSize
|
end := startIdx + pageSize
|
||||||
if start > total {
|
if startIdx > total {
|
||||||
start = total
|
startIdx = total
|
||||||
}
|
}
|
||||||
if end > total {
|
if end > total {
|
||||||
end = total
|
end = total
|
||||||
}
|
}
|
||||||
|
|
||||||
return &model.JobListResponse{
|
return &model.JobListResponse{
|
||||||
Jobs: allJobs[start:end],
|
Jobs: allJobs[startIdx:end],
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: page,
|
Page: page,
|
||||||
PageSize: pageSize,
|
PageSize: pageSize,
|
||||||
@@ -181,6 +385,14 @@ func strToPtrOrNil(s string) *string {
|
|||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mapUint32NoValToInt32(v *slurm.Uint32NoVal) *int32 {
|
||||||
|
if v != nil && v.Number != nil {
|
||||||
|
n := int32(*v.Number)
|
||||||
|
return &n
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// mapJobInfo maps SDK JobInfo to API JobResponse.
|
// mapJobInfo maps SDK JobInfo to API JobResponse.
|
||||||
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
|
func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
|
||||||
resp := model.JobResponse{}
|
resp := model.JobResponse{}
|
||||||
@@ -194,6 +406,17 @@ func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
|
|||||||
if ji.Partition != nil {
|
if ji.Partition != nil {
|
||||||
resp.Partition = *ji.Partition
|
resp.Partition = *ji.Partition
|
||||||
}
|
}
|
||||||
|
resp.Account = derefStr(ji.Account)
|
||||||
|
resp.User = derefStr(ji.UserName)
|
||||||
|
resp.Cluster = derefStr(ji.Cluster)
|
||||||
|
resp.QOS = derefStr(ji.Qos)
|
||||||
|
resp.Priority = mapUint32NoValToInt32(ji.Priority)
|
||||||
|
resp.TimeLimit = uint32NoValString(ji.TimeLimit)
|
||||||
|
resp.StateReason = derefStr(ji.StateReason)
|
||||||
|
resp.Cpus = mapUint32NoValToInt32(ji.Cpus)
|
||||||
|
resp.Tasks = mapUint32NoValToInt32(ji.Tasks)
|
||||||
|
resp.NodeCount = mapUint32NoValToInt32(ji.NodeCount)
|
||||||
|
resp.BatchHost = derefStr(ji.BatchHost)
|
||||||
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
|
if ji.SubmitTime != nil && ji.SubmitTime.Number != nil {
|
||||||
resp.SubmitTime = ji.SubmitTime.Number
|
resp.SubmitTime = ji.SubmitTime.Number
|
||||||
}
|
}
|
||||||
@@ -210,6 +433,13 @@ func mapJobInfo(ji *slurm.JobInfo) model.JobResponse {
|
|||||||
if ji.Nodes != nil {
|
if ji.Nodes != nil {
|
||||||
resp.Nodes = *ji.Nodes
|
resp.Nodes = *ji.Nodes
|
||||||
}
|
}
|
||||||
|
resp.StdOut = derefStr(ji.StandardOutput)
|
||||||
|
resp.StdErr = derefStr(ji.StandardError)
|
||||||
|
resp.StdIn = derefStr(ji.StandardInput)
|
||||||
|
resp.WorkDir = derefStr(ji.CurrentWorkingDirectory)
|
||||||
|
resp.Command = derefStr(ji.Command)
|
||||||
|
resp.ArrayJobID = mapUint32NoValToInt32(ji.ArrayJobID)
|
||||||
|
resp.ArrayTaskID = mapUint32NoValToInt32(ji.ArrayTaskID)
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,11 +454,20 @@ func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
|
|||||||
}
|
}
|
||||||
if j.State != nil {
|
if j.State != nil {
|
||||||
resp.State = j.State.Current
|
resp.State = j.State.Current
|
||||||
|
resp.StateReason = derefStr(j.State.Reason)
|
||||||
}
|
}
|
||||||
if j.Partition != nil {
|
if j.Partition != nil {
|
||||||
resp.Partition = *j.Partition
|
resp.Partition = *j.Partition
|
||||||
}
|
}
|
||||||
|
resp.Account = derefStr(j.Account)
|
||||||
|
if j.User != nil {
|
||||||
|
resp.User = *j.User
|
||||||
|
}
|
||||||
|
resp.Cluster = derefStr(j.Cluster)
|
||||||
|
resp.QOS = derefStr(j.Qos)
|
||||||
|
resp.Priority = mapUint32NoValToInt32(j.Priority)
|
||||||
if j.Time != nil {
|
if j.Time != nil {
|
||||||
|
resp.TimeLimit = uint32NoValString(j.Time.Limit)
|
||||||
if j.Time.Submission != nil {
|
if j.Time.Submission != nil {
|
||||||
resp.SubmitTime = j.Time.Submission
|
resp.SubmitTime = j.Time.Submission
|
||||||
}
|
}
|
||||||
@@ -239,8 +478,19 @@ func mapSlurmdbJob(j *slurm.Job) model.JobResponse {
|
|||||||
resp.EndTime = j.Time.End
|
resp.EndTime = j.Time.End
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if j.ExitCode != nil && j.ExitCode.ReturnCode != nil && j.ExitCode.ReturnCode.Number != nil {
|
||||||
|
code := int32(*j.ExitCode.ReturnCode.Number)
|
||||||
|
resp.ExitCode = &code
|
||||||
|
}
|
||||||
if j.Nodes != nil {
|
if j.Nodes != nil {
|
||||||
resp.Nodes = *j.Nodes
|
resp.Nodes = *j.Nodes
|
||||||
}
|
}
|
||||||
|
if j.Required != nil {
|
||||||
|
resp.Cpus = j.Required.CPUs
|
||||||
|
}
|
||||||
|
if j.AllocationNodes != nil {
|
||||||
|
resp.NodeCount = j.AllocationNodes
|
||||||
|
}
|
||||||
|
resp.WorkDir = derefStr(j.WorkingDirectory)
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
"gcy_hpc_server/internal/model"
|
||||||
@@ -148,14 +149,17 @@ func TestGetJobs(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
svc := NewJobService(client, zap.NewNop())
|
svc := NewJobService(client, zap.NewNop())
|
||||||
jobs, err := svc.GetJobs(context.Background())
|
result, err := svc.GetJobs(context.Background(), &model.JobListQuery{Page: 1, PageSize: 20})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetJobs: %v", err)
|
t.Fatalf("GetJobs: %v", err)
|
||||||
}
|
}
|
||||||
if len(jobs) != 1 {
|
if result.Total != 1 {
|
||||||
t.Fatalf("expected 1 job, got %d", len(jobs))
|
t.Fatalf("expected total 1, got %d", result.Total)
|
||||||
}
|
}
|
||||||
j := jobs[0]
|
if len(result.Jobs) != 1 {
|
||||||
|
t.Fatalf("expected 1 job, got %d", len(result.Jobs))
|
||||||
|
}
|
||||||
|
j := result.Jobs[0]
|
||||||
if j.JobID != 100 {
|
if j.JobID != 100 {
|
||||||
t.Errorf("expected JobID 100, got %d", j.JobID)
|
t.Errorf("expected JobID 100, got %d", j.JobID)
|
||||||
}
|
}
|
||||||
@@ -174,6 +178,12 @@ func TestGetJobs(t *testing.T) {
|
|||||||
if j.Nodes != "node01" {
|
if j.Nodes != "node01" {
|
||||||
t.Errorf("expected Nodes node01, got %s", j.Nodes)
|
t.Errorf("expected Nodes node01, got %s", j.Nodes)
|
||||||
}
|
}
|
||||||
|
if result.Page != 1 {
|
||||||
|
t.Errorf("expected Page 1, got %d", result.Page)
|
||||||
|
}
|
||||||
|
if result.PageSize != 20 {
|
||||||
|
t.Errorf("expected PageSize 20, got %d", result.PageSize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetJob(t *testing.T) {
|
func TestGetJob(t *testing.T) {
|
||||||
@@ -506,13 +516,13 @@ func TestJobService_SubmitJob_SuccessLog(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.InfoLevel {
|
if entries[2].Level != zapcore.InfoLevel {
|
||||||
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
|
t.Errorf("expected InfoLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["job_name"] != "log-test-job" {
|
if fields["job_name"] != "log-test-job" {
|
||||||
t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"])
|
t.Errorf("expected job_name=log-test-job, got %v", fields["job_name"])
|
||||||
}
|
}
|
||||||
@@ -539,13 +549,13 @@ func TestJobService_SubmitJob_ErrorLog(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.ErrorLevel {
|
if entries[2].Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["operation"] != "submit" {
|
if fields["operation"] != "submit" {
|
||||||
t.Errorf("expected operation=submit, got %v", fields["operation"])
|
t.Errorf("expected operation=submit, got %v", fields["operation"])
|
||||||
}
|
}
|
||||||
@@ -568,13 +578,13 @@ func TestJobService_CancelJob_SuccessLog(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.InfoLevel {
|
if entries[2].Level != zapcore.InfoLevel {
|
||||||
t.Errorf("expected InfoLevel, got %v", entries[0].Level)
|
t.Errorf("expected InfoLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["job_id"] != "555" {
|
if fields["job_id"] != "555" {
|
||||||
t.Errorf("expected job_id=555, got %v", fields["job_id"])
|
t.Errorf("expected job_id=555, got %v", fields["job_id"])
|
||||||
}
|
}
|
||||||
@@ -594,13 +604,13 @@ func TestJobService_CancelJob_ErrorLog(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.ErrorLevel {
|
if entries[2].Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["operation"] != "cancel" {
|
if fields["operation"] != "cancel" {
|
||||||
t.Errorf("expected operation=cancel, got %v", fields["operation"])
|
t.Errorf("expected operation=cancel, got %v", fields["operation"])
|
||||||
}
|
}
|
||||||
@@ -620,19 +630,19 @@ func TestJobService_GetJobs_ErrorLog(t *testing.T) {
|
|||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
svc, recorded := newJobServiceWithObserver(srv)
|
svc, recorded := newJobServiceWithObserver(srv)
|
||||||
_, err := svc.GetJobs(context.Background())
|
_, err := svc.GetJobs(context.Background(), &model.JobListQuery{Page: 1, PageSize: 20})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.ErrorLevel {
|
if entries[2].Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["operation"] != "get_jobs" {
|
if fields["operation"] != "get_jobs" {
|
||||||
t.Errorf("expected operation=get_jobs, got %v", fields["operation"])
|
t.Errorf("expected operation=get_jobs, got %v", fields["operation"])
|
||||||
}
|
}
|
||||||
@@ -655,13 +665,13 @@ func TestJobService_GetJob_ErrorLog(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.ErrorLevel {
|
if entries[2].Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["operation"] != "get_job" {
|
if fields["operation"] != "get_job" {
|
||||||
t.Errorf("expected operation=get_job, got %v", fields["operation"])
|
t.Errorf("expected operation=get_job, got %v", fields["operation"])
|
||||||
}
|
}
|
||||||
@@ -687,13 +697,13 @@ func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entries := recorded.All()
|
entries := recorded.All()
|
||||||
if len(entries) != 1 {
|
if len(entries) != 3 {
|
||||||
t.Fatalf("expected 1 log entry, got %d", len(entries))
|
t.Fatalf("expected 3 log entries, got %d", len(entries))
|
||||||
}
|
}
|
||||||
if entries[0].Level != zapcore.ErrorLevel {
|
if entries[2].Level != zapcore.ErrorLevel {
|
||||||
t.Errorf("expected ErrorLevel, got %v", entries[0].Level)
|
t.Errorf("expected ErrorLevel, got %v", entries[2].Level)
|
||||||
}
|
}
|
||||||
fields := entries[0].ContextMap()
|
fields := entries[2].ContextMap()
|
||||||
if fields["operation"] != "get_job_history" {
|
if fields["operation"] != "get_job_history" {
|
||||||
t.Errorf("expected operation=get_job_history, got %v", fields["operation"])
|
t.Errorf("expected operation=get_job_history, got %v", fields["operation"])
|
||||||
}
|
}
|
||||||
@@ -701,3 +711,157 @@ func TestJobService_GetJobHistory_ErrorLog(t *testing.T) {
|
|||||||
t.Error("expected error field in log entry")
|
t.Error("expected error field in log entry")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Fallback to SlurmDBD history tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetJob_FallbackToHistory_Found(t *testing.T) {
|
||||||
|
jobID := int32(198)
|
||||||
|
name := "hist-job"
|
||||||
|
ts := int64(1700000000)
|
||||||
|
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/slurm/v0.0.40/job/198":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"errors": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"description": "Unable to query JobId=198",
|
||||||
|
"error_number": float64(2017),
|
||||||
|
"error": "Invalid job id specified",
|
||||||
|
"source": "_handle_job_get",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"jobs": []interface{}{},
|
||||||
|
})
|
||||||
|
case "/slurmdb/v0.0.40/job/198":
|
||||||
|
resp := slurm.OpenapiSlurmdbdJobsResp{
|
||||||
|
Jobs: slurm.JobList{
|
||||||
|
{
|
||||||
|
JobID: &jobID,
|
||||||
|
Name: &name,
|
||||||
|
State: &slurm.JobState{Current: []string{"COMPLETED"}},
|
||||||
|
Time: &slurm.JobTime{Submission: &ts, Start: &ts, End: &ts},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
job, err := svc.GetJob(context.Background(), "198")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJob: %v", err)
|
||||||
|
}
|
||||||
|
if job == nil {
|
||||||
|
t.Fatal("expected job, got nil")
|
||||||
|
}
|
||||||
|
if job.JobID != 198 {
|
||||||
|
t.Errorf("expected JobID 198, got %d", job.JobID)
|
||||||
|
}
|
||||||
|
if job.Name != "hist-job" {
|
||||||
|
t.Errorf("expected Name hist-job, got %s", job.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_FallbackToHistory_NotFound(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
job, err := svc.GetJob(context.Background(), "999")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJob: %v", err)
|
||||||
|
}
|
||||||
|
if job != nil {
|
||||||
|
t.Errorf("expected nil, got %+v", job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_FallbackToHistory_HistoryError(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/slurm/v0.0.40/job/500":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"errors": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"description": "Unable to query JobId=500",
|
||||||
|
"error_number": float64(2017),
|
||||||
|
"error": "Invalid job id specified",
|
||||||
|
"source": "_handle_job_get",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"jobs": []interface{}{},
|
||||||
|
})
|
||||||
|
case "/slurmdb/v0.0.40/job/500":
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"errors":[{"error":"db error"}]}`))
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
job, err := svc.GetJob(context.Background(), "500")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if job != nil {
|
||||||
|
t.Errorf("expected nil job, got %+v", job)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "get job history") {
|
||||||
|
t.Errorf("expected error to contain 'get job history', got %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetJob_FallbackToHistory_EmptyHistory(t *testing.T) {
|
||||||
|
client, cleanup := mockJobServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/slurm/v0.0.40/job/777":
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"errors": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"description": "Unable to query JobId=777",
|
||||||
|
"error_number": float64(2017),
|
||||||
|
"error": "Invalid job id specified",
|
||||||
|
"source": "_handle_job_get",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"jobs": []interface{}{},
|
||||||
|
})
|
||||||
|
case "/slurmdb/v0.0.40/job/777":
|
||||||
|
resp := slurm.OpenapiSlurmdbdJobsResp{Jobs: slurm.JobList{}}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
svc := NewJobService(client, zap.NewNop())
|
||||||
|
job, err := svc.GetJob(context.Background(), "777")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetJob: %v", err)
|
||||||
|
}
|
||||||
|
if job != nil {
|
||||||
|
t.Errorf("expected nil, got %+v", job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
112
internal/service/script_utils.go
Normal file
112
internal/service/script_utils.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var paramNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
||||||
|
|
||||||
|
// ValidateParams checks that all required parameters are present and values match their types.
|
||||||
|
// Parameters not in the schema are silently ignored.
|
||||||
|
func ValidateParams(params []model.ParameterSchema, values map[string]string) error {
|
||||||
|
var errs []string
|
||||||
|
|
||||||
|
for _, p := range params {
|
||||||
|
if !paramNameRegex.MatchString(p.Name) {
|
||||||
|
errs = append(errs, fmt.Sprintf("invalid parameter name %q: must match ^[A-Za-z_][A-Za-z0-9_]*$", p.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val, ok := values[p.Name]
|
||||||
|
|
||||||
|
if p.Required && !ok {
|
||||||
|
errs = append(errs, fmt.Sprintf("required parameter %q is missing", p.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p.Type {
|
||||||
|
case model.ParamTypeInteger:
|
||||||
|
if _, err := strconv.Atoi(val); err != nil {
|
||||||
|
errs = append(errs, fmt.Sprintf("parameter %q must be an integer, got %q", p.Name, val))
|
||||||
|
}
|
||||||
|
case model.ParamTypeBoolean:
|
||||||
|
if val != "true" && val != "false" && val != "1" && val != "0" {
|
||||||
|
errs = append(errs, fmt.Sprintf("parameter %q must be a boolean (true/false/1/0), got %q", p.Name, val))
|
||||||
|
}
|
||||||
|
case model.ParamTypeEnum:
|
||||||
|
if len(p.Options) > 0 {
|
||||||
|
found := false
|
||||||
|
for _, opt := range p.Options {
|
||||||
|
if val == opt {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
errs = append(errs, fmt.Sprintf("parameter %q must be one of %v, got %q", p.Name, p.Options, val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case model.ParamTypeFile, model.ParamTypeDirectory:
|
||||||
|
case model.ParamTypeString:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return fmt.Errorf("parameter validation failed: %s", strings.Join(errs, "; "))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenderScript replaces $PARAM tokens in the template with user-provided values.
|
||||||
|
// Only tokens defined in the schema are replaced. Replacement is done longest-name-first
|
||||||
|
// to avoid partial matches (e.g., $JOB_NAME before $JOB).
|
||||||
|
// All values are shell-escaped using single-quote wrapping.
|
||||||
|
func RenderScript(template string, params []model.ParameterSchema, values map[string]string) string {
|
||||||
|
sorted := make([]model.ParameterSchema, len(params))
|
||||||
|
copy(sorted, params)
|
||||||
|
sort.Slice(sorted, func(i, j int) bool {
|
||||||
|
return len(sorted[i].Name) > len(sorted[j].Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
result := template
|
||||||
|
for _, p := range sorted {
|
||||||
|
val, ok := values[p.Name]
|
||||||
|
if !ok {
|
||||||
|
if p.Default != "" {
|
||||||
|
val = p.Default
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
escaped := "'" + strings.ReplaceAll(val, "'", "'\\''") + "'"
|
||||||
|
result = strings.ReplaceAll(result, "$"+p.Name, escaped)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeDirName sanitizes a directory name.
|
||||||
|
func SanitizeDirName(name string) string {
|
||||||
|
replacer := strings.NewReplacer(" ", "_", "/", "_", "\\", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
|
||||||
|
return replacer.Replace(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomSuffix generates a random suffix of length n.
|
||||||
|
func RandomSuffix(n int) string {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[rand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
554
internal/service/task_service.go
Normal file
554
internal/service/task_service.go
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskService struct {
|
||||||
|
taskStore *store.TaskStore
|
||||||
|
appStore *store.ApplicationStore
|
||||||
|
fileStore *store.FileStore // nil ok
|
||||||
|
blobStore *store.BlobStore // nil ok
|
||||||
|
stagingSvc *FileStagingService // nil ok — MinIO unavailable
|
||||||
|
jobSvc *JobService
|
||||||
|
workDirBase string
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
// async processing
|
||||||
|
taskCh chan int64 // buffered channel, cap=16
|
||||||
|
cancelFn context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
mu sync.Mutex // protects taskCh from send-on-closed
|
||||||
|
started bool // prevent double-start
|
||||||
|
stopped bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskService(
|
||||||
|
taskStore *store.TaskStore,
|
||||||
|
appStore *store.ApplicationStore,
|
||||||
|
fileStore *store.FileStore,
|
||||||
|
blobStore *store.BlobStore,
|
||||||
|
stagingSvc *FileStagingService,
|
||||||
|
jobSvc *JobService,
|
||||||
|
workDirBase string,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) *TaskService {
|
||||||
|
return &TaskService{
|
||||||
|
taskStore: taskStore,
|
||||||
|
appStore: appStore,
|
||||||
|
fileStore: fileStore,
|
||||||
|
blobStore: blobStore,
|
||||||
|
stagingSvc: stagingSvc,
|
||||||
|
jobSvc: jobSvc,
|
||||||
|
workDirBase: workDirBase,
|
||||||
|
logger: logger,
|
||||||
|
taskCh: make(chan int64, 16),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) CreateTask(ctx context.Context, req *model.CreateTaskRequest) (*model.Task, error) {
|
||||||
|
app, err := s.appStore.GetByID(ctx, req.AppID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get application: %w", err)
|
||||||
|
}
|
||||||
|
if app == nil {
|
||||||
|
return nil, fmt.Errorf("application %d not found", req.AppID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Validate file limit
|
||||||
|
if len(req.InputFileIDs) > 100 {
|
||||||
|
return nil, fmt.Errorf("input file count %d exceeds limit of 100", len(req.InputFileIDs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Deduplicate file IDs
|
||||||
|
fileIDs := uniqueInt64s(req.InputFileIDs)
|
||||||
|
|
||||||
|
// 4. Validate file IDs exist
|
||||||
|
if s.fileStore != nil && len(fileIDs) > 0 {
|
||||||
|
files, err := s.fileStore.GetByIDs(ctx, fileIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate file ids: %w", err)
|
||||||
|
}
|
||||||
|
found := make(map[int64]bool, len(files))
|
||||||
|
for _, f := range files {
|
||||||
|
found[f.ID] = true
|
||||||
|
}
|
||||||
|
for _, id := range fileIDs {
|
||||||
|
if !found[id] {
|
||||||
|
return nil, fmt.Errorf("file %d not found", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Auto-generate task name if empty
|
||||||
|
taskName := req.TaskName
|
||||||
|
if taskName == "" {
|
||||||
|
taskName = SanitizeDirName(app.Name) + "_" + time.Now().Format("20060102_150405")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Marshal values
|
||||||
|
valuesJSON := json.RawMessage(`{}`)
|
||||||
|
if len(req.Values) > 0 {
|
||||||
|
b, err := json.Marshal(req.Values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal values: %w", err)
|
||||||
|
}
|
||||||
|
valuesJSON = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Marshal input_file_ids
|
||||||
|
fileIDsJSON := json.RawMessage(`[]`)
|
||||||
|
if len(fileIDs) > 0 {
|
||||||
|
b, err := json.Marshal(fileIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal file ids: %w", err)
|
||||||
|
}
|
||||||
|
fileIDsJSON = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. Create task record
|
||||||
|
task := &model.Task{
|
||||||
|
TaskName: taskName,
|
||||||
|
AppID: app.ID,
|
||||||
|
AppName: app.Name,
|
||||||
|
Status: model.TaskStatusSubmitted,
|
||||||
|
Values: valuesJSON,
|
||||||
|
InputFileIDs: fileIDsJSON,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, err := s.taskStore.Create(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create task: %w", err)
|
||||||
|
}
|
||||||
|
task.ID = taskID
|
||||||
|
|
||||||
|
return task, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessTask runs the full synchronous processing pipeline for a task.
|
||||||
|
func (s *TaskService) ProcessTask(ctx context.Context, taskID int64) error {
|
||||||
|
// 1. Fetch task
|
||||||
|
task, err := s.taskStore.GetByID(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get task: %w", err)
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
return fmt.Errorf("task %d not found", taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
fail := func(step, msg string) error {
|
||||||
|
_ = s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusFailed, msg)
|
||||||
|
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusFailed, step, task.RetryCount)
|
||||||
|
return fmt.Errorf("%s", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentStep := task.CurrentStep
|
||||||
|
|
||||||
|
var workDir string
|
||||||
|
var app *model.Application
|
||||||
|
|
||||||
|
if currentStep == "" || currentStep == model.TaskStepPreparing {
|
||||||
|
// 2. Set preparing
|
||||||
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusPreparing, model.TaskStepPreparing, 0); err != nil {
|
||||||
|
return fail(model.TaskStepPreparing, fmt.Sprintf("update status to preparing: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fetch app
|
||||||
|
app, err = s.appStore.GetByID(ctx, task.AppID)
|
||||||
|
if err != nil {
|
||||||
|
return fail(model.TaskStepPreparing, fmt.Sprintf("get application: %v", err))
|
||||||
|
}
|
||||||
|
if app == nil {
|
||||||
|
return fail(model.TaskStepPreparing, fmt.Sprintf("application %d not found", task.AppID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4-5. Create work directory
|
||||||
|
workDir = filepath.Join(s.workDirBase, SanitizeDirName(app.Name), time.Now().Format("20060102_150405")+"_"+RandomSuffix(4))
|
||||||
|
if err := os.MkdirAll(workDir, 0777); err != nil {
|
||||||
|
return fail(model.TaskStepPreparing, fmt.Sprintf("create work directory %s: %v", workDir, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. CHMOD traversal — critical for multi-user HPC
|
||||||
|
for dir := workDir; dir != s.workDirBase; dir = filepath.Dir(dir) {
|
||||||
|
os.Chmod(dir, 0777)
|
||||||
|
}
|
||||||
|
os.Chmod(s.workDirBase, 0777)
|
||||||
|
|
||||||
|
// 7. UpdateWorkDir
|
||||||
|
if err := s.taskStore.UpdateWorkDir(ctx, taskID, workDir); err != nil {
|
||||||
|
return fail(model.TaskStepPreparing, fmt.Sprintf("update work dir: %v", err))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
app, err = s.appStore.GetByID(ctx, task.AppID)
|
||||||
|
if err != nil {
|
||||||
|
return fail(currentStep, fmt.Sprintf("get application: %v", err))
|
||||||
|
}
|
||||||
|
if app == nil {
|
||||||
|
return fail(currentStep, fmt.Sprintf("application %d not found", task.AppID))
|
||||||
|
}
|
||||||
|
workDir = task.WorkDir
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentStep == "" || currentStep == model.TaskStepPreparing || currentStep == model.TaskStepDownloading {
|
||||||
|
if currentStep == model.TaskStepDownloading && workDir != "" {
|
||||||
|
matches, _ := filepath.Glob(filepath.Join(workDir, "*"))
|
||||||
|
for _, f := range matches {
|
||||||
|
os.Remove(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. Set downloading
|
||||||
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusDownloading, model.TaskStepDownloading, 0); err != nil {
|
||||||
|
return fail(model.TaskStepDownloading, fmt.Sprintf("update status to downloading: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Parse input_file_ids
|
||||||
|
var fileIDs []int64
|
||||||
|
if len(task.InputFileIDs) > 0 {
|
||||||
|
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
|
||||||
|
return fail(model.TaskStepDownloading, fmt.Sprintf("parse input file ids: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 10-12. Download files
|
||||||
|
if len(fileIDs) > 0 {
|
||||||
|
if s.stagingSvc == nil {
|
||||||
|
return fail(model.TaskStepDownloading, "MinIO unavailable, cannot stage files")
|
||||||
|
}
|
||||||
|
if err := s.stagingSvc.DownloadFilesToDir(ctx, fileIDs, workDir); err != nil {
|
||||||
|
return fail(model.TaskStepDownloading, fmt.Sprintf("download files: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 13-14. Set ready + submitting
|
||||||
|
if err := s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusReady, model.TaskStepSubmitting, 0); err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to ready: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 15. Parse app parameters
|
||||||
|
var params []model.ParameterSchema
|
||||||
|
if len(app.Parameters) > 0 {
|
||||||
|
if err := json.Unmarshal(app.Parameters, ¶ms); err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse parameters: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 16. Parse task values
|
||||||
|
values := make(map[string]string)
|
||||||
|
if len(task.Values) > 0 {
|
||||||
|
if err := json.Unmarshal(task.Values, &values); err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("parse values: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateParams(params, values); err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 17. Render script
|
||||||
|
rendered := RenderScript(app.ScriptTemplate, params, values)
|
||||||
|
|
||||||
|
// 18. Submit to Slurm
|
||||||
|
jobResp, err := s.jobSvc.SubmitJob(ctx, &model.SubmitJobRequest{
|
||||||
|
Script: rendered,
|
||||||
|
WorkDir: workDir,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("submit job: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 19. Update slurm_job_id and status to queued
|
||||||
|
if err := s.taskStore.UpdateSlurmJobID(ctx, taskID, &jobResp.JobID); err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update slurm job id: %v", err))
|
||||||
|
}
|
||||||
|
if err := s.taskStore.UpdateStatus(ctx, taskID, model.TaskStatusQueued, ""); err != nil {
|
||||||
|
return fail(model.TaskStepSubmitting, fmt.Sprintf("update status to queued: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTasks returns a paginated list of tasks.
|
||||||
|
func (s *TaskService) ListTasks(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
||||||
|
return s.taskStore.List(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessTaskSync creates and processes a task synchronously, returning a JobResponse
|
||||||
|
// for old API compatibility.
|
||||||
|
func (s *TaskService) ProcessTaskSync(ctx context.Context, req *model.CreateTaskRequest) (*model.JobResponse, error) {
|
||||||
|
// 1. Create task
|
||||||
|
task, err := s.CreateTask(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Process synchronously
|
||||||
|
if err := s.ProcessTask(ctx, task.ID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Re-fetch to get updated slurm_job_id
|
||||||
|
task, err = s.taskStore.GetByID(ctx, task.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("re-fetch task: %w", err)
|
||||||
|
}
|
||||||
|
if task == nil || task.SlurmJobID == nil {
|
||||||
|
return nil, fmt.Errorf("task has no slurm job id after processing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Return JobResponse
|
||||||
|
return &model.JobResponse{JobID: *task.SlurmJobID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// uniqueInt64s deduplicates and sorts a slice of int64.
|
||||||
|
func uniqueInt64s(ids []int64) []int64 {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
seen := make(map[int64]bool, len(ids))
|
||||||
|
result := make([]int64, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
if !seen[id] {
|
||||||
|
seen[id] = true
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) mapSlurmStateToTaskStatus(slurmState []string) string {
|
||||||
|
if len(slurmState) == 0 {
|
||||||
|
return model.TaskStatusRunning
|
||||||
|
}
|
||||||
|
|
||||||
|
state := strings.ToUpper(slurmState[0])
|
||||||
|
switch state {
|
||||||
|
case "PENDING":
|
||||||
|
return model.TaskStatusQueued
|
||||||
|
case "RUNNING", "CONFIGURING", "COMPLETING", "SPECIAL_EXIT":
|
||||||
|
return model.TaskStatusRunning
|
||||||
|
case "COMPLETED":
|
||||||
|
return model.TaskStatusCompleted
|
||||||
|
case "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED":
|
||||||
|
return model.TaskStatusFailed
|
||||||
|
default:
|
||||||
|
return model.TaskStatusRunning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) refreshTaskStatus(ctx context.Context, taskID int64) error {
|
||||||
|
task, err := s.taskStore.GetByID(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to fetch task for refresh",
|
||||||
|
zap.Int64("task_id", taskID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if task == nil || task.SlurmJobID == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jobResp, err := s.jobSvc.GetJob(ctx, strconv.FormatInt(int64(*task.SlurmJobID), 10))
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("failed to query slurm job status during refresh",
|
||||||
|
zap.Int64("task_id", taskID),
|
||||||
|
zap.Int32("slurm_job_id", *task.SlurmJobID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if jobResp == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newStatus := s.mapSlurmStateToTaskStatus(jobResp.State)
|
||||||
|
if newStatus != task.Status {
|
||||||
|
s.logger.Info("updating task status from slurm",
|
||||||
|
zap.Int64("task_id", taskID),
|
||||||
|
zap.String("old_status", task.Status),
|
||||||
|
zap.String("new_status", newStatus),
|
||||||
|
)
|
||||||
|
return s.taskStore.UpdateStatus(ctx, taskID, newStatus, "")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) RefreshStaleTasks(ctx context.Context) error {
|
||||||
|
staleThreshold := 30 * time.Second
|
||||||
|
nonTerminal := []string{model.TaskStatusQueued, model.TaskStatusRunning}
|
||||||
|
|
||||||
|
for _, status := range nonTerminal {
|
||||||
|
tasks, _, err := s.taskStore.List(ctx, &model.TaskListQuery{
|
||||||
|
Status: status,
|
||||||
|
Page: 1,
|
||||||
|
PageSize: 1000,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("failed to list tasks for stale refresh",
|
||||||
|
zap.String("status", status),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoff := time.Now().Add(-staleThreshold)
|
||||||
|
for i := range tasks {
|
||||||
|
if tasks[i].UpdatedAt.Before(cutoff) {
|
||||||
|
if err := s.refreshTaskStatus(ctx, tasks[i].ID); err != nil {
|
||||||
|
s.logger.Warn("failed to refresh stale task",
|
||||||
|
zap.Int64("task_id", tasks[i].ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) StartProcessor(ctx context.Context) {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.started {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.started = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
ctx, s.cancelFn = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
s.logger.Error("processor panic", zap.Any("panic", r))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case taskID, ok := <-s.taskCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
taskCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
|
||||||
|
s.processWithRetry(taskCtx, taskID)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.RecoverStuckTasks(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) SubmitAsync(ctx context.Context, req *model.CreateTaskRequest) (int64, error) {
|
||||||
|
task, err := s.CreateTask(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.stopped {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return 0, fmt.Errorf("processor stopped, cannot submit task")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case s.taskCh <- task.ID:
|
||||||
|
default:
|
||||||
|
s.logger.Warn("task channel full, submit dropped", zap.Int64("taskID", task.ID))
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return task.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) StopProcessor() {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.stopped {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopped = true
|
||||||
|
close(s.taskCh)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.cancelFn != nil {
|
||||||
|
s.cancelFn()
|
||||||
|
}
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
drainCh := s.taskCh
|
||||||
|
s.taskCh = make(chan int64, 16)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
for taskID := range drainCh {
|
||||||
|
_ = s.taskStore.UpdateStatus(context.Background(), taskID, model.TaskStatusSubmitted, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) processWithRetry(ctx context.Context, taskID int64) {
|
||||||
|
err := s.ProcessTask(ctx, taskID)
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
task, fetchErr := s.taskStore.GetByID(ctx, taskID)
|
||||||
|
if fetchErr != nil || task == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.RetryCount < 3 {
|
||||||
|
_ = s.taskStore.UpdateRetryState(ctx, taskID, model.TaskStatusSubmitted, task.CurrentStep, task.RetryCount+1)
|
||||||
|
s.mu.Lock()
|
||||||
|
if !s.stopped {
|
||||||
|
select {
|
||||||
|
case s.taskCh <- taskID:
|
||||||
|
default:
|
||||||
|
s.logger.Warn("task channel full, retry dropped", zap.Int64("taskID", taskID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskService) RecoverStuckTasks(ctx context.Context) {
|
||||||
|
tasks, err := s.taskStore.GetStuckTasks(ctx, 5*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get stuck tasks", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range tasks {
|
||||||
|
_ = s.taskStore.UpdateStatus(ctx, tasks[i].ID, model.TaskStatusSubmitted, "")
|
||||||
|
s.mu.Lock()
|
||||||
|
if !s.stopped {
|
||||||
|
select {
|
||||||
|
case s.taskCh <- tasks[i].ID:
|
||||||
|
default:
|
||||||
|
s.logger.Warn("task channel full, stuck task recovery dropped", zap.Int64("taskID", tasks[i].ID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
416
internal/service/task_service_async_test.go
Normal file
416
internal/service/task_service_async_test.go
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupAsyncTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
type asyncTestEnv struct {
|
||||||
|
taskStore *store.TaskStore
|
||||||
|
appStore *store.ApplicationStore
|
||||||
|
svc *TaskService
|
||||||
|
srv *httptest.Server
|
||||||
|
db *gorm.DB
|
||||||
|
workDirBase string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAsyncTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *asyncTestEnv {
|
||||||
|
t.Helper()
|
||||||
|
db := setupAsyncTestDB(t)
|
||||||
|
|
||||||
|
ts := store.NewTaskStore(db)
|
||||||
|
as := store.NewApplicationStore(db)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(slurmHandler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
jobSvc := NewJobService(client, zap.NewNop())
|
||||||
|
|
||||||
|
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||||
|
os.MkdirAll(workDirBase, 0777)
|
||||||
|
|
||||||
|
svc := NewTaskService(ts, as, nil, nil, nil, jobSvc, workDirBase, zap.NewNop())
|
||||||
|
|
||||||
|
return &asyncTestEnv{
|
||||||
|
taskStore: ts,
|
||||||
|
appStore: as,
|
||||||
|
svc: svc,
|
||||||
|
srv: srv,
|
||||||
|
db: db,
|
||||||
|
workDirBase: workDirBase,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *asyncTestEnv) close() {
|
||||||
|
e.srv.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *asyncTestEnv) createApp(t *testing.T, name, script string) int64 {
|
||||||
|
t.Helper()
|
||||||
|
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: name,
|
||||||
|
ScriptTemplate: script,
|
||||||
|
Parameters: json.RawMessage(`[]`),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_Async_SubmitAndProcess(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "async-app", "#!/bin/bash\necho hello")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "async-test",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitAsync: %v", err)
|
||||||
|
}
|
||||||
|
if taskID == 0 {
|
||||||
|
t.Fatal("expected non-zero task ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
task, err := env.taskStore.GetByID(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID: %v", err)
|
||||||
|
}
|
||||||
|
if task.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusQueued)
|
||||||
|
}
|
||||||
|
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_Retry_MaxExhaustion(t *testing.T) {
|
||||||
|
callCount := int32(0)
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&callCount, 1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"slurm down"}`))
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "retry-app", "#!/bin/bash\necho hello")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
taskID, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "retry-test",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SubmitAsync: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
task, _ := env.taskStore.GetByID(ctx, taskID)
|
||||||
|
if task.Status != model.TaskStatusFailed {
|
||||||
|
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusFailed)
|
||||||
|
}
|
||||||
|
if task.RetryCount < 3 {
|
||||||
|
t.Errorf("RetryCount = %d, want >= 3", task.RetryCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_Recover_StuckTasks(t *testing.T) {
|
||||||
|
jobID := int32(99)
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "stuck-app", "#!/bin/bash\necho hello")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
task := &model.Task{
|
||||||
|
TaskName: "stuck-task",
|
||||||
|
AppID: appID,
|
||||||
|
AppName: "stuck-app",
|
||||||
|
Status: model.TaskStatusPreparing,
|
||||||
|
CurrentStep: model.TaskStepPreparing,
|
||||||
|
RetryCount: 0,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
taskID, err := env.taskStore.Create(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create stuck task: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
staleTime := time.Now().Add(-10 * time.Minute)
|
||||||
|
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, taskID)
|
||||||
|
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
updated, _ := env.taskStore.GetByID(ctx, taskID)
|
||||||
|
if updated.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||||
|
}
|
||||||
|
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_Shutdown_InFlight(t *testing.T) {
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
jobID := int32(77)
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "shutdown-app", "#!/bin/bash\necho hello")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "shutdown-test",
|
||||||
|
})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("StopProcessor did not complete within timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
task, _ := env.taskStore.GetByID(ctx, taskID)
|
||||||
|
if task.Status != model.TaskStatusQueued && task.Status != model.TaskStatusSubmitted {
|
||||||
|
t.Logf("task status after shutdown: %q (acceptable)", task.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_PanicRecovery(t *testing.T) {
|
||||||
|
jobID := int32(55)
|
||||||
|
panicDone := int32(0)
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if atomic.CompareAndSwapInt32(&panicDone, 0, 1) {
|
||||||
|
panic("intentional test panic")
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "panic-app", "#!/bin/bash\necho hello")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
taskID, _ := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "panic-test",
|
||||||
|
})
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
atomic.StoreInt32(&panicDone, 1)
|
||||||
|
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
_ = taskID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_SubmitAsync_DuringShutdown(t *testing.T) {
|
||||||
|
env := newAsyncTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "shutdown-err-app", "#!/bin/bash\necho hello")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
|
||||||
|
_, err := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "after-shutdown",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when submitting after shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTaskService_SubmitAsync_ChannelFull_NonBlocking verifies SubmitAsync
|
||||||
|
// returns without blocking when the task channel buffer (cap=16) is full.
|
||||||
|
// Before fix: SubmitAsync holds s.mu while blocking on full channel → deadlock.
|
||||||
|
// After fix: non-blocking select returns immediately.
|
||||||
|
func TestTaskService_SubmitAsync_ChannelFull_NonBlocking(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "channel-full-app", "#!/bin/bash\necho hello")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
|
||||||
|
taskIDs := make([]int64, 17)
|
||||||
|
for i := range taskIDs {
|
||||||
|
id, err := env.taskStore.Create(ctx, &model.Task{
|
||||||
|
TaskName: fmt.Sprintf("fill-%d", i),
|
||||||
|
AppID: appID,
|
||||||
|
AppName: "channel-full-app",
|
||||||
|
Status: model.TaskStatusSubmitted,
|
||||||
|
CurrentStep: model.TaskStepSubmitting,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create fill task %d: %v", i, err)
|
||||||
|
}
|
||||||
|
taskIDs[i] = id
|
||||||
|
}
|
||||||
|
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
defer env.svc.StopProcessor()
|
||||||
|
|
||||||
|
// Consumer grabs first ID immediately; remaining 15 sit in channel.
|
||||||
|
// Push one more to fill buffer to 16 (full).
|
||||||
|
for _, id := range taskIDs {
|
||||||
|
env.svc.taskCh <- id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overflow submit: must return within 3s (non-blocking after fix)
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, submitErr := env.svc.SubmitAsync(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "overflow-task",
|
||||||
|
})
|
||||||
|
done <- submitErr
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("SubmitAsync returned error (acceptable after fix): %v", err)
|
||||||
|
} else {
|
||||||
|
t.Log("SubmitAsync returned without blocking — channel send is non-blocking")
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("SubmitAsync blocked for >3s — channel send is blocking, potential deadlock")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTaskService_Retry_ChannelFull_NonBlocking verifies processWithRetry
|
||||||
|
// does not deadlock when re-enqueuing a failed task into a full channel.
|
||||||
|
// Before fix: processWithRetry holds s.mu while blocking on s.taskCh <- taskID → deadlock.
|
||||||
|
// After fix: non-blocking select drops the retry with a Warn log.
|
||||||
|
func TestTaskService_Retry_ChannelFull_NonBlocking(t *testing.T) {
|
||||||
|
env := newAsyncTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"slurm down"}`))
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "retry-full-app", "#!/bin/bash\necho hello")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create all tasks BEFORE starting processor to avoid SQLite concurrent-write contention
|
||||||
|
taskIDs := make([]int64, 17)
|
||||||
|
for i := range taskIDs {
|
||||||
|
id, err := env.taskStore.Create(ctx, &model.Task{
|
||||||
|
TaskName: fmt.Sprintf("retry-%d", i),
|
||||||
|
AppID: appID,
|
||||||
|
AppName: "retry-full-app",
|
||||||
|
Status: model.TaskStatusSubmitted,
|
||||||
|
CurrentStep: model.TaskStepSubmitting,
|
||||||
|
RetryCount: 0,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create retry task %d: %v", i, err)
|
||||||
|
}
|
||||||
|
taskIDs[i] = id
|
||||||
|
}
|
||||||
|
|
||||||
|
env.svc.StartProcessor(ctx)
|
||||||
|
|
||||||
|
// Push all 17 IDs: consumer grabs one (processing ~1s), 16 fill the buffer
|
||||||
|
for _, id := range taskIDs {
|
||||||
|
env.svc.taskCh <- id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for consumer to finish first task and attempt retry into full channel
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
// If processWithRetry deadlocked holding s.mu, StopProcessor hangs on mutex acquisition
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
env.svc.StopProcessor()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Log("StopProcessor completed — retry channel send is non-blocking")
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("StopProcessor did not complete within 5s — deadlock from retry channel send")
|
||||||
|
}
|
||||||
|
}
|
||||||
294
internal/service/task_service_status_test.go
Normal file
294
internal/service/task_service_status_test.go
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTaskSvcTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Task{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
type taskSvcTestEnv struct {
|
||||||
|
taskStore *store.TaskStore
|
||||||
|
jobSvc *JobService
|
||||||
|
svc *TaskService
|
||||||
|
srv *httptest.Server
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTaskSvcTestEnv(t *testing.T, handler http.HandlerFunc) *taskSvcTestEnv {
|
||||||
|
t.Helper()
|
||||||
|
db := newTaskSvcTestDB(t)
|
||||||
|
ts := store.NewTaskStore(db)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(handler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
jobSvc := NewJobService(client, zap.NewNop())
|
||||||
|
svc := NewTaskService(ts, nil, nil, nil, nil, jobSvc, "/tmp", zap.NewNop())
|
||||||
|
|
||||||
|
return &taskSvcTestEnv{
|
||||||
|
taskStore: ts,
|
||||||
|
jobSvc: jobSvc,
|
||||||
|
svc: svc,
|
||||||
|
srv: srv,
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *taskSvcTestEnv) close() {
|
||||||
|
e.srv.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTaskForTest(name, status string, slurmJobID *int32) *model.Task {
|
||||||
|
return &model.Task{
|
||||||
|
TaskName: name,
|
||||||
|
AppID: 1,
|
||||||
|
AppName: "test-app",
|
||||||
|
Status: status,
|
||||||
|
CurrentStep: "",
|
||||||
|
RetryCount: 0,
|
||||||
|
UserID: "user1",
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
SlurmJobID: slurmJobID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_MapSlurmState_AllStates(t *testing.T) {
|
||||||
|
env := newTaskSvcTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
input []string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{[]string{"PENDING"}, model.TaskStatusQueued},
|
||||||
|
{[]string{"RUNNING"}, model.TaskStatusRunning},
|
||||||
|
{[]string{"CONFIGURING"}, model.TaskStatusRunning},
|
||||||
|
{[]string{"COMPLETING"}, model.TaskStatusRunning},
|
||||||
|
{[]string{"COMPLETED"}, model.TaskStatusCompleted},
|
||||||
|
{[]string{"FAILED"}, model.TaskStatusFailed},
|
||||||
|
{[]string{"CANCELLED"}, model.TaskStatusFailed},
|
||||||
|
{[]string{"TIMEOUT"}, model.TaskStatusFailed},
|
||||||
|
{[]string{"NODE_FAIL"}, model.TaskStatusFailed},
|
||||||
|
{[]string{"OUT_OF_MEMORY"}, model.TaskStatusFailed},
|
||||||
|
{[]string{"PREEMPTED"}, model.TaskStatusFailed},
|
||||||
|
{[]string{"SPECIAL_EXIT"}, model.TaskStatusRunning},
|
||||||
|
{[]string{"unknown_state"}, model.TaskStatusRunning},
|
||||||
|
{[]string{"pending"}, model.TaskStatusQueued},
|
||||||
|
{[]string{"Running"}, model.TaskStatusRunning},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
got := env.svc.mapSlurmStateToTaskStatus(tc.input)
|
||||||
|
if got != tc.expected {
|
||||||
|
t.Errorf("mapSlurmStateToTaskStatus(%v) = %q, want %q", tc.input, got, tc.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_MapSlurmState_Empty(t *testing.T) {
|
||||||
|
env := newTaskSvcTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
got := env.svc.mapSlurmStateToTaskStatus([]string{})
|
||||||
|
if got != model.TaskStatusRunning {
|
||||||
|
t.Errorf("mapSlurmStateToTaskStatus([]) = %q, want %q", got, model.TaskStatusRunning)
|
||||||
|
}
|
||||||
|
|
||||||
|
got = env.svc.mapSlurmStateToTaskStatus(nil)
|
||||||
|
if got != model.TaskStatusRunning {
|
||||||
|
t.Errorf("mapSlurmStateToTaskStatus(nil) = %q, want %q", got, model.TaskStatusRunning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_RefreshTaskStatus_UpdatesDB(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: slurm.JobInfoMsg{
|
||||||
|
{
|
||||||
|
JobID: &jobID,
|
||||||
|
JobState: []string{"RUNNING"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
task := makeTaskForTest("refresh-test", model.TaskStatusQueued, &jobID)
|
||||||
|
id, err := env.taskStore.Create(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.refreshTaskStatus(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("refreshTaskStatus: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ := env.taskStore.GetByID(ctx, id)
|
||||||
|
if updated.Status != model.TaskStatusRunning {
|
||||||
|
t.Errorf("status = %q, want %q", updated.Status, model.TaskStatusRunning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_RefreshTaskStatus_NoSlurmJobID(t *testing.T) {
|
||||||
|
env := newTaskSvcTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
task := makeTaskForTest("no-slurm", model.TaskStatusQueued, nil)
|
||||||
|
id, _ := env.taskStore.Create(ctx, task)
|
||||||
|
|
||||||
|
err := env.svc.refreshTaskStatus(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := env.taskStore.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("status should remain unchanged, got %q", got.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_RefreshTaskStatus_SlurmError(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"down"}`))
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
task := makeTaskForTest("slurm-err", model.TaskStatusQueued, &jobID)
|
||||||
|
id, _ := env.taskStore.Create(ctx, task)
|
||||||
|
|
||||||
|
err := env.svc.refreshTaskStatus(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error (soft fail), got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := env.taskStore.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("status should remain unchanged on slurm error, got %q", got.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_RefreshTaskStatus_NoChange(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: slurm.JobInfoMsg{
|
||||||
|
{
|
||||||
|
JobID: &jobID,
|
||||||
|
JobState: []string{"RUNNING"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
task := makeTaskForTest("no-change", model.TaskStatusRunning, &jobID)
|
||||||
|
id, _ := env.taskStore.Create(ctx, task)
|
||||||
|
|
||||||
|
err := env.svc.refreshTaskStatus(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("refreshTaskStatus: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := env.taskStore.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusRunning {
|
||||||
|
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusRunning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_RefreshStaleTasks_SkipsFresh(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
slurmQueried := false
|
||||||
|
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
slurmQueried = true
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
task := makeTaskForTest("fresh-task", model.TaskStatusQueued, &jobID)
|
||||||
|
id, _ := env.taskStore.Create(ctx, task)
|
||||||
|
|
||||||
|
freshTask, _ := env.taskStore.GetByID(ctx, id)
|
||||||
|
if freshTask == nil {
|
||||||
|
t.Fatal("task not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", time.Now(), id)
|
||||||
|
|
||||||
|
err := env.svc.RefreshStaleTasks(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RefreshStaleTasks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurmQueried {
|
||||||
|
t.Error("expected no Slurm query for fresh task")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_RefreshStaleTasks_RefreshesStale(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskSvcTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := slurm.OpenapiJobInfoResp{
|
||||||
|
Jobs: slurm.JobInfoMsg{
|
||||||
|
{
|
||||||
|
JobID: &jobID,
|
||||||
|
JobState: []string{"COMPLETED"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
task := makeTaskForTest("stale-task", model.TaskStatusRunning, &jobID)
|
||||||
|
id, _ := env.taskStore.Create(ctx, task)
|
||||||
|
|
||||||
|
staleTime := time.Now().Add(-60 * time.Second)
|
||||||
|
env.db.Exec("UPDATE hpc_tasks SET updated_at = ? WHERE id = ?", staleTime, id)
|
||||||
|
|
||||||
|
err := env.svc.RefreshStaleTasks(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RefreshStaleTasks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := env.taskStore.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusCompleted {
|
||||||
|
t.Errorf("status = %q, want %q", got.Status, model.TaskStatusCompleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
538
internal/service/task_service_test.go
Normal file
538
internal/service/task_service_test.go
Normal file
@@ -0,0 +1,538 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/slurm"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
gormlogger "gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTaskTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Task{}, &model.Application{}, &model.File{}, &model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
type taskTestEnv struct {
|
||||||
|
taskStore *store.TaskStore
|
||||||
|
appStore *store.ApplicationStore
|
||||||
|
fileStore *store.FileStore
|
||||||
|
blobStore *store.BlobStore
|
||||||
|
svc *TaskService
|
||||||
|
srv *httptest.Server
|
||||||
|
db *gorm.DB
|
||||||
|
workDirBase string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTaskTestEnv(t *testing.T, slurmHandler http.HandlerFunc) *taskTestEnv {
|
||||||
|
t.Helper()
|
||||||
|
db := setupTaskTestDB(t)
|
||||||
|
|
||||||
|
ts := store.NewTaskStore(db)
|
||||||
|
as := store.NewApplicationStore(db)
|
||||||
|
fs := store.NewFileStore(db)
|
||||||
|
bs := store.NewBlobStore(db)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(slurmHandler)
|
||||||
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
||||||
|
jobSvc := NewJobService(client, zap.NewNop())
|
||||||
|
|
||||||
|
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
||||||
|
os.MkdirAll(workDirBase, 0777)
|
||||||
|
|
||||||
|
svc := NewTaskService(ts, as, fs, bs, nil, jobSvc, workDirBase, zap.NewNop())
|
||||||
|
|
||||||
|
return &taskTestEnv{
|
||||||
|
taskStore: ts,
|
||||||
|
appStore: as,
|
||||||
|
fileStore: fs,
|
||||||
|
blobStore: bs,
|
||||||
|
svc: svc,
|
||||||
|
srv: srv,
|
||||||
|
db: db,
|
||||||
|
workDirBase: workDirBase,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *taskTestEnv) close() {
|
||||||
|
e.srv.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *taskTestEnv) createApp(t *testing.T, name, script string, params json.RawMessage) int64 {
|
||||||
|
t.Helper()
|
||||||
|
id, err := e.appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: name,
|
||||||
|
ScriptTemplate: script,
|
||||||
|
Parameters: params,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create app: %v", err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_CreateTask_Success(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "my-app", "#!/bin/bash\necho hello", json.RawMessage(`[]`))
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "test-task",
|
||||||
|
Values: map[string]string{"KEY": "val"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
if task.ID == 0 {
|
||||||
|
t.Error("expected non-zero task ID")
|
||||||
|
}
|
||||||
|
if task.AppID != appID {
|
||||||
|
t.Errorf("AppID = %d, want %d", task.AppID, appID)
|
||||||
|
}
|
||||||
|
if task.AppName != "my-app" {
|
||||||
|
t.Errorf("AppName = %q, want %q", task.AppName, "my-app")
|
||||||
|
}
|
||||||
|
if task.Status != model.TaskStatusSubmitted {
|
||||||
|
t.Errorf("Status = %q, want %q", task.Status, model.TaskStatusSubmitted)
|
||||||
|
}
|
||||||
|
if task.TaskName != "test-task" {
|
||||||
|
t.Errorf("TaskName = %q, want %q", task.TaskName, "test-task")
|
||||||
|
}
|
||||||
|
|
||||||
|
var values map[string]string
|
||||||
|
if err := json.Unmarshal(task.Values, &values); err != nil {
|
||||||
|
t.Fatalf("unmarshal values: %v", err)
|
||||||
|
}
|
||||||
|
if values["KEY"] != "val" {
|
||||||
|
t.Errorf("values[KEY] = %q, want %q", values["KEY"], "val")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_CreateTask_InvalidAppID(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: 999,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid app_id")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not found") {
|
||||||
|
t.Errorf("error should mention 'not found', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_CreateTask_ExceedsFileLimit(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
|
||||||
|
|
||||||
|
fileIDs := make([]int64, 101)
|
||||||
|
for i := range fileIDs {
|
||||||
|
fileIDs[i] = int64(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
InputFileIDs: fileIDs,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for exceeding file limit")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "exceeds limit") {
|
||||||
|
t.Errorf("error should mention limit, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_CreateTask_DuplicateFileIDs(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
for _, id := range []int64{1, 2} {
|
||||||
|
f := &model.File{
|
||||||
|
Name: "file.txt",
|
||||||
|
BlobSHA256: "abc123",
|
||||||
|
}
|
||||||
|
if err := env.fileStore.Create(ctx, f); err != nil {
|
||||||
|
t.Fatalf("create file: %v", err)
|
||||||
|
}
|
||||||
|
if f.ID != id {
|
||||||
|
t.Fatalf("expected file ID %d, got %d", id, f.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(ctx, &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
InputFileIDs: []int64{1, 1, 2, 2},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileIDs []int64
|
||||||
|
if err := json.Unmarshal(task.InputFileIDs, &fileIDs); err != nil {
|
||||||
|
t.Fatalf("unmarshal file ids: %v", err)
|
||||||
|
}
|
||||||
|
if len(fileIDs) != 2 {
|
||||||
|
t.Fatalf("expected 2 deduplicated file IDs, got %d: %v", len(fileIDs), fileIDs)
|
||||||
|
}
|
||||||
|
if fileIDs[0] != 1 || fileIDs[1] != 2 {
|
||||||
|
t.Errorf("expected [1,2], got %v", fileIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_CreateTask_AutoName(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "My Cool App", "#!/bin/bash\necho hi", nil)
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(task.TaskName, "My_Cool_App_") {
|
||||||
|
t.Errorf("auto-generated name should start with 'My_Cool_App_', got %q", task.TaskName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_CreateTask_NilValues(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "app", "#!/bin/bash\necho hi", nil)
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
Values: nil,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
if string(task.Values) != `{}` {
|
||||||
|
t.Errorf("Values = %q, want {}", string(task.Values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_Success(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "test-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
Values: map[string]string{"INPUT": "hello"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||||
|
if updated.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||||
|
}
|
||||||
|
if updated.SlurmJobID == nil || *updated.SlurmJobID != 42 {
|
||||||
|
t.Errorf("SlurmJobID = %v, want 42", updated.SlurmJobID)
|
||||||
|
}
|
||||||
|
if updated.WorkDir == "" {
|
||||||
|
t.Error("WorkDir should not be empty")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(updated.WorkDir, env.workDirBase) {
|
||||||
|
t.Errorf("WorkDir = %q, should start with %q", updated.WorkDir, env.workDirBase)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(updated.WorkDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stat workdir: %v", err)
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
t.Error("WorkDir should be a directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_TaskNotFound(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
err := env.svc.ProcessTask(context.Background(), 999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-existent task")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not found") {
|
||||||
|
t.Errorf("error should mention 'not found', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_SlurmError(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte(`{"error":"slurm down"}`))
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "test-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from Slurm")
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||||
|
if updated.Status != model.TaskStatusFailed {
|
||||||
|
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusFailed)
|
||||||
|
}
|
||||||
|
if updated.CurrentStep != model.TaskStepSubmitting {
|
||||||
|
t.Errorf("CurrentStep = %q, want %q", updated.CurrentStep, model.TaskStepSubmitting)
|
||||||
|
}
|
||||||
|
if !strings.Contains(updated.ErrorMessage, "submit job") {
|
||||||
|
t.Errorf("ErrorMessage should mention 'submit job', got: %q", updated.ErrorMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTaskSync(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "sync-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
|
||||||
|
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessTaskSync: %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 42 {
|
||||||
|
t.Errorf("JobID = %d, want 42", resp.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTaskSync_NoMinIO(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "no-minio-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
|
||||||
|
resp, err := env.svc.ProcessTaskSync(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
InputFileIDs: nil,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessTaskSync: %v", err)
|
||||||
|
}
|
||||||
|
if resp.JobID != 42 {
|
||||||
|
t.Errorf("JobID = %d, want 42", resp.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_NilValues(t *testing.T) {
|
||||||
|
jobID := int32(55)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "nil-val-app", "#!/bin/bash\necho hello", nil)
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
Values: nil,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||||
|
if updated.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ListTasks(t *testing.T) {
|
||||||
|
env := newTaskTestEnv(t, nil)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "list-app", "#!/bin/bash\necho hi", nil)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
_, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
TaskName: "task-" + string(rune('A'+i)),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, total, err := env.svc.ListTasks(context.Background(), &model.TaskListQuery{
|
||||||
|
Page: 1,
|
||||||
|
PageSize: 10,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTasks: %v", err)
|
||||||
|
}
|
||||||
|
if total != 3 {
|
||||||
|
t.Errorf("total = %d, want 3", total)
|
||||||
|
}
|
||||||
|
if len(tasks) != 3 {
|
||||||
|
t.Errorf("len(tasks) = %d, want 3", len(tasks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_ValidateParams_MissingRequired(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
// App requires INPUT param, but we submit without it
|
||||||
|
appID := env.createApp(t, "validation-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
Values: map[string]string{}, // missing required INPUT
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing required parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
|
||||||
|
}
|
||||||
|
errStr := err.Error()
|
||||||
|
if !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "missing") && !strings.Contains(errStr, "INPUT") {
|
||||||
|
t.Errorf("error should mention 'validation', 'missing', or 'INPUT', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_ValidateParams_InvalidInteger(t *testing.T) {
|
||||||
|
jobID := int32(42)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
// App expects integer param NUM, but we submit "abc"
|
||||||
|
appID := env.createApp(t, "int-validation-app", "#!/bin/bash\necho $NUM", json.RawMessage(`[{"name":"NUM","type":"integer","required":true}]`))
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
Values: map[string]string{"NUM": "abc"}, // invalid integer
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid integer parameter, got nil — ValidateParams is not being called in ProcessTask pipeline")
|
||||||
|
}
|
||||||
|
errStr := err.Error()
|
||||||
|
if !strings.Contains(errStr, "integer") && !strings.Contains(errStr, "validation") && !strings.Contains(errStr, "NUM") {
|
||||||
|
t.Errorf("error should mention 'integer', 'validation', or 'NUM', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskService_ProcessTask_ValidateParams_ValidParamsSucceed(t *testing.T) {
|
||||||
|
jobID := int32(99)
|
||||||
|
env := newTaskTestEnv(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
||||||
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
appID := env.createApp(t, "valid-params-app", "#!/bin/bash\necho $INPUT", json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`))
|
||||||
|
|
||||||
|
task, err := env.svc.CreateTask(context.Background(), &model.CreateTaskRequest{
|
||||||
|
AppID: appID,
|
||||||
|
Values: map[string]string{"INPUT": "hello"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = env.svc.ProcessTask(context.Background(), task.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessTask with valid params: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ := env.taskStore.GetByID(context.Background(), task.ID)
|
||||||
|
if updated.Status != model.TaskStatusQueued {
|
||||||
|
t.Errorf("Status = %q, want %q", updated.Status, model.TaskStatusQueued)
|
||||||
|
}
|
||||||
|
if updated.SlurmJobID == nil || *updated.SlurmJobID != 99 {
|
||||||
|
t.Errorf("SlurmJobID = %v, want 99", updated.SlurmJobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
443
internal/service/upload_service.go
Normal file
443
internal/service/upload_service.go
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UploadService struct {
|
||||||
|
storage storage.ObjectStorage
|
||||||
|
blobStore *store.BlobStore
|
||||||
|
fileStore *store.FileStore
|
||||||
|
uploadStore *store.UploadStore
|
||||||
|
cfg config.MinioConfig
|
||||||
|
db *gorm.DB
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUploadService(
|
||||||
|
st storage.ObjectStorage,
|
||||||
|
blobStore *store.BlobStore,
|
||||||
|
fileStore *store.FileStore,
|
||||||
|
uploadStore *store.UploadStore,
|
||||||
|
cfg config.MinioConfig,
|
||||||
|
db *gorm.DB,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) *UploadService {
|
||||||
|
return &UploadService{
|
||||||
|
storage: st,
|
||||||
|
blobStore: blobStore,
|
||||||
|
fileStore: fileStore,
|
||||||
|
uploadStore: uploadStore,
|
||||||
|
cfg: cfg,
|
||||||
|
db: db,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UploadService) InitUpload(ctx context.Context, req model.InitUploadRequest) (interface{}, error) {
|
||||||
|
if err := model.ValidateFileName(req.FileName); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid file name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.FileSize < 0 {
|
||||||
|
return nil, fmt.Errorf("file size cannot be negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.FileSize > s.cfg.MaxFileSize {
|
||||||
|
return nil, fmt.Errorf("file size %d exceeds maximum %d", req.FileSize, s.cfg.MaxFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkSize := s.cfg.ChunkSize
|
||||||
|
if req.ChunkSize != nil {
|
||||||
|
chunkSize = *req.ChunkSize
|
||||||
|
}
|
||||||
|
if chunkSize < s.cfg.MinChunkSize {
|
||||||
|
return nil, fmt.Errorf("chunk size %d is below minimum %d", chunkSize, s.cfg.MinChunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalChunks := int((req.FileSize + chunkSize - 1) / chunkSize)
|
||||||
|
if req.FileSize == 0 {
|
||||||
|
totalChunks = 0
|
||||||
|
}
|
||||||
|
if totalChunks > 10000 {
|
||||||
|
return nil, fmt.Errorf("total chunks %d exceeds limit of 10000", totalChunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := s.blobStore.GetBySHA256(ctx, req.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("check dedup: %w", err)
|
||||||
|
}
|
||||||
|
if blob != nil {
|
||||||
|
if err := s.blobStore.IncrementRef(ctx, req.SHA256); err != nil {
|
||||||
|
return nil, fmt.Errorf("increment ref: %w", err)
|
||||||
|
}
|
||||||
|
file := &model.File{
|
||||||
|
Name: req.FileName,
|
||||||
|
FolderID: req.FolderID,
|
||||||
|
BlobSHA256: req.SHA256,
|
||||||
|
}
|
||||||
|
if err := s.fileStore.Create(ctx, file); err != nil {
|
||||||
|
return nil, fmt.Errorf("create file: %w", err)
|
||||||
|
}
|
||||||
|
return model.FileResponse{
|
||||||
|
ID: file.ID,
|
||||||
|
Name: file.Name,
|
||||||
|
FolderID: file.FolderID,
|
||||||
|
Size: blob.FileSize,
|
||||||
|
MimeType: blob.MimeType,
|
||||||
|
SHA256: blob.SHA256,
|
||||||
|
CreatedAt: file.CreatedAt,
|
||||||
|
UpdatedAt: file.UpdatedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := req.MimeType
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
session := &model.UploadSession{
|
||||||
|
FileName: req.FileName,
|
||||||
|
FileSize: req.FileSize,
|
||||||
|
ChunkSize: chunkSize,
|
||||||
|
TotalChunks: totalChunks,
|
||||||
|
SHA256: req.SHA256,
|
||||||
|
FolderID: req.FolderID,
|
||||||
|
Status: "pending",
|
||||||
|
MinioPrefix: fmt.Sprintf("uploads/%d/", time.Now().UnixNano()),
|
||||||
|
MimeType: mimeType,
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(s.cfg.SessionTTL) * time.Hour),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.uploadStore.CreateSession(ctx, session); err != nil {
|
||||||
|
return nil, fmt.Errorf("create session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.UploadSessionResponse{
|
||||||
|
ID: session.ID,
|
||||||
|
FileName: session.FileName,
|
||||||
|
FileSize: session.FileSize,
|
||||||
|
ChunkSize: session.ChunkSize,
|
||||||
|
TotalChunks: session.TotalChunks,
|
||||||
|
SHA256: session.SHA256,
|
||||||
|
Status: session.Status,
|
||||||
|
UploadedChunks: []int{},
|
||||||
|
ExpiresAt: session.ExpiresAt,
|
||||||
|
CreatedAt: session.CreatedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UploadService) UploadChunk(ctx context.Context, sessionID int64, chunkIndex int, reader io.Reader, size int64) error {
|
||||||
|
session, err := s.uploadStore.GetSession(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get session: %w", err)
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
return fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch session.Status {
|
||||||
|
case "pending", "uploading", "failed":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot upload to session with status %q", session.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunkIndex < 0 || chunkIndex >= session.TotalChunks {
|
||||||
|
return fmt.Errorf("chunk index %d out of range [0, %d)", chunkIndex, session.TotalChunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%schunk_%05d", session.MinioPrefix, chunkIndex)
|
||||||
|
|
||||||
|
hasher := sha256.New()
|
||||||
|
teeReader := io.TeeReader(reader, hasher)
|
||||||
|
|
||||||
|
_, err = s.storage.PutObject(ctx, s.cfg.Bucket, key, teeReader, size, storage.PutObjectOptions{
|
||||||
|
DisableMultipart: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("put object: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
||||||
|
|
||||||
|
chunk := &model.UploadChunk{
|
||||||
|
SessionID: sessionID,
|
||||||
|
ChunkIndex: chunkIndex,
|
||||||
|
MinioKey: key,
|
||||||
|
SHA256: chunkSHA256,
|
||||||
|
Size: size,
|
||||||
|
Status: "uploaded",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.uploadStore.UpsertChunk(ctx, chunk); err != nil {
|
||||||
|
return fmt.Errorf("upsert chunk: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.Status == "pending" {
|
||||||
|
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "uploading"); err != nil {
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return fmt.Errorf("update status to uploading: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UploadService) GetUploadStatus(ctx context.Context, sessionID int64) (*model.UploadSessionResponse, error) {
|
||||||
|
session, chunks, err := s.uploadStore.GetSessionWithChunks(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get session: %w", err)
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
return nil, fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadedChunks := make([]int, 0, len(chunks))
|
||||||
|
for _, c := range chunks {
|
||||||
|
if c.Status == "uploaded" {
|
||||||
|
uploadedChunks = append(uploadedChunks, c.ChunkIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.UploadSessionResponse{
|
||||||
|
ID: session.ID,
|
||||||
|
FileName: session.FileName,
|
||||||
|
FileSize: session.FileSize,
|
||||||
|
ChunkSize: session.ChunkSize,
|
||||||
|
TotalChunks: session.TotalChunks,
|
||||||
|
SHA256: session.SHA256,
|
||||||
|
Status: session.Status,
|
||||||
|
UploadedChunks: uploadedChunks,
|
||||||
|
ExpiresAt: session.ExpiresAt,
|
||||||
|
CreatedAt: session.CreatedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UploadService) CompleteUpload(ctx context.Context, sessionID int64) (*model.FileResponse, error) {
|
||||||
|
var fileResp *model.FileResponse
|
||||||
|
var totalChunks int
|
||||||
|
var minioPrefix string
|
||||||
|
|
||||||
|
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
var session model.UploadSession
|
||||||
|
query := tx.WithContext(ctx)
|
||||||
|
if !isSQLite(tx) {
|
||||||
|
query = query.Clauses(clause.Locking{Strength: "UPDATE"})
|
||||||
|
}
|
||||||
|
if err := query.First(&session, sessionID).Error; err != nil {
|
||||||
|
return fmt.Errorf("get session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch session.Status {
|
||||||
|
case "uploading", "failed":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot complete session with status %q", session.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.TotalChunks > 0 {
|
||||||
|
var chunkCount int64
|
||||||
|
if err := tx.WithContext(ctx).Model(&model.UploadChunk{}).
|
||||||
|
Where("session_id = ? AND status = ?", sessionID, "uploaded").
|
||||||
|
Count(&chunkCount).Error; err != nil {
|
||||||
|
return fmt.Errorf("count chunks: %w", err)
|
||||||
|
}
|
||||||
|
if int(chunkCount) != session.TotalChunks {
|
||||||
|
return fmt.Errorf("not all chunks uploaded: %d/%d", chunkCount, session.TotalChunks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
totalChunks = session.TotalChunks
|
||||||
|
minioPrefix = session.MinioPrefix
|
||||||
|
|
||||||
|
if err := updateStatusTx(tx, ctx, sessionID, "merging"); err != nil {
|
||||||
|
return fmt.Errorf("update status to merging: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := s.blobStore.GetBySHA256ForUpdate(ctx, tx, session.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get blob for update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if blob != nil {
|
||||||
|
result := tx.WithContext(ctx).Model(&model.FileBlob{}).
|
||||||
|
Where("sha256 = ?", session.SHA256).
|
||||||
|
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("increment ref: %w", result.Error)
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("blob not found for ref increment")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if session.TotalChunks > 0 {
|
||||||
|
chunkKeys := make([]string, session.TotalChunks)
|
||||||
|
for i := 0; i < session.TotalChunks; i++ {
|
||||||
|
chunkKeys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
dstKey := "files/" + session.SHA256
|
||||||
|
if _, err := s.storage.ComposeObject(ctx, s.cfg.Bucket, dstKey, chunkKeys); err != nil {
|
||||||
|
return errComposeFailed{err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blobRecord := &model.FileBlob{
|
||||||
|
SHA256: session.SHA256,
|
||||||
|
MinioKey: "files/" + session.SHA256,
|
||||||
|
FileSize: session.FileSize,
|
||||||
|
MimeType: session.MimeType,
|
||||||
|
RefCount: 1,
|
||||||
|
}
|
||||||
|
if session.TotalChunks == 0 {
|
||||||
|
blobRecord.MinioKey = "files/" + session.SHA256
|
||||||
|
blobRecord.FileSize = 0
|
||||||
|
}
|
||||||
|
if err := tx.WithContext(ctx).Create(blobRecord).Error; err != nil {
|
||||||
|
return fmt.Errorf("create blob: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file := &model.File{
|
||||||
|
Name: session.FileName,
|
||||||
|
FolderID: session.FolderID,
|
||||||
|
BlobSHA256: session.SHA256,
|
||||||
|
}
|
||||||
|
if err := tx.WithContext(ctx).Create(file).Error; err != nil {
|
||||||
|
return fmt.Errorf("create file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := updateStatusTx(tx, ctx, sessionID, "completed"); err != nil {
|
||||||
|
return fmt.Errorf("update status to completed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileResp = &model.FileResponse{
|
||||||
|
ID: file.ID,
|
||||||
|
Name: file.Name,
|
||||||
|
FolderID: file.FolderID,
|
||||||
|
Size: session.FileSize,
|
||||||
|
MimeType: session.MimeType,
|
||||||
|
SHA256: session.SHA256,
|
||||||
|
CreatedAt: file.CreatedAt,
|
||||||
|
UpdatedAt: file.UpdatedAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
var cfe errComposeFailed
|
||||||
|
if errors.As(err, &cfe) {
|
||||||
|
if statusErr := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "failed"); statusErr != nil {
|
||||||
|
s.logger.Warn("failed to mark session as failed", zap.Int64("session_id", sessionID), zap.Error(statusErr))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("compose object: %w", cfe.err)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalChunks > 0 {
|
||||||
|
keys := make([]string, totalChunks)
|
||||||
|
for i := 0; i < totalChunks; i++ {
|
||||||
|
keys[i] = fmt.Sprintf("%schunk_%05d", minioPrefix, i)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
bgCtx := context.Background()
|
||||||
|
if delErr := s.storage.RemoveObjects(bgCtx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
|
||||||
|
s.logger.Warn("delete temp chunks", zap.Error(delErr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fileResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UploadService) CancelUpload(ctx context.Context, sessionID int64) error {
|
||||||
|
session, err := s.uploadStore.GetSession(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get session: %w", err)
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
return fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.uploadStore.UpdateSessionStatus(ctx, sessionID, "cancelled"); err != nil {
|
||||||
|
return fmt.Errorf("update status to cancelled: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.TotalChunks > 0 {
|
||||||
|
keys, listErr := s.listChunkKeys(ctx, sessionID)
|
||||||
|
if listErr != nil {
|
||||||
|
s.logger.Warn("list chunk keys for cancel", zap.Error(listErr))
|
||||||
|
} else if len(keys) > 0 {
|
||||||
|
if delErr := s.storage.RemoveObjects(ctx, s.cfg.Bucket, keys, storage.RemoveObjectsOptions{}); delErr != nil {
|
||||||
|
s.logger.Warn("remove chunk objects for cancel", zap.Error(delErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.uploadStore.DeleteSession(ctx, sessionID); err != nil {
|
||||||
|
return fmt.Errorf("delete session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UploadService) listChunkKeys(ctx context.Context, sessionID int64) ([]string, error) {
|
||||||
|
session, err := s.uploadStore.GetSession(ctx, sessionID)
|
||||||
|
if err != nil || session == nil {
|
||||||
|
return nil, fmt.Errorf("get session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, session.TotalChunks)
|
||||||
|
for i := 0; i < session.TotalChunks; i++ {
|
||||||
|
keys[i] = fmt.Sprintf("%schunk_%05d", session.MinioPrefix, i)
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateStatusTx(tx *gorm.DB, ctx context.Context, id int64, status string) error {
|
||||||
|
result := tx.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSQLite(db *gorm.DB) bool {
|
||||||
|
return db.Dialector.Name() == "sqlite"
|
||||||
|
}
|
||||||
|
|
||||||
|
func objectKeys(objects []storage.ObjectInfo) []string {
|
||||||
|
keys := make([]string, len(objects))
|
||||||
|
for i, o := range objects {
|
||||||
|
keys[i] = o.Key
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
type errComposeFailed struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errComposeFailed) Error() string {
|
||||||
|
return fmt.Sprintf("compose failed: %v", e.err)
|
||||||
|
}
|
||||||
678
internal/service/upload_service_test.go
Normal file
678
internal/service/upload_service_test.go
Normal file
@@ -0,0 +1,678 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
"gcy_hpc_server/internal/store"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type uploadMockStorage struct {
|
||||||
|
putObjectFn func(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error)
|
||||||
|
composeObjectFn func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error)
|
||||||
|
listObjectsFn func(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error)
|
||||||
|
removeObjectsFn func(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error
|
||||||
|
removeObjectFn func(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error
|
||||||
|
getObjectFn func(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error)
|
||||||
|
bucketExistsFn func(ctx context.Context, bucket string) (bool, error)
|
||||||
|
makeBucketFn func(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error
|
||||||
|
statObjectFn func(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error)
|
||||||
|
abortMultipartFn func(ctx context.Context, bucket, object, uploadID string) error
|
||||||
|
removeIncompleteFn func(ctx context.Context, bucket, object string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
if m.putObjectFn != nil {
|
||||||
|
return m.putObjectFn(ctx, bucket, key, reader, size, opts)
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, reader)
|
||||||
|
return storage.UploadInfo{ETag: "etag", Size: size}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) GetObject(ctx context.Context, bucket, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
if m.getObjectFn != nil {
|
||||||
|
return m.getObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return nil, storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
|
||||||
|
if m.composeObjectFn != nil {
|
||||||
|
return m.composeObjectFn(ctx, bucket, dst, sources)
|
||||||
|
}
|
||||||
|
return storage.UploadInfo{ETag: "composed", Size: 0}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
|
||||||
|
if m.abortMultipartFn != nil {
|
||||||
|
return m.abortMultipartFn(ctx, bucket, object, uploadID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
|
||||||
|
if m.removeIncompleteFn != nil {
|
||||||
|
return m.removeIncompleteFn(ctx, bucket, object)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) RemoveObject(ctx context.Context, bucket, key string, opts storage.RemoveObjectOptions) error {
|
||||||
|
if m.removeObjectFn != nil {
|
||||||
|
return m.removeObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]storage.ObjectInfo, error) {
|
||||||
|
if m.listObjectsFn != nil {
|
||||||
|
return m.listObjectsFn(ctx, bucket, prefix, recursive)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) RemoveObjects(ctx context.Context, bucket string, keys []string, opts storage.RemoveObjectsOptions) error {
|
||||||
|
if m.removeObjectsFn != nil {
|
||||||
|
return m.removeObjectsFn(ctx, bucket, keys, opts)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) BucketExists(ctx context.Context, bucket string) (bool, error) {
|
||||||
|
if m.bucketExistsFn != nil {
|
||||||
|
return m.bucketExistsFn(ctx, bucket)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) MakeBucket(ctx context.Context, bucket string, opts storage.MakeBucketOptions) error {
|
||||||
|
if m.makeBucketFn != nil {
|
||||||
|
return m.makeBucketFn(ctx, bucket, opts)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadMockStorage) StatObject(ctx context.Context, bucket, key string, opts storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
if m.statObjectFn != nil {
|
||||||
|
return m.statObjectFn(ctx, bucket, key, opts)
|
||||||
|
}
|
||||||
|
return storage.ObjectInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupUploadTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.FileBlob{}, &model.File{}, &model.UploadSession{}, &model.UploadChunk{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func uploadTestConfig() config.MinioConfig {
|
||||||
|
return config.MinioConfig{
|
||||||
|
Bucket: "test-bucket",
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
MaxFileSize: 50 << 30,
|
||||||
|
MinChunkSize: 5 << 20,
|
||||||
|
SessionTTL: 48,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUploadTestService(t *testing.T, st storage.ObjectStorage, db *gorm.DB) *UploadService {
|
||||||
|
t.Helper()
|
||||||
|
return NewUploadService(
|
||||||
|
st,
|
||||||
|
store.NewBlobStore(db),
|
||||||
|
store.NewFileStore(db),
|
||||||
|
store.NewUploadStore(db),
|
||||||
|
uploadTestConfig(),
|
||||||
|
db,
|
||||||
|
zap.NewNop(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Of(data []byte) string {
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_CreatesSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
resp, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "test.txt",
|
||||||
|
FileSize: 32 << 20,
|
||||||
|
SHA256: "abc123",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("InitUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessResp, ok := resp.(model.UploadSessionResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected UploadSessionResponse, got %T", resp)
|
||||||
|
}
|
||||||
|
if sessResp.Status != "pending" {
|
||||||
|
t.Errorf("status = %q, want pending", sessResp.Status)
|
||||||
|
}
|
||||||
|
if sessResp.TotalChunks != 2 {
|
||||||
|
t.Errorf("TotalChunks = %d, want 2", sessResp.TotalChunks)
|
||||||
|
}
|
||||||
|
if sessResp.FileName != "test.txt" {
|
||||||
|
t.Errorf("FileName = %q, want test.txt", sessResp.FileName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_DedupBlobExists(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
blobSHA := sha256Of([]byte("hello"))
|
||||||
|
blob := &model.FileBlob{
|
||||||
|
SHA256: blobSHA,
|
||||||
|
MinioKey: "files/" + blobSHA,
|
||||||
|
FileSize: 5,
|
||||||
|
MimeType: "text/plain",
|
||||||
|
RefCount: 1,
|
||||||
|
}
|
||||||
|
if err := db.Create(blob).Error; err != nil {
|
||||||
|
t.Fatalf("create blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "mydoc.txt",
|
||||||
|
FileSize: 5,
|
||||||
|
SHA256: blobSHA,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("InitUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileResp, ok := resp.(model.FileResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected FileResponse, got %T", resp)
|
||||||
|
}
|
||||||
|
if fileResp.Name != "mydoc.txt" {
|
||||||
|
t.Errorf("Name = %q, want mydoc.txt", fileResp.Name)
|
||||||
|
}
|
||||||
|
if fileResp.SHA256 != blobSHA {
|
||||||
|
t.Errorf("SHA256 mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
var after model.FileBlob
|
||||||
|
db.First(&after, blob.ID)
|
||||||
|
if after.RefCount != 2 {
|
||||||
|
t.Errorf("RefCount = %d, want 2", after.RefCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_ChunkSizeTooSmall(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
tiny := int64(1024)
|
||||||
|
_, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "test.txt",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
SHA256: "abc",
|
||||||
|
ChunkSize: &tiny,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for chunk size too small")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "below minimum") {
|
||||||
|
t.Errorf("error = %q, want 'below minimum'", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_TooManyChunks(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
|
||||||
|
cfg := uploadTestConfig()
|
||||||
|
cfg.ChunkSize = 1
|
||||||
|
cfg.MinChunkSize = 1
|
||||||
|
svc2 := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
|
||||||
|
|
||||||
|
_, err := svc2.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "big.txt",
|
||||||
|
FileSize: 10002,
|
||||||
|
SHA256: "abc",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for too many chunks")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "exceeds limit") {
|
||||||
|
t.Errorf("error = %q, want 'exceeds limit'", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitUpload_DangerousFilename(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
for _, name := range []string{"", "..", "foo/bar", "foo\\bar", " name"} {
|
||||||
|
_, err := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: name,
|
||||||
|
FileSize: 100,
|
||||||
|
SHA256: "abc",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for filename %q", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadChunk_UploadsAndStoresSHA256(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
SHA256: "deadbeef",
|
||||||
|
})
|
||||||
|
sessResp := resp.(model.UploadSessionResponse)
|
||||||
|
|
||||||
|
data := []byte("chunk data here")
|
||||||
|
chunkSHA := sha256Of(data)
|
||||||
|
|
||||||
|
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader(data), int64(len(data)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UploadChunk: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk model.UploadChunk
|
||||||
|
db.Where("session_id = ? AND chunk_index = ?", sessResp.ID, 0).First(&chunk)
|
||||||
|
if chunk.SHA256 != chunkSHA {
|
||||||
|
t.Errorf("chunk SHA256 = %q, want %q", chunk.SHA256, chunkSHA)
|
||||||
|
}
|
||||||
|
if chunk.Status != "uploaded" {
|
||||||
|
t.Errorf("chunk status = %q, want uploaded", chunk.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := svc.uploadStore.GetSession(context.Background(), sessResp.ID)
|
||||||
|
if session.Status != "uploading" {
|
||||||
|
t.Errorf("session status = %q, want uploading", session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadChunk_RejectsCompletedSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
SHA256: "deadbeef",
|
||||||
|
})
|
||||||
|
sessResp := resp.(model.UploadSessionResponse)
|
||||||
|
|
||||||
|
svc.uploadStore.UpdateSessionStatus(context.Background(), sessResp.ID, "completed")
|
||||||
|
|
||||||
|
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader([]byte("x")), 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for completed session")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "cannot upload") {
|
||||||
|
t.Errorf("error = %q", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadChunk_RejectsExpiredSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
SHA256: "deadbeef",
|
||||||
|
})
|
||||||
|
sessResp := resp.(model.UploadSessionResponse)
|
||||||
|
|
||||||
|
svc.uploadStore.UpdateSessionStatus(context.Background(), sessResp.ID, "expired")
|
||||||
|
|
||||||
|
err := svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader([]byte("x")), 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for expired session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUploadStatus_ReturnsChunkIndices(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 32 << 20,
|
||||||
|
SHA256: "abc",
|
||||||
|
})
|
||||||
|
sessResp := resp.(model.UploadSessionResponse)
|
||||||
|
|
||||||
|
data := []byte("chunk0data")
|
||||||
|
svc.UploadChunk(context.Background(), sessResp.ID, 0, bytes.NewReader(data), int64(len(data)))
|
||||||
|
|
||||||
|
status, err := svc.GetUploadStatus(context.Background(), sessResp.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUploadStatus: %v", err)
|
||||||
|
}
|
||||||
|
if len(status.UploadedChunks) != 1 || status.UploadedChunks[0] != 0 {
|
||||||
|
t.Errorf("UploadedChunks = %v, want [0]", status.UploadedChunks)
|
||||||
|
}
|
||||||
|
if status.TotalChunks != 2 {
|
||||||
|
t.Errorf("TotalChunks = %d, want 2", status.TotalChunks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteUpload_CreatesBlobAndFile(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
cfg := uploadTestConfig()
|
||||||
|
cfg.ChunkSize = 5 << 20
|
||||||
|
cfg.MinChunkSize = 1
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
|
||||||
|
|
||||||
|
fileSHA := "aaa111"
|
||||||
|
|
||||||
|
sess := &model.UploadSession{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
ChunkSize: 5 << 20,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: fileSHA,
|
||||||
|
Status: "uploading",
|
||||||
|
MinioPrefix: "uploads/testcomplete/",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
ExpiresAt: time.Now().Add(48 * time.Hour),
|
||||||
|
}
|
||||||
|
db.Create(sess)
|
||||||
|
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/testcomplete/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/testcomplete/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
|
||||||
|
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteUpload: %v", err)
|
||||||
|
}
|
||||||
|
if fileResp.Name != "test.bin" {
|
||||||
|
t.Errorf("Name = %q, want test.bin", fileResp.Name)
|
||||||
|
}
|
||||||
|
if fileResp.SHA256 != fileSHA {
|
||||||
|
t.Errorf("SHA256 = %q, want %q", fileResp.SHA256, fileSHA)
|
||||||
|
}
|
||||||
|
|
||||||
|
var blob model.FileBlob
|
||||||
|
db.Where("sha256 = ?", fileSHA).First(&blob)
|
||||||
|
if blob.RefCount != 1 {
|
||||||
|
t.Errorf("blob RefCount = %d, want 1", blob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
|
||||||
|
if session == nil {
|
||||||
|
t.Fatal("session should exist")
|
||||||
|
}
|
||||||
|
if session.Status != "completed" {
|
||||||
|
t.Errorf("session status = %q, want completed", session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteUpload_ReusesExistingBlob(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
cfg := uploadTestConfig()
|
||||||
|
cfg.ChunkSize = 5 << 20
|
||||||
|
cfg.MinChunkSize = 1
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
|
||||||
|
composeCalled := false
|
||||||
|
st.composeObjectFn = func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
|
||||||
|
composeCalled = true
|
||||||
|
return storage.UploadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
|
||||||
|
|
||||||
|
fileSHA := "reuse123"
|
||||||
|
db.Create(&model.FileBlob{
|
||||||
|
SHA256: fileSHA,
|
||||||
|
MinioKey: "files/" + fileSHA,
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
RefCount: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
sess := &model.UploadSession{
|
||||||
|
FileName: "reuse.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
ChunkSize: 5 << 20,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: fileSHA,
|
||||||
|
Status: "uploading",
|
||||||
|
MinioPrefix: "uploads/reuse/",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
ExpiresAt: time.Now().Add(48 * time.Hour),
|
||||||
|
}
|
||||||
|
db.Create(sess)
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/reuse/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/reuse/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
|
||||||
|
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteUpload: %v", err)
|
||||||
|
}
|
||||||
|
if fileResp.SHA256 != fileSHA {
|
||||||
|
t.Errorf("SHA256 mismatch")
|
||||||
|
}
|
||||||
|
if composeCalled {
|
||||||
|
t.Error("ComposeObject should not be called when blob exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
var blob model.FileBlob
|
||||||
|
db.Where("sha256 = ?", fileSHA).First(&blob)
|
||||||
|
if blob.RefCount != 2 {
|
||||||
|
t.Errorf("RefCount = %d, want 2", blob.RefCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteUpload_NotAllChunks(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
cfg := uploadTestConfig()
|
||||||
|
cfg.ChunkSize = 5 << 20
|
||||||
|
cfg.MinChunkSize = 1
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
|
||||||
|
|
||||||
|
sess := &model.UploadSession{
|
||||||
|
FileName: "partial.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
ChunkSize: 5 << 20,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: "partial123",
|
||||||
|
Status: "uploading",
|
||||||
|
MinioPrefix: "uploads/partial/",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
ExpiresAt: time.Now().Add(48 * time.Hour),
|
||||||
|
}
|
||||||
|
db.Create(sess)
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/partial/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
|
||||||
|
_, err := svc.CompleteUpload(context.Background(), sess.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for incomplete chunks")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not all chunks uploaded") {
|
||||||
|
t.Errorf("error = %q", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteUpload_ComposeObjectFails(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
cfg := uploadTestConfig()
|
||||||
|
cfg.ChunkSize = 5 << 20
|
||||||
|
cfg.MinChunkSize = 1
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
|
||||||
|
st.composeObjectFn = func(ctx context.Context, bucket, dst string, sources []string) (storage.UploadInfo, error) {
|
||||||
|
return storage.UploadInfo{}, fmt.Errorf("compose failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
|
||||||
|
|
||||||
|
fileSHA := "fail123"
|
||||||
|
sess := &model.UploadSession{
|
||||||
|
FileName: "fail.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
ChunkSize: 5 << 20,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: fileSHA,
|
||||||
|
Status: "uploading",
|
||||||
|
MinioPrefix: "uploads/fail/",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
ExpiresAt: time.Now().Add(48 * time.Hour),
|
||||||
|
}
|
||||||
|
db.Create(sess)
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/fail/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/fail/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
|
||||||
|
_, err := svc.CompleteUpload(context.Background(), sess.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for compose failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
|
||||||
|
if session.Status != "failed" {
|
||||||
|
t.Errorf("session status = %q, want failed", session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteUpload_RetriesFailedSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
cfg := uploadTestConfig()
|
||||||
|
cfg.ChunkSize = 5 << 20
|
||||||
|
cfg.MinChunkSize = 1
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := NewUploadService(st, store.NewBlobStore(db), store.NewFileStore(db), store.NewUploadStore(db), cfg, db, zap.NewNop())
|
||||||
|
|
||||||
|
fileSHA := "retry123"
|
||||||
|
sess := &model.UploadSession{
|
||||||
|
FileName: "retry.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
ChunkSize: 5 << 20,
|
||||||
|
TotalChunks: 2,
|
||||||
|
SHA256: fileSHA,
|
||||||
|
Status: "failed",
|
||||||
|
MinioPrefix: "uploads/retry/",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
ExpiresAt: time.Now().Add(48 * time.Hour),
|
||||||
|
}
|
||||||
|
db.Create(sess)
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 0, MinioKey: "uploads/retry/chunk_00000", SHA256: "c0", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
db.Create(&model.UploadChunk{SessionID: sess.ID, ChunkIndex: 1, MinioKey: "uploads/retry/chunk_00001", SHA256: "c1", Size: 5 << 20, Status: "uploaded"})
|
||||||
|
|
||||||
|
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteUpload on retry: %v", err)
|
||||||
|
}
|
||||||
|
if fileResp.SHA256 != fileSHA {
|
||||||
|
t.Errorf("SHA256 mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := svc.uploadStore.GetSession(context.Background(), sess.ID)
|
||||||
|
if session.Status != "completed" {
|
||||||
|
t.Errorf("session status = %q, want completed", session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelUpload_CleansUp(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
resp, _ := svc.InitUpload(context.Background(), model.InitUploadRequest{
|
||||||
|
FileName: "cancel.bin",
|
||||||
|
FileSize: 10 << 20,
|
||||||
|
SHA256: "cancel123",
|
||||||
|
})
|
||||||
|
sessResp := resp.(model.UploadSessionResponse)
|
||||||
|
|
||||||
|
err := svc.CancelUpload(context.Background(), sessResp.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CancelUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, _ := svc.uploadStore.GetSession(context.Background(), sessResp.ID)
|
||||||
|
if session != nil {
|
||||||
|
t.Error("session should be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZeroByteFile_CompletesImmediately(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
st := &uploadMockStorage{}
|
||||||
|
svc := newUploadTestService(t, st, db)
|
||||||
|
|
||||||
|
fileSHA := "empty123"
|
||||||
|
sess := &model.UploadSession{
|
||||||
|
FileName: "empty.bin",
|
||||||
|
FileSize: 0,
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
TotalChunks: 0,
|
||||||
|
SHA256: fileSHA,
|
||||||
|
Status: "uploading",
|
||||||
|
MinioPrefix: "uploads/empty/",
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
ExpiresAt: time.Now().Add(48 * time.Hour),
|
||||||
|
}
|
||||||
|
db.Create(sess)
|
||||||
|
|
||||||
|
fileResp, err := svc.CompleteUpload(context.Background(), sess.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteUpload zero byte: %v", err)
|
||||||
|
}
|
||||||
|
if fileResp.SHA256 != fileSHA {
|
||||||
|
t.Errorf("SHA256 mismatch")
|
||||||
|
}
|
||||||
|
if fileResp.Size != 0 {
|
||||||
|
t.Errorf("Size = %d, want 0", fileResp.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
var blob model.FileBlob
|
||||||
|
db.Where("sha256 = ?", fileSHA).First(&blob)
|
||||||
|
if blob.RefCount != 1 {
|
||||||
|
t.Errorf("RefCount = %d, want 1", blob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadStore := store.NewUploadStore(db)
|
||||||
|
session, _ := uploadStore.GetSession(context.Background(), sess.ID)
|
||||||
|
if session == nil {
|
||||||
|
t.Fatal("session should exist")
|
||||||
|
}
|
||||||
|
if session.Status != "completed" {
|
||||||
|
t.Errorf("session status = %q, want completed", session.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -16,6 +17,8 @@ const (
|
|||||||
DefaultBaseURL = "http://localhost:6820/"
|
DefaultBaseURL = "http://localhost:6820/"
|
||||||
// DefaultUserAgent is the default User-Agent header value.
|
// DefaultUserAgent is the default User-Agent header value.
|
||||||
DefaultUserAgent = "slurm-go-sdk"
|
DefaultUserAgent = "slurm-go-sdk"
|
||||||
|
// DefaultTimeout is the default HTTP request timeout.
|
||||||
|
DefaultTimeout = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client manages communication with the Slurm REST API.
|
// Client manages communication with the Slurm REST API.
|
||||||
@@ -85,7 +88,7 @@ type Response struct {
|
|||||||
// http.DefaultClient is used.
|
// http.DefaultClient is used.
|
||||||
func NewClient(baseURL string, httpClient *http.Client) (*Client, error) {
|
func NewClient(baseURL string, httpClient *http.Client) (*Client, error) {
|
||||||
if httpClient == nil {
|
if httpClient == nil {
|
||||||
httpClient = http.DefaultClient
|
httpClient = &http.Client{Timeout: DefaultTimeout}
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, err := url.Parse(baseURL)
|
parsedURL, err := url.Parse(baseURL)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func TestNewClient(t *testing.T) {
|
|||||||
t.Errorf("expected UserAgent %q, got %q", DefaultUserAgent, client.UserAgent)
|
t.Errorf("expected UserAgent %q, got %q", DefaultUserAgent, client.UserAgent)
|
||||||
}
|
}
|
||||||
if client.client == nil {
|
if client.client == nil {
|
||||||
t.Error("expected http.Client to be initialized (nil httpClient should default to http.DefaultClient)")
|
t.Error("expected http.Client to be initialized (nil httpClient should create a client with default timeout)")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = NewClient("://invalid", nil)
|
_, err = NewClient("://invalid", nil)
|
||||||
@@ -137,11 +137,11 @@ func TestClient_ErrorHandling(t *testing.T) {
|
|||||||
t.Fatal("expected error for 500 response")
|
t.Fatal("expected error for 500 response")
|
||||||
}
|
}
|
||||||
|
|
||||||
errorResp, ok := err.(*ErrorResponse)
|
errorResp, ok := err.(*SlurmAPIError)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected *ErrorResponse, got %T", err)
|
t.Fatalf("expected *SlurmAPIError, got %T", err)
|
||||||
}
|
}
|
||||||
if errorResp.Response.StatusCode != 500 {
|
if errorResp.StatusCode != 500 {
|
||||||
t.Errorf("expected status 500, got %d", errorResp.Response.StatusCode)
|
t.Errorf("expected status 500, got %d", errorResp.StatusCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,38 +1,85 @@
|
|||||||
package slurm
|
package slurm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrorResponse represents an error returned by the Slurm REST API.
|
// errorResponseFields is used to parse errors/warnings from a Slurm API error body.
|
||||||
type ErrorResponse struct {
|
type errorResponseFields struct {
|
||||||
Response *http.Response
|
Errors OpenapiErrors `json:"errors,omitempty"`
|
||||||
Message string
|
Warnings OpenapiWarnings `json:"warnings,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ErrorResponse) Error() string {
|
// SlurmAPIError represents a structured error returned by the Slurm REST API.
|
||||||
|
// It captures both the HTTP details and the parsed Slurm error array when available.
|
||||||
|
type SlurmAPIError struct {
|
||||||
|
Response *http.Response
|
||||||
|
StatusCode int
|
||||||
|
Errors OpenapiErrors
|
||||||
|
Warnings OpenapiWarnings
|
||||||
|
Message string // raw body fallback when JSON parsing fails
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SlurmAPIError) Error() string {
|
||||||
|
if len(e.Errors) > 0 {
|
||||||
|
first := e.Errors[0]
|
||||||
|
detail := ""
|
||||||
|
if first.Error != nil {
|
||||||
|
detail = *first.Error
|
||||||
|
} else if first.Description != nil {
|
||||||
|
detail = *first.Description
|
||||||
|
}
|
||||||
|
if detail != "" {
|
||||||
return fmt.Sprintf("%v %v: %d %s",
|
return fmt.Sprintf("%v %v: %d %s",
|
||||||
r.Response.Request.Method, r.Response.Request.URL,
|
e.Response.Request.Method, e.Response.Request.URL,
|
||||||
r.Response.StatusCode, r.Message)
|
e.StatusCode, detail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v: %d %s",
|
||||||
|
e.Response.Request.Method, e.Response.Request.URL,
|
||||||
|
e.StatusCode, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotFound reports whether err is a SlurmAPIError with HTTP 404 status.
|
||||||
|
func IsNotFound(err error) bool {
|
||||||
|
if apiErr, ok := err.(*SlurmAPIError); ok {
|
||||||
|
return apiErr.StatusCode == http.StatusNotFound
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckResponse checks the API response for errors. It returns nil if the
|
// CheckResponse checks the API response for errors. It returns nil if the
|
||||||
// response is a 2xx status code. For non-2xx codes, it reads the response
|
// response is a 2xx status code. For non-2xx codes, it reads the response
|
||||||
// body and returns an ErrorResponse.
|
// body, attempts to parse structured Slurm errors, and returns a SlurmAPIError.
|
||||||
func CheckResponse(r *http.Response) error {
|
func CheckResponse(r *http.Response) error {
|
||||||
if c := r.StatusCode; c >= 200 && c <= 299 {
|
if c := r.StatusCode; c >= 200 && c <= 299 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
errorResponse := &ErrorResponse{Response: r}
|
|
||||||
data, err := io.ReadAll(r.Body)
|
data, err := io.ReadAll(r.Body)
|
||||||
if err != nil || len(data) == 0 {
|
if err != nil || len(data) == 0 {
|
||||||
errorResponse.Message = r.Status
|
return &SlurmAPIError{
|
||||||
return errorResponse
|
Response: r,
|
||||||
|
StatusCode: r.StatusCode,
|
||||||
|
Message: r.Status,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
errorResponse.Message = string(data)
|
apiErr := &SlurmAPIError{
|
||||||
return errorResponse
|
Response: r,
|
||||||
|
StatusCode: r.StatusCode,
|
||||||
|
Message: string(data),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to extract structured errors/warnings from JSON body.
|
||||||
|
var fields errorResponseFields
|
||||||
|
if json.Unmarshal(data, &fields) == nil {
|
||||||
|
apiErr.Errors = fields.Errors
|
||||||
|
apiErr.Warnings = fields.Warnings
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiErr
|
||||||
}
|
}
|
||||||
|
|||||||
220
internal/slurm/errors_test.go
Normal file
220
internal/slurm/errors_test.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package slurm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
body string
|
||||||
|
wantErr bool
|
||||||
|
wantStatusCode int
|
||||||
|
wantErrors int
|
||||||
|
wantWarnings int
|
||||||
|
wantMessageContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "2xx response returns nil",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
body: `{"meta":{}}`,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "404 with valid JSON error body",
|
||||||
|
statusCode: http.StatusNotFound,
|
||||||
|
body: `{"errors":[{"description":"Unable to query JobId=198","error_number":2017,"error":"Invalid job id specified","source":"_handle_job_get"}],"warnings":[]}`,
|
||||||
|
wantErr: true,
|
||||||
|
wantStatusCode: 404,
|
||||||
|
wantErrors: 1,
|
||||||
|
wantWarnings: 0,
|
||||||
|
wantMessageContains: "Invalid job id specified",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500 with non-JSON body",
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
body: "internal server error",
|
||||||
|
wantErr: true,
|
||||||
|
wantStatusCode: 500,
|
||||||
|
wantErrors: 0,
|
||||||
|
wantWarnings: 0,
|
||||||
|
wantMessageContains: "internal server error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "503 with empty body returns http.Status text",
|
||||||
|
statusCode: http.StatusServiceUnavailable,
|
||||||
|
body: "",
|
||||||
|
wantErr: true,
|
||||||
|
wantStatusCode: 503,
|
||||||
|
wantErrors: 0,
|
||||||
|
wantWarnings: 0,
|
||||||
|
wantMessageContains: "503 Service Unavailable",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "400 with valid JSON but empty errors array",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
body: `{"errors":[],"warnings":[]}`,
|
||||||
|
wantErr: true,
|
||||||
|
wantStatusCode: 400,
|
||||||
|
wantErrors: 0,
|
||||||
|
wantWarnings: 0,
|
||||||
|
wantMessageContains: `{"errors":[],"warnings":[]}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(tt.statusCode)
|
||||||
|
if tt.body != "" {
|
||||||
|
rec.Body.WriteString(tt.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := CheckResponse(rec.Result())
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("CheckResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiErr, ok := err.(*SlurmAPIError)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected *SlurmAPIError, got %T", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiErr.StatusCode != tt.wantStatusCode {
|
||||||
|
t.Errorf("StatusCode = %d, want %d", apiErr.StatusCode, tt.wantStatusCode)
|
||||||
|
}
|
||||||
|
if len(apiErr.Errors) != tt.wantErrors {
|
||||||
|
t.Errorf("len(Errors) = %d, want %d", len(apiErr.Errors), tt.wantErrors)
|
||||||
|
}
|
||||||
|
if len(apiErr.Warnings) != tt.wantWarnings {
|
||||||
|
t.Errorf("len(Warnings) = %d, want %d", len(apiErr.Warnings), tt.wantWarnings)
|
||||||
|
}
|
||||||
|
if !strings.Contains(apiErr.Message, tt.wantMessageContains) {
|
||||||
|
t.Errorf("Message = %q, want to contain %q", apiErr.Message, tt.wantMessageContains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsNotFound(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "404 SlurmAPIError returns true",
|
||||||
|
err: &SlurmAPIError{StatusCode: http.StatusNotFound},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500 SlurmAPIError returns false",
|
||||||
|
err: &SlurmAPIError{StatusCode: http.StatusInternalServerError},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "200 SlurmAPIError returns false",
|
||||||
|
err: &SlurmAPIError{StatusCode: http.StatusOK},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "plain error returns false",
|
||||||
|
err: fmt.Errorf("some error"),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil returns false",
|
||||||
|
err: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := IsNotFound(tt.err); got != tt.want {
|
||||||
|
t.Errorf("IsNotFound() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlurmAPIError_Error(t *testing.T) {
|
||||||
|
fakeReq := httptest.NewRequest("GET", "http://localhost/slurm/v0.0.40/job/123", nil)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err *SlurmAPIError
|
||||||
|
wantContains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with Error field set",
|
||||||
|
err: &SlurmAPIError{
|
||||||
|
Response: &http.Response{Request: fakeReq},
|
||||||
|
StatusCode: http.StatusNotFound,
|
||||||
|
Errors: OpenapiErrors{{Error: Ptr("Job not found")}},
|
||||||
|
Message: "raw body",
|
||||||
|
},
|
||||||
|
wantContains: []string{"404", "Job not found"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with Description field set when Error is nil",
|
||||||
|
err: &SlurmAPIError{
|
||||||
|
Response: &http.Response{Request: fakeReq},
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Errors: OpenapiErrors{{Description: Ptr("Unable to query")}},
|
||||||
|
Message: "raw body",
|
||||||
|
},
|
||||||
|
wantContains: []string{"400", "Unable to query"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with both Error and Description nil falls through to Message",
|
||||||
|
err: &SlurmAPIError{
|
||||||
|
Response: &http.Response{Request: fakeReq},
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Errors: OpenapiErrors{{}},
|
||||||
|
Message: "something went wrong",
|
||||||
|
},
|
||||||
|
wantContains: []string{"500", "something went wrong"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with empty Errors slice falls through to Message",
|
||||||
|
err: &SlurmAPIError{
|
||||||
|
Response: &http.Response{Request: fakeReq},
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Errors: OpenapiErrors{},
|
||||||
|
Message: "service unavailable fallback",
|
||||||
|
},
|
||||||
|
wantContains: []string{"503", "service unavailable fallback"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with non-empty Errors but empty detail string falls through to Message",
|
||||||
|
err: &SlurmAPIError{
|
||||||
|
Response: &http.Response{Request: fakeReq},
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
Errors: OpenapiErrors{{ErrorNumber: Ptr(int32(42))}},
|
||||||
|
Message: "gateway error detail",
|
||||||
|
},
|
||||||
|
wantContains: []string{"502", "gateway error detail"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.err.Error()
|
||||||
|
for _, substr := range tt.wantContains {
|
||||||
|
if !strings.Contains(got, substr) {
|
||||||
|
t.Errorf("Error() = %q, want to contain %q", got, substr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,6 +24,8 @@ func defaultClientConfig() *clientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const defaultHTTPTimeout = 30 * time.Second
|
||||||
|
|
||||||
// WithJWTKey specifies the path to the JWT key file.
|
// WithJWTKey specifies the path to the JWT key file.
|
||||||
func WithJWTKey(path string) ClientOption {
|
func WithJWTKey(path string) ClientOption {
|
||||||
return func(c *clientConfig) error {
|
return func(c *clientConfig) error {
|
||||||
@@ -89,11 +91,12 @@ func NewClientWithOpts(baseURL string, opts ...ClientOption) (*Client, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tr := NewJWTAuthTransport(cfg.username, key, transportOpts...)
|
tr := NewJWTAuthTransport(cfg.username, key, transportOpts...)
|
||||||
httpClient = tr.Client()
|
httpClient = &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
Timeout: defaultHTTPTimeout,
|
||||||
|
}
|
||||||
} else if cfg.httpClient != nil {
|
} else if cfg.httpClient != nil {
|
||||||
httpClient = cfg.httpClient
|
httpClient = cfg.httpClient
|
||||||
} else {
|
|
||||||
httpClient = http.DefaultClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewClient(baseURL, httpClient)
|
return NewClient(baseURL, httpClient)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package slurm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -61,8 +60,8 @@ func TestNewClientWithOpts_BackwardCompatible(t *testing.T) {
|
|||||||
if client == nil {
|
if client == nil {
|
||||||
t.Fatal("expected non-nil client")
|
t.Fatal("expected non-nil client")
|
||||||
}
|
}
|
||||||
if client.client != http.DefaultClient {
|
if client.client.Timeout != DefaultTimeout {
|
||||||
t.Error("expected http.DefaultClient when no options provided")
|
t.Errorf("expected Timeout=%v, got %v", DefaultTimeout, client.client.Timeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ type PartitionInfoMaximumsOversubscribe struct {
|
|||||||
|
|
||||||
// PartitionInfoMaximums represents maximum resource limits for a partition (v0.0.40_partition_info.maximums).
|
// PartitionInfoMaximums represents maximum resource limits for a partition (v0.0.40_partition_info.maximums).
|
||||||
type PartitionInfoMaximums struct {
|
type PartitionInfoMaximums struct {
|
||||||
CpusPerNode *int32 `json:"cpus_per_node,omitempty"`
|
CpusPerNode *Uint32NoVal `json:"cpus_per_node,omitempty"`
|
||||||
CpusPerSocket *int32 `json:"cpus_per_socket,omitempty"`
|
CpusPerSocket *Uint32NoVal `json:"cpus_per_socket,omitempty"`
|
||||||
MemoryPerCPU *int64 `json:"memory_per_cpu,omitempty"`
|
MemoryPerCPU *int64 `json:"memory_per_cpu,omitempty"`
|
||||||
PartitionMemoryPerCPU *Uint64NoVal `json:"partition_memory_per_cpu,omitempty"`
|
PartitionMemoryPerCPU *Uint64NoVal `json:"partition_memory_per_cpu,omitempty"`
|
||||||
PartitionMemoryPerNode *Uint64NoVal `json:"partition_memory_per_node,omitempty"`
|
PartitionMemoryPerNode *Uint64NoVal `json:"partition_memory_per_node,omitempty"`
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ func TestPartitionInfoRoundTrip(t *testing.T) {
|
|||||||
},
|
},
|
||||||
GraceTime: Ptr(int32(300)),
|
GraceTime: Ptr(int32(300)),
|
||||||
Maximums: &PartitionInfoMaximums{
|
Maximums: &PartitionInfoMaximums{
|
||||||
CpusPerNode: Ptr(int32(128)),
|
CpusPerNode: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(128))},
|
||||||
CpusPerSocket: Ptr(int32(64)),
|
CpusPerSocket: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(64))},
|
||||||
MemoryPerCPU: Ptr(int64(8192)),
|
MemoryPerCPU: Ptr(int64(8192)),
|
||||||
PartitionMemoryPerCPU: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(8192))},
|
PartitionMemoryPerCPU: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(8192))},
|
||||||
PartitionMemoryPerNode: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(262144))},
|
PartitionMemoryPerNode: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(262144))},
|
||||||
|
|||||||
286
internal/storage/minio.go
Normal file
286
internal/storage/minio.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/config"
|
||||||
|
"github.com/minio/minio-go/v7"
|
||||||
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ObjectInfo contains metadata about a stored object.
|
||||||
|
type ObjectInfo struct {
|
||||||
|
Key string
|
||||||
|
Size int64
|
||||||
|
LastModified time.Time
|
||||||
|
ETag string
|
||||||
|
ContentType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadInfo contains metadata about an uploaded object.
|
||||||
|
type UploadInfo struct {
|
||||||
|
ETag string
|
||||||
|
Size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOptions specifies parameters for GetObject, including optional Range.
|
||||||
|
type GetOptions struct {
|
||||||
|
Start *int64 // Range start byte offset (nil = no range)
|
||||||
|
End *int64 // Range end byte offset (nil = no range)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultipartUpload represents an incomplete multipart upload.
|
||||||
|
type MultipartUpload struct {
|
||||||
|
ObjectName string
|
||||||
|
UploadID string
|
||||||
|
Initiated time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveObjectsOptions specifies options for removing multiple objects.
|
||||||
|
type RemoveObjectsOptions struct {
|
||||||
|
ForceDelete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutObjectOptions for PutObject.
|
||||||
|
type PutObjectOptions struct {
|
||||||
|
ContentType string
|
||||||
|
DisableMultipart bool // true for small chunks (already pre-split)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveObjectOptions for RemoveObject.
|
||||||
|
type RemoveObjectOptions struct {
|
||||||
|
ForceDelete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeBucketOptions for MakeBucket.
|
||||||
|
type MakeBucketOptions struct {
|
||||||
|
Region string
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatObjectOptions for StatObject.
|
||||||
|
type StatObjectOptions struct{}
|
||||||
|
|
||||||
|
// ObjectStorage defines the interface for object storage operations.
|
||||||
|
// Implementations should wrap MinIO SDK calls with custom transfer types.
|
||||||
|
type ObjectStorage interface {
|
||||||
|
// PutObject uploads an object from a reader.
|
||||||
|
PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts PutObjectOptions) (UploadInfo, error)
|
||||||
|
|
||||||
|
// GetObject retrieves an object. opts may contain Range parameters.
|
||||||
|
GetObject(ctx context.Context, bucket, key string, opts GetOptions) (io.ReadCloser, ObjectInfo, error)
|
||||||
|
|
||||||
|
// ComposeObject merges multiple source objects into a single destination.
|
||||||
|
ComposeObject(ctx context.Context, bucket, dst string, sources []string) (UploadInfo, error)
|
||||||
|
|
||||||
|
// AbortMultipartUpload aborts an incomplete multipart upload by upload ID.
|
||||||
|
AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error
|
||||||
|
|
||||||
|
// RemoveIncompleteUpload removes all incomplete uploads for an object.
|
||||||
|
// This is the preferred cleanup method — it encapsulates list + abort logic.
|
||||||
|
RemoveIncompleteUpload(ctx context.Context, bucket, object string) error
|
||||||
|
|
||||||
|
// RemoveObject deletes a single object.
|
||||||
|
RemoveObject(ctx context.Context, bucket, key string, opts RemoveObjectOptions) error
|
||||||
|
|
||||||
|
// ListObjects lists objects with a given prefix.
|
||||||
|
ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]ObjectInfo, error)
|
||||||
|
|
||||||
|
// RemoveObjects deletes multiple objects.
|
||||||
|
RemoveObjects(ctx context.Context, bucket string, keys []string, opts RemoveObjectsOptions) error
|
||||||
|
|
||||||
|
// BucketExists checks if a bucket exists.
|
||||||
|
BucketExists(ctx context.Context, bucket string) (bool, error)
|
||||||
|
|
||||||
|
// MakeBucket creates a new bucket.
|
||||||
|
MakeBucket(ctx context.Context, bucket string, opts MakeBucketOptions) error
|
||||||
|
|
||||||
|
// StatObject gets metadata about an object without downloading it.
|
||||||
|
StatObject(ctx context.Context, bucket, key string, opts StatObjectOptions) (ObjectInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ObjectStorage = (*MinioClient)(nil)
|
||||||
|
|
||||||
|
type MinioClient struct {
|
||||||
|
core *minio.Core
|
||||||
|
bucket string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMinioClient(cfg config.MinioConfig) (*MinioClient, error) {
|
||||||
|
transport, err := minio.DefaultTransport(cfg.UseSSL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create default transport: %w", err)
|
||||||
|
}
|
||||||
|
transport.MaxIdleConnsPerHost = 100
|
||||||
|
transport.IdleConnTimeout = 90 * time.Second
|
||||||
|
|
||||||
|
core, err := minio.NewCore(cfg.Endpoint, &minio.Options{
|
||||||
|
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||||
|
Secure: cfg.UseSSL,
|
||||||
|
Transport: transport,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create minio client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mc := &MinioClient{core: core, bucket: cfg.Bucket}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
exists, err := core.BucketExists(ctx, cfg.Bucket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("check bucket: %w", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
if err := core.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("create bucket: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) PutObject(ctx context.Context, bucket, key string, reader io.Reader, size int64, opts PutObjectOptions) (UploadInfo, error) {
|
||||||
|
info, err := m.core.PutObject(ctx, bucket, key, reader, size, "", "", minio.PutObjectOptions{
|
||||||
|
ContentType: opts.ContentType,
|
||||||
|
DisableMultipart: opts.DisableMultipart,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return UploadInfo{}, fmt.Errorf("put object %s/%s: %w", bucket, key, err)
|
||||||
|
}
|
||||||
|
return UploadInfo{ETag: info.ETag, Size: info.Size}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) GetObject(ctx context.Context, bucket, key string, opts GetOptions) (io.ReadCloser, ObjectInfo, error) {
|
||||||
|
var gopts minio.GetObjectOptions
|
||||||
|
if opts.Start != nil || opts.End != nil {
|
||||||
|
start := int64(0)
|
||||||
|
end := int64(0)
|
||||||
|
if opts.Start != nil {
|
||||||
|
start = *opts.Start
|
||||||
|
}
|
||||||
|
if opts.End != nil {
|
||||||
|
end = *opts.End
|
||||||
|
}
|
||||||
|
if err := gopts.SetRange(start, end); err != nil {
|
||||||
|
return nil, ObjectInfo{}, fmt.Errorf("set range: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body, info, _, err := m.core.GetObject(ctx, bucket, key, gopts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ObjectInfo{}, fmt.Errorf("get object %s/%s: %w", bucket, key, err)
|
||||||
|
}
|
||||||
|
return body, toObjectInfo(info), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) ComposeObject(ctx context.Context, bucket, dst string, sources []string) (UploadInfo, error) {
|
||||||
|
srcs := make([]minio.CopySrcOptions, len(sources))
|
||||||
|
for i, src := range sources {
|
||||||
|
srcs[i] = minio.CopySrcOptions{Bucket: bucket, Object: src}
|
||||||
|
}
|
||||||
|
|
||||||
|
do := minio.CopyDestOptions{
|
||||||
|
Bucket: bucket,
|
||||||
|
Object: dst,
|
||||||
|
ReplaceMetadata: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := m.core.ComposeObject(ctx, do, srcs...)
|
||||||
|
if err != nil {
|
||||||
|
return UploadInfo{}, fmt.Errorf("compose object %s/%s: %w", bucket, dst, err)
|
||||||
|
}
|
||||||
|
return UploadInfo{ETag: info.ETag, Size: info.Size}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) AbortMultipartUpload(ctx context.Context, bucket, object, uploadID string) error {
|
||||||
|
if err := m.core.AbortMultipartUpload(ctx, bucket, object, uploadID); err != nil {
|
||||||
|
return fmt.Errorf("abort multipart upload %s/%s %s: %w", bucket, object, uploadID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) RemoveIncompleteUpload(ctx context.Context, bucket, object string) error {
|
||||||
|
if err := m.core.RemoveIncompleteUpload(ctx, bucket, object); err != nil {
|
||||||
|
return fmt.Errorf("remove incomplete upload %s/%s: %w", bucket, object, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) RemoveObject(ctx context.Context, bucket, key string, opts RemoveObjectOptions) error {
|
||||||
|
if err := m.core.RemoveObject(ctx, bucket, key, minio.RemoveObjectOptions{
|
||||||
|
ForceDelete: opts.ForceDelete,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("remove object %s/%s: %w", bucket, key, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) ListObjects(ctx context.Context, bucket, prefix string, recursive bool) ([]ObjectInfo, error) {
|
||||||
|
ch := m.core.Client.ListObjects(ctx, bucket, minio.ListObjectsOptions{
|
||||||
|
Prefix: prefix,
|
||||||
|
Recursive: recursive,
|
||||||
|
})
|
||||||
|
|
||||||
|
var result []ObjectInfo
|
||||||
|
for obj := range ch {
|
||||||
|
if obj.Err != nil {
|
||||||
|
return result, fmt.Errorf("list objects %s/%s: %w", bucket, prefix, obj.Err)
|
||||||
|
}
|
||||||
|
result = append(result, toObjectInfo(obj))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) RemoveObjects(ctx context.Context, bucket string, keys []string, opts RemoveObjectsOptions) error {
|
||||||
|
objectsCh := make(chan minio.ObjectInfo, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
objectsCh <- minio.ObjectInfo{Key: key}
|
||||||
|
}
|
||||||
|
close(objectsCh)
|
||||||
|
|
||||||
|
errCh := m.core.RemoveObjects(ctx, bucket, objectsCh, minio.RemoveObjectsOptions{})
|
||||||
|
|
||||||
|
for err := range errCh {
|
||||||
|
if err.Err != nil {
|
||||||
|
return fmt.Errorf("remove object %s: %w", err.ObjectName, err.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) BucketExists(ctx context.Context, bucket string) (bool, error) {
|
||||||
|
ok, err := m.core.BucketExists(ctx, bucket)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("bucket exists %s: %w", bucket, err)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) MakeBucket(ctx context.Context, bucket string, opts MakeBucketOptions) error {
|
||||||
|
if err := m.core.MakeBucket(ctx, bucket, minio.MakeBucketOptions{
|
||||||
|
Region: opts.Region,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("make bucket %s: %w", bucket, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MinioClient) StatObject(ctx context.Context, bucket, key string, _ StatObjectOptions) (ObjectInfo, error) {
|
||||||
|
info, err := m.core.StatObject(ctx, bucket, key, minio.StatObjectOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return ObjectInfo{}, fmt.Errorf("stat object %s/%s: %w", bucket, key, err)
|
||||||
|
}
|
||||||
|
return toObjectInfo(info), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toObjectInfo(info minio.ObjectInfo) ObjectInfo {
|
||||||
|
return ObjectInfo{
|
||||||
|
Key: info.Key,
|
||||||
|
Size: info.Size,
|
||||||
|
LastModified: info.LastModified,
|
||||||
|
ETag: info.ETag,
|
||||||
|
ContentType: info.ContentType,
|
||||||
|
}
|
||||||
|
}
|
||||||
7
internal/storage/minio_test.go
Normal file
7
internal/storage/minio_test.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestMinioClientImplementsObjectStorage(t *testing.T) {
|
||||||
|
var _ ObjectStorage = (*MinioClient)(nil)
|
||||||
|
}
|
||||||
114
internal/store/application_store.go
Normal file
114
internal/store/application_store.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ApplicationStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewApplicationStore(db *gorm.DB) *ApplicationStore {
|
||||||
|
return &ApplicationStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApplicationStore) List(ctx context.Context, page, pageSize int) ([]model.Application, int, error) {
|
||||||
|
var apps []model.Application
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.Application{}).Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&apps).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return apps, int(total), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApplicationStore) GetByID(ctx context.Context, id int64) (*model.Application, error) {
|
||||||
|
var app model.Application
|
||||||
|
err := s.db.WithContext(ctx).First(&app, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &app, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApplicationStore) Create(ctx context.Context, req *model.CreateApplicationRequest) (int64, error) {
|
||||||
|
params := req.Parameters
|
||||||
|
if len(params) == 0 {
|
||||||
|
params = json.RawMessage(`[]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
app := &model.Application{
|
||||||
|
Name: req.Name,
|
||||||
|
Description: req.Description,
|
||||||
|
Icon: req.Icon,
|
||||||
|
Category: req.Category,
|
||||||
|
ScriptTemplate: req.ScriptTemplate,
|
||||||
|
Parameters: params,
|
||||||
|
Scope: req.Scope,
|
||||||
|
}
|
||||||
|
if err := s.db.WithContext(ctx).Create(app).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return app.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApplicationStore) Update(ctx context.Context, id int64, req *model.UpdateApplicationRequest) error {
|
||||||
|
updates := map[string]interface{}{}
|
||||||
|
if req.Name != nil {
|
||||||
|
updates["name"] = *req.Name
|
||||||
|
}
|
||||||
|
if req.Description != nil {
|
||||||
|
updates["description"] = *req.Description
|
||||||
|
}
|
||||||
|
if req.Icon != nil {
|
||||||
|
updates["icon"] = *req.Icon
|
||||||
|
}
|
||||||
|
if req.Category != nil {
|
||||||
|
updates["category"] = *req.Category
|
||||||
|
}
|
||||||
|
if req.ScriptTemplate != nil {
|
||||||
|
updates["script_template"] = *req.ScriptTemplate
|
||||||
|
}
|
||||||
|
if req.Parameters != nil {
|
||||||
|
updates["parameters"] = *req.Parameters
|
||||||
|
}
|
||||||
|
if req.Scope != nil {
|
||||||
|
updates["scope"] = *req.Scope
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Application{}).Where("id = ?", id).Updates(updates)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApplicationStore) Delete(ctx context.Context, id int64) error {
|
||||||
|
result := s.db.WithContext(ctx).Delete(&model.Application{}, id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
227
internal/store/application_store_test.go
Normal file
227
internal/store/application_store_test.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAppTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Application{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Create_Success(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
id, err := s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "gromacs",
|
||||||
|
Description: "Molecular dynamics simulator",
|
||||||
|
Category: "simulation",
|
||||||
|
ScriptTemplate: "#!/bin/bash\nmodule load gromacs",
|
||||||
|
Parameters: []byte(`[{"name":"ntasks","type":"number","required":true}]`),
|
||||||
|
Scope: "system",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
t.Errorf("Create() id = %d, want positive", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_GetByID_Success(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "lammps",
|
||||||
|
ScriptTemplate: "#!/bin/bash\nmodule load lammps",
|
||||||
|
})
|
||||||
|
|
||||||
|
app, err := s.GetByID(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if app == nil {
|
||||||
|
t.Fatal("GetByID() returned nil, expected application")
|
||||||
|
}
|
||||||
|
if app.Name != "lammps" {
|
||||||
|
t.Errorf("Name = %q, want %q", app.Name, "lammps")
|
||||||
|
}
|
||||||
|
if app.ID != id {
|
||||||
|
t.Errorf("ID = %d, want %d", app.ID, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_GetByID_NotFound(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
app, err := s.GetByID(context.Background(), 99999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if app != nil {
|
||||||
|
t.Error("GetByID() expected nil for not-found, got non-nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_List_Pagination(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "app-" + string(rune('A'+i)),
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho " + string(rune('A'+i)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
apps, total, err := s.List(context.Background(), 1, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 5 {
|
||||||
|
t.Errorf("total = %d, want 5", total)
|
||||||
|
}
|
||||||
|
if len(apps) != 3 {
|
||||||
|
t.Errorf("len(apps) = %d, want 3", len(apps))
|
||||||
|
}
|
||||||
|
|
||||||
|
apps2, total2, err := s.List(context.Background(), 2, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() page 2 error = %v", err)
|
||||||
|
}
|
||||||
|
if total2 != 5 {
|
||||||
|
t.Errorf("total2 = %d, want 5", total2)
|
||||||
|
}
|
||||||
|
if len(apps2) != 2 {
|
||||||
|
t.Errorf("len(apps2) = %d, want 2", len(apps2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Update_Success(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "orig",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho original",
|
||||||
|
})
|
||||||
|
|
||||||
|
newName := "updated"
|
||||||
|
newDesc := "updated description"
|
||||||
|
err := s.Update(context.Background(), id, &model.UpdateApplicationRequest{
|
||||||
|
Name: &newName,
|
||||||
|
Description: &newDesc,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app, _ := s.GetByID(context.Background(), id)
|
||||||
|
if app.Name != "updated" {
|
||||||
|
t.Errorf("Name = %q, want %q", app.Name, "updated")
|
||||||
|
}
|
||||||
|
if app.Description != "updated description" {
|
||||||
|
t.Errorf("Description = %q, want %q", app.Description, "updated description")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Update_NotFound(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
name := "nope"
|
||||||
|
err := s.Update(context.Background(), 99999, &model.UpdateApplicationRequest{
|
||||||
|
Name: &name,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Update() expected error for not-found, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Delete_Success(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
id, _ := s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "to-delete",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho bye",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Delete(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app, _ := s.GetByID(context.Background(), id)
|
||||||
|
if app != nil {
|
||||||
|
t.Error("GetByID() after delete returned non-nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Delete_Idempotent(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
err := s.Delete(context.Background(), 99999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete() non-existent error = %v, want nil (idempotent)", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Create_DuplicateName(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
_, err := s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "dup-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho 1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Create() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "dup-app",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho 2",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate name, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplicationStore_Create_EmptyParameters(t *testing.T) {
|
||||||
|
db := newAppTestDB(t)
|
||||||
|
s := NewApplicationStore(db)
|
||||||
|
|
||||||
|
id, err := s.Create(context.Background(), &model.CreateApplicationRequest{
|
||||||
|
Name: "no-params",
|
||||||
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app, _ := s.GetByID(context.Background(), id)
|
||||||
|
if string(app.Parameters) != "[]" {
|
||||||
|
t.Errorf("Parameters = %q, want []", string(app.Parameters))
|
||||||
|
}
|
||||||
|
}
|
||||||
108
internal/store/blob_store.go
Normal file
108
internal/store/blob_store.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BlobStore manages physical file blobs with reference counting.
|
||||||
|
type BlobStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBlobStore creates a new BlobStore.
|
||||||
|
func NewBlobStore(db *gorm.DB) *BlobStore {
|
||||||
|
return &BlobStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create inserts a new FileBlob record.
|
||||||
|
func (s *BlobStore) Create(ctx context.Context, blob *model.FileBlob) error {
|
||||||
|
return s.db.WithContext(ctx).Create(blob).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBySHA256 returns the FileBlob with the given SHA256 hash.
|
||||||
|
// Returns (nil, nil) if not found.
|
||||||
|
func (s *BlobStore) GetBySHA256(ctx context.Context, sha256 string) (*model.FileBlob, error) {
|
||||||
|
var blob model.FileBlob
|
||||||
|
err := s.db.WithContext(ctx).Where("sha256 = ?", sha256).First(&blob).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &blob, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementRef atomically increments the ref_count for the blob with the given SHA256.
|
||||||
|
func (s *BlobStore) IncrementRef(ctx context.Context, sha256 string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.FileBlob{}).
|
||||||
|
Where("sha256 = ?", sha256).
|
||||||
|
UpdateColumn("ref_count", gorm.Expr("ref_count + 1"))
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrementRef atomically decrements the ref_count for the blob with the given SHA256.
|
||||||
|
// Returns the new ref_count after decrementing.
|
||||||
|
func (s *BlobStore) DecrementRef(ctx context.Context, sha256 string) (int64, error) {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.FileBlob{}).
|
||||||
|
Where("sha256 = ? AND ref_count > 0", sha256).
|
||||||
|
UpdateColumn("ref_count", gorm.Expr("ref_count - 1"))
|
||||||
|
if result.Error != nil {
|
||||||
|
return 0, result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return 0, gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var blob model.FileBlob
|
||||||
|
if err := s.db.WithContext(ctx).Where("sha256 = ?", sha256).First(&blob).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int64(blob.RefCount), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a FileBlob record by SHA256 (hard delete).
|
||||||
|
func (s *BlobStore) Delete(ctx context.Context, sha256 string) error {
|
||||||
|
result := s.db.WithContext(ctx).Where("sha256 = ?", sha256).Delete(&model.FileBlob{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BlobStore) GetBySHA256s(ctx context.Context, sha256s []string) ([]model.FileBlob, error) {
|
||||||
|
var blobs []model.FileBlob
|
||||||
|
if len(sha256s) == 0 {
|
||||||
|
return blobs, nil
|
||||||
|
}
|
||||||
|
err := s.db.WithContext(ctx).Where("sha256 IN ?", sha256s).Find(&blobs).Error
|
||||||
|
return blobs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBySHA256ForUpdate returns the FileBlob with a SELECT ... FOR UPDATE lock.
|
||||||
|
// Returns (nil, nil) if not found.
|
||||||
|
func (s *BlobStore) GetBySHA256ForUpdate(ctx context.Context, tx *gorm.DB, sha256 string) (*model.FileBlob, error) {
|
||||||
|
var blob model.FileBlob
|
||||||
|
err := tx.WithContext(ctx).
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
|
Where("sha256 = ?", sha256).First(&blob).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &blob, nil
|
||||||
|
}
|
||||||
199
internal/store/blob_store_test.go
Normal file
199
internal/store/blob_store_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupBlobTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_Create(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blob := &model.FileBlob{
|
||||||
|
SHA256: "abc123",
|
||||||
|
MinioKey: "files/abc123",
|
||||||
|
FileSize: 1024,
|
||||||
|
MimeType: "application/octet-stream",
|
||||||
|
RefCount: 0,
|
||||||
|
}
|
||||||
|
if err := store.Create(ctx, blob); err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if blob.ID == 0 {
|
||||||
|
t.Error("Create() did not set ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
|
||||||
|
|
||||||
|
blob, err := store.GetBySHA256(ctx, "abc")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256() error = %v", err)
|
||||||
|
}
|
||||||
|
if blob == nil {
|
||||||
|
t.Fatal("GetBySHA256() returned nil")
|
||||||
|
}
|
||||||
|
if blob.SHA256 != "abc" {
|
||||||
|
t.Errorf("SHA256 = %q, want %q", blob.SHA256, "abc")
|
||||||
|
}
|
||||||
|
if blob.RefCount != 0 {
|
||||||
|
t.Errorf("RefCount = %d, want 0", blob.RefCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256_NotFound(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blob, err := store.GetBySHA256(ctx, "nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256() error = %v", err)
|
||||||
|
}
|
||||||
|
if blob != nil {
|
||||||
|
t.Error("GetBySHA256() should return nil for not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_IncrementDecrementRef(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
|
||||||
|
|
||||||
|
if err := store.IncrementRef(ctx, "abc"); err != nil {
|
||||||
|
t.Fatalf("IncrementRef() error = %v", err)
|
||||||
|
}
|
||||||
|
blob, _ := store.GetBySHA256(ctx, "abc")
|
||||||
|
if blob.RefCount != 1 {
|
||||||
|
t.Errorf("RefCount after 1st increment = %d, want 1", blob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
store.IncrementRef(ctx, "abc")
|
||||||
|
blob, _ = store.GetBySHA256(ctx, "abc")
|
||||||
|
if blob.RefCount != 2 {
|
||||||
|
t.Errorf("RefCount after 2nd increment = %d, want 2", blob.RefCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
refCount, err := store.DecrementRef(ctx, "abc")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecrementRef() error = %v", err)
|
||||||
|
}
|
||||||
|
if refCount != 1 {
|
||||||
|
t.Errorf("DecrementRef() returned %d, want 1", refCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
refCount, _ = store.DecrementRef(ctx, "abc")
|
||||||
|
if refCount != 0 {
|
||||||
|
t.Errorf("DecrementRef() returned %d, want 0", refCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_Delete(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "abc", MinioKey: "files/abc", FileSize: 100, RefCount: 0})
|
||||||
|
|
||||||
|
if err := store.Delete(ctx, "abc"); err != nil {
|
||||||
|
t.Fatalf("Delete() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, _ := store.GetBySHA256(ctx, "abc")
|
||||||
|
if blob != nil {
|
||||||
|
t.Error("Delete() did not remove blob")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256s(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "h1", MinioKey: "files/h1", FileSize: 100})
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "h2", MinioKey: "files/h2", FileSize: 200})
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "h3", MinioKey: "files/h3", FileSize: 300})
|
||||||
|
|
||||||
|
blobs, err := store.GetBySHA256s(ctx, []string{"h1", "h3"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256s() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(blobs) != 2 {
|
||||||
|
t.Fatalf("len(blobs) = %d, want 2", len(blobs))
|
||||||
|
}
|
||||||
|
keys := map[string]bool{}
|
||||||
|
for _, b := range blobs {
|
||||||
|
keys[b.SHA256] = true
|
||||||
|
}
|
||||||
|
if !keys["h1"] || !keys["h3"] {
|
||||||
|
t.Errorf("expected h1 and h3, got %v", blobs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256s_Empty(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blobs, err := store.GetBySHA256s(ctx, []string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256s() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(blobs) != 0 {
|
||||||
|
t.Errorf("len(blobs) = %d, want 0", len(blobs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_GetBySHA256s_NotFound(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blobs, err := store.GetBySHA256s(ctx, []string{"nonexistent"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBySHA256s() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(blobs) != 0 {
|
||||||
|
t.Errorf("len(blobs) = %d, want 0 for non-existent SHA256s", len(blobs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobStore_SHA256_UniqueConstraint(t *testing.T) {
|
||||||
|
db := setupBlobTestDB(t)
|
||||||
|
store := NewBlobStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.FileBlob{SHA256: "dup", MinioKey: "files/dup1", FileSize: 100})
|
||||||
|
err := store.Create(ctx, &model.FileBlob{SHA256: "dup", MinioKey: "files/dup2", FileSize: 200})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for duplicate SHA256, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
108
internal/store/file_store.go
Normal file
108
internal/store/file_store.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFileStore(db *gorm.DB) *FileStore {
|
||||||
|
return &FileStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) Create(ctx context.Context, file *model.File) error {
|
||||||
|
return s.db.WithContext(ctx).Create(file).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetByID(ctx context.Context, id int64) (*model.File, error) {
|
||||||
|
var file model.File
|
||||||
|
err := s.db.WithContext(ctx).First(&file, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &file, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) List(ctx context.Context, folderID *int64, page, pageSize int) ([]model.File, int64, error) {
|
||||||
|
query := s.db.WithContext(ctx).Model(&model.File{})
|
||||||
|
if folderID == nil {
|
||||||
|
query = query.Where("folder_id IS NULL")
|
||||||
|
} else {
|
||||||
|
query = query.Where("folder_id = ?", *folderID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := query.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []model.File
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := query.Order("id DESC").Limit(pageSize).Offset(offset).Find(&files).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return files, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) Search(ctx context.Context, queryStr string, page, pageSize int) ([]model.File, int64, error) {
|
||||||
|
query := s.db.WithContext(ctx).Model(&model.File{}).Where("name LIKE ?", "%"+queryStr+"%")
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := query.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []model.File
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := query.Order("id DESC").Limit(pageSize).Offset(offset).Find(&files).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return files, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) Delete(ctx context.Context, id int64) error {
|
||||||
|
result := s.db.WithContext(ctx).Delete(&model.File{}, id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) CountByBlobSHA256(ctx context.Context, blobSHA256 string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := s.db.WithContext(ctx).Model(&model.File{}).
|
||||||
|
Where("blob_sha256 = ?", blobSHA256).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetByIDs(ctx context.Context, ids []int64) ([]model.File, error) {
|
||||||
|
var files []model.File
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&files).Error
|
||||||
|
return files, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetBlobSHA256ByID(ctx context.Context, id int64) (string, error) {
|
||||||
|
var file model.File
|
||||||
|
err := s.db.WithContext(ctx).Select("blob_sha256").First(&file, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return file.BlobSHA256, nil
|
||||||
|
}
|
||||||
323
internal/store/file_store_test.go
Normal file
323
internal/store/file_store_test.go
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFileTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.File{}, &model.FileBlob{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_CreateAndGetByID(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
file := &model.File{
|
||||||
|
Name: "test.bin",
|
||||||
|
BlobSHA256: "abc123",
|
||||||
|
}
|
||||||
|
if err := store.Create(ctx, file); err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if file.ID == 0 {
|
||||||
|
t.Fatal("Create() did not set ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := store.GetByID(ctx, file.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetByID() returned nil")
|
||||||
|
}
|
||||||
|
if got.Name != "test.bin" {
|
||||||
|
t.Errorf("Name = %q, want %q", got.Name, "test.bin")
|
||||||
|
}
|
||||||
|
if got.BlobSHA256 != "abc123" {
|
||||||
|
t.Errorf("BlobSHA256 = %q, want %q", got.BlobSHA256, "abc123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByID_NotFound(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
got, err := store.GetByID(ctx, 999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByID() should return nil for not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_ListByFolder(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
folderID := int64(1)
|
||||||
|
store.Create(ctx, &model.File{Name: "f1.bin", BlobSHA256: "a1", FolderID: &folderID})
|
||||||
|
store.Create(ctx, &model.File{Name: "f2.bin", BlobSHA256: "a2", FolderID: &folderID})
|
||||||
|
store.Create(ctx, &model.File{Name: "root.bin", BlobSHA256: "a3"}) // root (folder_id=nil)
|
||||||
|
|
||||||
|
files, total, err := store.List(ctx, &folderID, 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("total = %d, want 2", total)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Errorf("len(files) = %d, want 2", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_ListRootFolder(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "root.bin", BlobSHA256: "a1"})
|
||||||
|
folderID := int64(1)
|
||||||
|
store.Create(ctx, &model.File{Name: "sub.bin", BlobSHA256: "a2", FolderID: &folderID})
|
||||||
|
|
||||||
|
files, total, err := store.List(ctx, nil, 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 1 {
|
||||||
|
t.Errorf("total = %d, want 1", total)
|
||||||
|
}
|
||||||
|
if len(files) != 1 {
|
||||||
|
t.Errorf("len(files) = %d, want 1", len(files))
|
||||||
|
}
|
||||||
|
if files[0].Name != "root.bin" {
|
||||||
|
t.Errorf("files[0].Name = %q, want %q", files[0].Name, "root.bin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_Pagination(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := 0; i < 25; i++ {
|
||||||
|
store.Create(ctx, &model.File{Name: "file.bin", BlobSHA256: "hash"})
|
||||||
|
}
|
||||||
|
|
||||||
|
files, total, err := store.List(ctx, nil, 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 25 {
|
||||||
|
t.Errorf("total = %d, want 25", total)
|
||||||
|
}
|
||||||
|
if len(files) != 10 {
|
||||||
|
t.Errorf("page 1 len = %d, want 10", len(files))
|
||||||
|
}
|
||||||
|
|
||||||
|
files, _, _ = store.List(ctx, nil, 3, 10)
|
||||||
|
if len(files) != 5 {
|
||||||
|
t.Errorf("page 3 len = %d, want 5", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_Search(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "experiment_results.csv", BlobSHA256: "a1"})
|
||||||
|
store.Create(ctx, &model.File{Name: "training_log.txt", BlobSHA256: "a2"})
|
||||||
|
store.Create(ctx, &model.File{Name: "model_weights.bin", BlobSHA256: "a3"})
|
||||||
|
|
||||||
|
files, total, err := store.Search(ctx, "results", 1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Search() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 1 {
|
||||||
|
t.Errorf("total = %d, want 1", total)
|
||||||
|
}
|
||||||
|
if len(files) != 1 || files[0].Name != "experiment_results.csv" {
|
||||||
|
t.Errorf("expected experiment_results.csv, got %v", files)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_Delete_SoftDelete(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
file := &model.File{Name: "deleteme.bin", BlobSHA256: "abc"}
|
||||||
|
store.Create(ctx, file)
|
||||||
|
|
||||||
|
if err := store.Delete(ctx, file.ID); err != nil {
|
||||||
|
t.Fatalf("Delete() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID should return nil (soft deleted)
|
||||||
|
got, err := store.GetByID(ctx, file.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByID() should return nil after soft delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// List should not include soft deleted
|
||||||
|
_, total, _ := store.List(ctx, nil, 1, 10)
|
||||||
|
if total != 0 {
|
||||||
|
t.Errorf("total after delete = %d, want 0", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_CountByBlobSHA256(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "shared_hash"})
|
||||||
|
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "shared_hash"})
|
||||||
|
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "shared_hash"})
|
||||||
|
|
||||||
|
count, err := store.CountByBlobSHA256(ctx, "shared_hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountByBlobSHA256() error = %v", err)
|
||||||
|
}
|
||||||
|
if count != 3 {
|
||||||
|
t.Errorf("count = %d, want 3", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Soft delete one
|
||||||
|
store.Delete(ctx, 1)
|
||||||
|
count, _ = store.CountByBlobSHA256(ctx, "shared_hash")
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("count after soft delete = %d, want 2", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetBlobSHA256ByID(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
file := &model.File{Name: "test.bin", BlobSHA256: "my_hash"}
|
||||||
|
store.Create(ctx, file)
|
||||||
|
|
||||||
|
sha256, err := store.GetBlobSHA256ByID(ctx, file.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBlobSHA256ByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if sha256 != "my_hash" {
|
||||||
|
t.Errorf("sha256 = %q, want %q", sha256, "my_hash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
|
||||||
|
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
|
||||||
|
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{1, 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("len(files) = %d, want 2", len(files))
|
||||||
|
}
|
||||||
|
names := map[string]bool{}
|
||||||
|
for _, f := range files {
|
||||||
|
names[f.Name] = true
|
||||||
|
}
|
||||||
|
if !names["a.bin"] || !names["c.bin"] {
|
||||||
|
t.Errorf("expected a.bin and c.bin, got %v", files)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs_Empty(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 0 {
|
||||||
|
t.Errorf("len(files) = %d, want 0", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs_NotFound(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{999})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 0 {
|
||||||
|
t.Errorf("len(files) = %d, want 0 for non-existent IDs", len(files))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetByIDs_SoftDeleteExcluded(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
store.Create(ctx, &model.File{Name: "a.bin", BlobSHA256: "h1"})
|
||||||
|
store.Create(ctx, &model.File{Name: "b.bin", BlobSHA256: "h2"})
|
||||||
|
store.Create(ctx, &model.File{Name: "c.bin", BlobSHA256: "h3"})
|
||||||
|
|
||||||
|
store.Delete(ctx, 2)
|
||||||
|
|
||||||
|
files, err := store.GetByIDs(ctx, []int64{1, 2, 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByIDs() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("len(files) = %d, want 2 (soft-deleted excluded)", len(files))
|
||||||
|
}
|
||||||
|
for _, f := range files {
|
||||||
|
if f.ID == 2 {
|
||||||
|
t.Error("soft-deleted file ID 2 should not appear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetBlobSHA256ByID_NotFound(t *testing.T) {
|
||||||
|
db := setupFileTestDB(t)
|
||||||
|
store := NewFileStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
sha256, err := store.GetBlobSHA256ByID(ctx, 999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBlobSHA256ByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if sha256 != "" {
|
||||||
|
t.Errorf("sha256 = %q, want empty for not found", sha256)
|
||||||
|
}
|
||||||
|
}
|
||||||
105
internal/store/folder_store.go
Normal file
105
internal/store/folder_store.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FolderStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFolderStore(db *gorm.DB) *FolderStore {
|
||||||
|
return &FolderStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FolderStore) Create(ctx context.Context, folder *model.Folder) error {
|
||||||
|
return s.db.WithContext(ctx).Create(folder).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FolderStore) GetByID(ctx context.Context, id int64) (*model.Folder, error) {
|
||||||
|
var folder model.Folder
|
||||||
|
err := s.db.WithContext(ctx).First(&folder, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &folder, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FolderStore) GetByPath(ctx context.Context, path string) (*model.Folder, error) {
|
||||||
|
var folder model.Folder
|
||||||
|
err := s.db.WithContext(ctx).Where("path = ?", path).First(&folder).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &folder, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FolderStore) ListByParentID(ctx context.Context, parentID *int64) ([]model.Folder, error) {
|
||||||
|
var folders []model.Folder
|
||||||
|
query := s.db.WithContext(ctx)
|
||||||
|
if parentID == nil {
|
||||||
|
query = query.Where("parent_id IS NULL")
|
||||||
|
} else {
|
||||||
|
query = query.Where("parent_id = ?", *parentID)
|
||||||
|
}
|
||||||
|
if err := query.Order("name ASC").Find(&folders).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return folders, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSubTree returns all folders whose path starts with the given prefix.
|
||||||
|
func (s *FolderStore) GetSubTree(ctx context.Context, path string) ([]model.Folder, error) {
|
||||||
|
var folders []model.Folder
|
||||||
|
if err := s.db.WithContext(ctx).Where("path LIKE ?", path+"%").Find(&folders).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return folders, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasChildren checks if a folder has sub-folders or files.
|
||||||
|
func (s *FolderStore) HasChildren(ctx context.Context, id int64) (bool, error) {
|
||||||
|
folder, err := s.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if folder == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for sub-folders
|
||||||
|
var subFolderCount int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.Folder{}).Where("parent_id = ?", id).Count(&subFolderCount).Error; err != nil {
|
||||||
|
return false, fmt.Errorf("count sub-folders: %w", err)
|
||||||
|
}
|
||||||
|
if subFolderCount > 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for files
|
||||||
|
var fileCount int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.File{}).Where("folder_id = ?", id).Count(&fileCount).Error; err != nil {
|
||||||
|
return false, fmt.Errorf("count files: %w", err)
|
||||||
|
}
|
||||||
|
return fileCount > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FolderStore) Delete(ctx context.Context, id int64) error {
|
||||||
|
result := s.db.WithContext(ctx).Delete(&model.Folder{}, id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
294
internal/store/folder_store_test.go
Normal file
294
internal/store/folder_store_test.go
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFolderTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Folder{}, &model.File{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_CreateAndGetByID(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
folder := &model.Folder{
|
||||||
|
Name: "data",
|
||||||
|
Path: "/data/",
|
||||||
|
}
|
||||||
|
if err := s.Create(context.Background(), folder); err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if folder.ID <= 0 {
|
||||||
|
t.Fatalf("Create() id = %d, want positive", folder.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := s.GetByID(context.Background(), folder.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetByID() returned nil")
|
||||||
|
}
|
||||||
|
if got.Name != "data" {
|
||||||
|
t.Errorf("Name = %q, want %q", got.Name, "data")
|
||||||
|
}
|
||||||
|
if got.Path != "/data/" {
|
||||||
|
t.Errorf("Path = %q, want %q", got.Path, "/data/")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_GetByID_NotFound(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
got, err := s.GetByID(context.Background(), 99999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByID() expected nil for not-found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_GetByPath(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
folder := &model.Folder{
|
||||||
|
Name: "data",
|
||||||
|
Path: "/data/",
|
||||||
|
}
|
||||||
|
if err := s.Create(context.Background(), folder); err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := s.GetByPath(context.Background(), "/data/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByPath() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetByPath() returned nil")
|
||||||
|
}
|
||||||
|
if got.ID != folder.ID {
|
||||||
|
t.Errorf("ID = %d, want %d", got.ID, folder.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err = s.GetByPath(context.Background(), "/nonexistent/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByPath() nonexistent error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByPath() expected nil for nonexistent path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_ListByParentID(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
root1 := &model.Folder{Name: "alpha", Path: "/alpha/"}
|
||||||
|
root2 := &model.Folder{Name: "beta", Path: "/beta/"}
|
||||||
|
if err := s.Create(context.Background(), root1); err != nil {
|
||||||
|
t.Fatalf("Create root1: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Create(context.Background(), root2); err != nil {
|
||||||
|
t.Fatalf("Create root2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := &model.Folder{Name: "sub", Path: "/alpha/sub/", ParentID: &root1.ID}
|
||||||
|
if err := s.Create(context.Background(), sub); err != nil {
|
||||||
|
t.Fatalf("Create sub: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
roots, err := s.ListByParentID(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListByParentID(nil) error = %v", err)
|
||||||
|
}
|
||||||
|
if len(roots) != 2 {
|
||||||
|
t.Fatalf("root folders = %d, want 2", len(roots))
|
||||||
|
}
|
||||||
|
if roots[0].Name != "alpha" {
|
||||||
|
t.Errorf("roots[0].Name = %q, want %q (alphabetical)", roots[0].Name, "alpha")
|
||||||
|
}
|
||||||
|
|
||||||
|
children, err := s.ListByParentID(context.Background(), &root1.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListByParentID(root1) error = %v", err)
|
||||||
|
}
|
||||||
|
if len(children) != 1 {
|
||||||
|
t.Fatalf("children = %d, want 1", len(children))
|
||||||
|
}
|
||||||
|
if children[0].Name != "sub" {
|
||||||
|
t.Errorf("children[0].Name = %q, want %q", children[0].Name, "sub")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_GetSubTree(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
data := &model.Folder{Name: "data", Path: "/data/"}
|
||||||
|
if err := s.Create(context.Background(), data); err != nil {
|
||||||
|
t.Fatalf("Create data: %v", err)
|
||||||
|
}
|
||||||
|
results := &model.Folder{Name: "results", Path: "/data/results/", ParentID: &data.ID}
|
||||||
|
if err := s.Create(context.Background(), results); err != nil {
|
||||||
|
t.Fatalf("Create results: %v", err)
|
||||||
|
}
|
||||||
|
other := &model.Folder{Name: "other", Path: "/other/"}
|
||||||
|
if err := s.Create(context.Background(), other); err != nil {
|
||||||
|
t.Fatalf("Create other: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
subtree, err := s.GetSubTree(context.Background(), "/data/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSubTree() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(subtree) != 2 {
|
||||||
|
t.Fatalf("subtree = %d, want 2", len(subtree))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_HasChildren_WithSubFolders(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
parent := &model.Folder{Name: "parent", Path: "/parent/"}
|
||||||
|
if err := s.Create(context.Background(), parent); err != nil {
|
||||||
|
t.Fatalf("Create parent: %v", err)
|
||||||
|
}
|
||||||
|
child := &model.Folder{Name: "child", Path: "/parent/child/", ParentID: &parent.ID}
|
||||||
|
if err := s.Create(context.Background(), child); err != nil {
|
||||||
|
t.Fatalf("Create child: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
has, err := s.HasChildren(context.Background(), parent.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HasChildren() error = %v", err)
|
||||||
|
}
|
||||||
|
if !has {
|
||||||
|
t.Error("HasChildren() = false, want true (has sub-folders)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_HasChildren_WithFiles(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
folder := &model.Folder{Name: "docs", Path: "/docs/"}
|
||||||
|
if err := s.Create(context.Background(), folder); err != nil {
|
||||||
|
t.Fatalf("Create folder: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file := &model.File{
|
||||||
|
Name: "readme.txt",
|
||||||
|
FolderID: &folder.ID,
|
||||||
|
BlobSHA256: "abc123",
|
||||||
|
}
|
||||||
|
if err := db.Create(file).Error; err != nil {
|
||||||
|
t.Fatalf("Create file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
has, err := s.HasChildren(context.Background(), folder.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HasChildren() error = %v", err)
|
||||||
|
}
|
||||||
|
if !has {
|
||||||
|
t.Error("HasChildren() = false, want true (has files)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_HasChildren_Empty(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
folder := &model.Folder{Name: "empty", Path: "/empty/"}
|
||||||
|
if err := s.Create(context.Background(), folder); err != nil {
|
||||||
|
t.Fatalf("Create folder: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
has, err := s.HasChildren(context.Background(), folder.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HasChildren() error = %v", err)
|
||||||
|
}
|
||||||
|
if has {
|
||||||
|
t.Error("HasChildren() = true, want false (empty folder)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_HasChildren_NotFound(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
has, err := s.HasChildren(context.Background(), 99999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HasChildren() error = %v", err)
|
||||||
|
}
|
||||||
|
if has {
|
||||||
|
t.Error("HasChildren() = true for nonexistent, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_Delete(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
folder := &model.Folder{Name: "temp", Path: "/temp/"}
|
||||||
|
if err := s.Create(context.Background(), folder); err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.Delete(context.Background(), folder.ID); err != nil {
|
||||||
|
t.Fatalf("Delete() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := s.GetByID(context.Background(), folder.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() after delete error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByID() after delete returned non-nil, expected soft-deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_Delete_Idempotent(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
if err := s.Delete(context.Background(), 99999); err != nil {
|
||||||
|
t.Fatalf("Delete() non-existent error = %v, want nil (idempotent)", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFolderStore_Path_UniqueConstraint(t *testing.T) {
|
||||||
|
db := setupFolderTestDB(t)
|
||||||
|
s := NewFolderStore(db)
|
||||||
|
|
||||||
|
f1 := &model.Folder{Name: "data", Path: "/data/"}
|
||||||
|
if err := s.Create(context.Background(), f1); err != nil {
|
||||||
|
t.Fatalf("Create first: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f2 := &model.Folder{Name: "data2", Path: "/data/"}
|
||||||
|
if err := s.Create(context.Background(), f2); err == nil {
|
||||||
|
t.Fatal("expected error for duplicate path, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1 +0,0 @@
|
|||||||
DROP TABLE IF EXISTS job_templates;
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
CREATE TABLE IF NOT EXISTS job_templates (
|
|
||||||
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
name VARCHAR(255) NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
script TEXT NOT NULL,
|
|
||||||
partition VARCHAR(255),
|
|
||||||
qos VARCHAR(255),
|
|
||||||
cpus INT UNSIGNED,
|
|
||||||
memory VARCHAR(50),
|
|
||||||
time_limit VARCHAR(50),
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
|
||||||
UNIQUE KEY idx_name (name)
|
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
||||||
@@ -40,5 +40,13 @@ func NewGormDB(dsn string, zapLogger *zap.Logger, gormLevel string) (*gorm.DB, e
|
|||||||
|
|
||||||
// AutoMigrate runs GORM auto-migration for all models.
|
// AutoMigrate runs GORM auto-migration for all models.
|
||||||
func AutoMigrate(db *gorm.DB) error {
|
func AutoMigrate(db *gorm.DB) error {
|
||||||
return db.AutoMigrate(&model.JobTemplate{})
|
return db.AutoMigrate(
|
||||||
|
&model.Application{},
|
||||||
|
&model.FileBlob{},
|
||||||
|
&model.File{},
|
||||||
|
&model.Folder{},
|
||||||
|
&model.UploadSession{},
|
||||||
|
&model.UploadChunk{},
|
||||||
|
&model.Task{},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
141
internal/store/task_store.go
Normal file
141
internal/store/task_store.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskStore(db *gorm.DB) *TaskStore {
|
||||||
|
return &TaskStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) Create(ctx context.Context, task *model.Task) (int64, error) {
|
||||||
|
if err := s.db.WithContext(ctx).Create(task).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return task.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) GetByID(ctx context.Context, id int64) (*model.Task, error) {
|
||||||
|
var task model.Task
|
||||||
|
err := s.db.WithContext(ctx).First(&task, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &task, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) List(ctx context.Context, query *model.TaskListQuery) ([]model.Task, int64, error) {
|
||||||
|
page := query.Page
|
||||||
|
pageSize := query.PageSize
|
||||||
|
if page <= 0 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
q := s.db.WithContext(ctx).Model(&model.Task{})
|
||||||
|
if query.Status != "" {
|
||||||
|
q = q.Where("status = ?", query.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := q.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tasks []model.Task
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := q.Order("id DESC").Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return tasks, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateStatus(ctx context.Context, id int64, status, errorMsg string) error {
|
||||||
|
updates := map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
"error_message": errorMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
switch status {
|
||||||
|
case model.TaskStatusPreparing, model.TaskStatusDownloading, model.TaskStatusReady,
|
||||||
|
model.TaskStatusQueued, model.TaskStatusRunning:
|
||||||
|
updates["started_at"] = &now
|
||||||
|
case model.TaskStatusCompleted, model.TaskStatusFailed:
|
||||||
|
updates["finished_at"] = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(updates)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateSlurmJobID(ctx context.Context, id int64, slurmJobID *int32) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
|
||||||
|
Update("slurm_job_id", slurmJobID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateWorkDir(ctx context.Context, id int64, workDir string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).
|
||||||
|
Update("work_dir", workDir)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) UpdateRetryState(ctx context.Context, id int64, status, currentStep string, retryCount int) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
"current_step": currentStep,
|
||||||
|
"retry_count": retryCount,
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return fmt.Errorf("task %d not found", id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TaskStore) GetStuckTasks(ctx context.Context, maxAge time.Duration) ([]model.Task, error) {
|
||||||
|
cutoff := time.Now().Add(-maxAge)
|
||||||
|
var tasks []model.Task
|
||||||
|
err := s.db.WithContext(ctx).
|
||||||
|
Where("status NOT IN ?", []string{model.TaskStatusCompleted, model.TaskStatusFailed}).
|
||||||
|
Where("updated_at < ?", cutoff).
|
||||||
|
Find(&tasks).Error
|
||||||
|
return tasks, err
|
||||||
|
}
|
||||||
229
internal/store/task_store_test.go
Normal file
229
internal/store/task_store_test.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTaskTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.Task{}); err != nil {
|
||||||
|
t.Fatalf("auto migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTestTask(name, status string) *model.Task {
|
||||||
|
return &model.Task{
|
||||||
|
TaskName: name,
|
||||||
|
AppID: 1,
|
||||||
|
AppName: "test-app",
|
||||||
|
Status: status,
|
||||||
|
CurrentStep: "",
|
||||||
|
RetryCount: 0,
|
||||||
|
UserID: "user1",
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_CreateAndGetByID(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
task := makeTestTask("test-task", model.TaskStatusSubmitted)
|
||||||
|
id, err := s.Create(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
t.Errorf("Create() id = %d, want positive", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := s.GetByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetByID() returned nil")
|
||||||
|
}
|
||||||
|
if got.TaskName != "test-task" {
|
||||||
|
t.Errorf("TaskName = %q, want %q", got.TaskName, "test-task")
|
||||||
|
}
|
||||||
|
if got.Status != model.TaskStatusSubmitted {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
|
||||||
|
}
|
||||||
|
if got.ID != id {
|
||||||
|
t.Errorf("ID = %d, want %d", got.ID, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_GetByID_NotFound(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
|
||||||
|
got, err := s.GetByID(context.Background(), 999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetByID() expected nil for not-found, got non-nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_ListPagination(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
s.Create(ctx, makeTestTask("task-"+string(rune('A'+i)), model.TaskStatusSubmitted))
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 5 {
|
||||||
|
t.Errorf("total = %d, want 5", total)
|
||||||
|
}
|
||||||
|
if len(tasks) != 3 {
|
||||||
|
t.Errorf("len(tasks) = %d, want 3", len(tasks))
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks2, total2, err := s.List(ctx, &model.TaskListQuery{Page: 2, PageSize: 3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() page 2 error = %v", err)
|
||||||
|
}
|
||||||
|
if total2 != 5 {
|
||||||
|
t.Errorf("total2 = %d, want 5", total2)
|
||||||
|
}
|
||||||
|
if len(tasks2) != 2 {
|
||||||
|
t.Errorf("len(tasks2) = %d, want 2", len(tasks2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_ListStatusFilter(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
s.Create(ctx, makeTestTask("running-1", model.TaskStatusRunning))
|
||||||
|
s.Create(ctx, makeTestTask("running-2", model.TaskStatusRunning))
|
||||||
|
s.Create(ctx, makeTestTask("completed-1", model.TaskStatusCompleted))
|
||||||
|
|
||||||
|
tasks, total, err := s.List(ctx, &model.TaskListQuery{Page: 1, PageSize: 10, Status: model.TaskStatusRunning})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List() error = %v", err)
|
||||||
|
}
|
||||||
|
if total != 2 {
|
||||||
|
t.Errorf("total = %d, want 2", total)
|
||||||
|
}
|
||||||
|
for _, t2 := range tasks {
|
||||||
|
if t2.Status != model.TaskStatusRunning {
|
||||||
|
t.Errorf("Status = %q, want %q", t2.Status, model.TaskStatusRunning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_UpdateStatus(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, _ := s.Create(ctx, makeTestTask("status-test", model.TaskStatusSubmitted))
|
||||||
|
|
||||||
|
err := s.UpdateStatus(ctx, id, model.TaskStatusPreparing, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := s.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusPreparing {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusPreparing)
|
||||||
|
}
|
||||||
|
if got.StartedAt == nil {
|
||||||
|
t.Error("StartedAt expected non-nil after preparing status")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.UpdateStatus(ctx, id, model.TaskStatusFailed, "something broke")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus(failed) error = %v", err)
|
||||||
|
}
|
||||||
|
got, _ = s.GetByID(ctx, id)
|
||||||
|
if got.ErrorMessage != "something broke" {
|
||||||
|
t.Errorf("ErrorMessage = %q, want %q", got.ErrorMessage, "something broke")
|
||||||
|
}
|
||||||
|
if got.FinishedAt == nil {
|
||||||
|
t.Error("FinishedAt expected non-nil after failed status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_UpdateRetryState(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
task := makeTestTask("retry-test", model.TaskStatusFailed)
|
||||||
|
task.CurrentStep = model.TaskStepDownloading
|
||||||
|
task.RetryCount = 1
|
||||||
|
id, _ := s.Create(ctx, task)
|
||||||
|
|
||||||
|
err := s.UpdateRetryState(ctx, id, model.TaskStatusSubmitted, model.TaskStepDownloading, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateRetryState() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := s.GetByID(ctx, id)
|
||||||
|
if got.Status != model.TaskStatusSubmitted {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, model.TaskStatusSubmitted)
|
||||||
|
}
|
||||||
|
if got.CurrentStep != model.TaskStepDownloading {
|
||||||
|
t.Errorf("CurrentStep = %q, want %q", got.CurrentStep, model.TaskStepDownloading)
|
||||||
|
}
|
||||||
|
if got.RetryCount != 2 {
|
||||||
|
t.Errorf("RetryCount = %d, want 2", got.RetryCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskStore_GetStuckTasks(t *testing.T) {
|
||||||
|
db := newTaskTestDB(t)
|
||||||
|
s := NewTaskStore(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
stuck := makeTestTask("stuck-1", model.TaskStatusDownloading)
|
||||||
|
stuck.UpdatedAt = time.Now().Add(-1 * time.Hour)
|
||||||
|
s.Create(ctx, stuck)
|
||||||
|
|
||||||
|
recent := makeTestTask("recent-1", model.TaskStatusDownloading)
|
||||||
|
s.Create(ctx, recent)
|
||||||
|
|
||||||
|
done := makeTestTask("done-1", model.TaskStatusCompleted)
|
||||||
|
done.UpdatedAt = time.Now().Add(-2 * time.Hour)
|
||||||
|
s.Create(ctx, done)
|
||||||
|
|
||||||
|
tasks, err := s.GetStuckTasks(ctx, 30*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetStuckTasks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tasks) != 1 {
|
||||||
|
t.Fatalf("len(tasks) = %d, want 1", len(tasks))
|
||||||
|
}
|
||||||
|
if tasks[0].TaskName != "stuck-1" {
|
||||||
|
t.Errorf("TaskName = %q, want %q", tasks[0].TaskName, "stuck-1")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
package store
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TemplateStore provides CRUD operations for job templates via GORM.
|
|
||||||
type TemplateStore struct {
|
|
||||||
db *gorm.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTemplateStore creates a new TemplateStore.
|
|
||||||
func NewTemplateStore(db *gorm.DB) *TemplateStore {
|
|
||||||
return &TemplateStore{db: db}
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns a paginated list of job templates and the total count.
|
|
||||||
func (s *TemplateStore) List(ctx context.Context, page, pageSize int) ([]model.JobTemplate, int, error) {
|
|
||||||
var templates []model.JobTemplate
|
|
||||||
var total int64
|
|
||||||
|
|
||||||
if err := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Count(&total).Error; err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := (page - 1) * pageSize
|
|
||||||
if err := s.db.WithContext(ctx).Order("id DESC").Limit(pageSize).Offset(offset).Find(&templates).Error; err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return templates, int(total), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByID returns a single job template by ID. Returns nil, nil when not found.
|
|
||||||
func (s *TemplateStore) GetByID(ctx context.Context, id int64) (*model.JobTemplate, error) {
|
|
||||||
var t model.JobTemplate
|
|
||||||
err := s.db.WithContext(ctx).First(&t, id).Error
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create inserts a new job template and returns the generated ID.
|
|
||||||
func (s *TemplateStore) Create(ctx context.Context, req *model.CreateTemplateRequest) (int64, error) {
|
|
||||||
t := &model.JobTemplate{
|
|
||||||
Name: req.Name,
|
|
||||||
Description: req.Description,
|
|
||||||
Script: req.Script,
|
|
||||||
Partition: req.Partition,
|
|
||||||
QOS: req.QOS,
|
|
||||||
CPUs: req.CPUs,
|
|
||||||
Memory: req.Memory,
|
|
||||||
TimeLimit: req.TimeLimit,
|
|
||||||
}
|
|
||||||
if err := s.db.WithContext(ctx).Create(t).Error; err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return t.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update modifies an existing job template. Only non-empty/non-zero fields are updated.
|
|
||||||
func (s *TemplateStore) Update(ctx context.Context, id int64, req *model.UpdateTemplateRequest) error {
|
|
||||||
updates := map[string]interface{}{}
|
|
||||||
if req.Name != "" {
|
|
||||||
updates["name"] = req.Name
|
|
||||||
}
|
|
||||||
if req.Description != "" {
|
|
||||||
updates["description"] = req.Description
|
|
||||||
}
|
|
||||||
if req.Script != "" {
|
|
||||||
updates["script"] = req.Script
|
|
||||||
}
|
|
||||||
if req.Partition != "" {
|
|
||||||
updates["partition"] = req.Partition
|
|
||||||
}
|
|
||||||
if req.QOS != "" {
|
|
||||||
updates["qos"] = req.QOS
|
|
||||||
}
|
|
||||||
if req.CPUs > 0 {
|
|
||||||
updates["cpus"] = req.CPUs
|
|
||||||
}
|
|
||||||
if req.Memory != "" {
|
|
||||||
updates["memory"] = req.Memory
|
|
||||||
}
|
|
||||||
if req.TimeLimit != "" {
|
|
||||||
updates["time_limit"] = req.TimeLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(updates) == 0 {
|
|
||||||
return nil // nothing to update
|
|
||||||
}
|
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Model(&model.JobTemplate{}).Where("id = ?", id).Updates(updates)
|
|
||||||
return result.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete removes a job template by ID. Idempotent — returns nil even if the row doesn't exist.
|
|
||||||
func (s *TemplateStore) Delete(ctx context.Context, id int64) error {
|
|
||||||
result := s.db.WithContext(ctx).Delete(&model.JobTemplate{}, id)
|
|
||||||
if result.Error != nil {
|
|
||||||
return result.Error
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
package store
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
|
|
||||||
"gcy_hpc_server/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestDB(t *testing.T) *gorm.DB {
|
|
||||||
t.Helper()
|
|
||||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
|
||||||
Logger: logger.Default.LogMode(logger.Silent),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("open sqlite: %v", err)
|
|
||||||
}
|
|
||||||
if err := db.AutoMigrate(&model.JobTemplate{}); err != nil {
|
|
||||||
t.Fatalf("auto migrate: %v", err)
|
|
||||||
}
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_List(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-1", Script: "echo 1"})
|
|
||||||
s.Create(context.Background(), &model.CreateTemplateRequest{Name: "job-2", Script: "echo 2"})
|
|
||||||
|
|
||||||
templates, total, err := s.List(context.Background(), 1, 10)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("List() error = %v", err)
|
|
||||||
}
|
|
||||||
if total != 2 {
|
|
||||||
t.Errorf("total = %d, want 2", total)
|
|
||||||
}
|
|
||||||
if len(templates) != 2 {
|
|
||||||
t.Fatalf("len(templates) = %d, want 2", len(templates))
|
|
||||||
}
|
|
||||||
// DESC order, so job-2 is first
|
|
||||||
if templates[0].Name != "job-2" {
|
|
||||||
t.Errorf("templates[0].Name = %q, want %q", templates[0].Name, "job-2")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_List_Page2(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
for i := 0; i < 15; i++ {
|
|
||||||
s.Create(context.Background(), &model.CreateTemplateRequest{
|
|
||||||
Name: "job-" + string(rune('A'+i)), Script: "echo",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
templates, total, err := s.List(context.Background(), 2, 10)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("List() error = %v", err)
|
|
||||||
}
|
|
||||||
if total != 15 {
|
|
||||||
t.Errorf("total = %d, want 15", total)
|
|
||||||
}
|
|
||||||
if len(templates) != 5 {
|
|
||||||
t.Fatalf("len(templates) = %d, want 5", len(templates))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_GetByID(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
|
||||||
Name: "test-job", Script: "echo hi", Partition: "batch", QOS: "normal", CPUs: 2, Memory: "4G",
|
|
||||||
})
|
|
||||||
|
|
||||||
tpl, err := s.GetByID(context.Background(), id)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetByID() error = %v", err)
|
|
||||||
}
|
|
||||||
if tpl == nil {
|
|
||||||
t.Fatal("GetByID() returned nil")
|
|
||||||
}
|
|
||||||
if tpl.Name != "test-job" {
|
|
||||||
t.Errorf("Name = %q, want %q", tpl.Name, "test-job")
|
|
||||||
}
|
|
||||||
if tpl.CPUs != 2 {
|
|
||||||
t.Errorf("CPUs = %d, want 2", tpl.CPUs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_GetByID_NotFound(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
tpl, err := s.GetByID(context.Background(), 999)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetByID() error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
if tpl != nil {
|
|
||||||
t.Fatal("GetByID() should return nil for not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_Create(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
id, err := s.Create(context.Background(), &model.CreateTemplateRequest{
|
|
||||||
Name: "new-job", Script: "echo", Partition: "gpu",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Create() error = %v", err)
|
|
||||||
}
|
|
||||||
if id == 0 {
|
|
||||||
t.Fatal("Create() returned id=0")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_Update(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
|
||||||
Name: "old", Script: "echo",
|
|
||||||
})
|
|
||||||
|
|
||||||
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
|
|
||||||
Name: "updated",
|
|
||||||
Script: "echo new",
|
|
||||||
CPUs: 8,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Update() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tpl, _ := s.GetByID(context.Background(), id)
|
|
||||||
if tpl.Name != "updated" {
|
|
||||||
t.Errorf("Name = %q, want %q", tpl.Name, "updated")
|
|
||||||
}
|
|
||||||
if tpl.CPUs != 8 {
|
|
||||||
t.Errorf("CPUs = %d, want 8", tpl.CPUs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_Update_Partial(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
|
||||||
Name: "original", Script: "echo orig", Partition: "batch",
|
|
||||||
})
|
|
||||||
|
|
||||||
err := s.Update(context.Background(), id, &model.UpdateTemplateRequest{
|
|
||||||
Name: "renamed",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Update() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tpl, _ := s.GetByID(context.Background(), id)
|
|
||||||
if tpl.Name != "renamed" {
|
|
||||||
t.Errorf("Name = %q, want %q", tpl.Name, "renamed")
|
|
||||||
}
|
|
||||||
// Script and Partition should be unchanged
|
|
||||||
if tpl.Script != "echo orig" {
|
|
||||||
t.Errorf("Script = %q, want %q", tpl.Script, "echo orig")
|
|
||||||
}
|
|
||||||
if tpl.Partition != "batch" {
|
|
||||||
t.Errorf("Partition = %q, want %q", tpl.Partition, "batch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_Delete(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
id, _ := s.Create(context.Background(), &model.CreateTemplateRequest{
|
|
||||||
Name: "to-delete", Script: "echo",
|
|
||||||
})
|
|
||||||
|
|
||||||
err := s.Delete(context.Background(), id)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Delete() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tpl, _ := s.GetByID(context.Background(), id)
|
|
||||||
if tpl != nil {
|
|
||||||
t.Fatal("Delete() did not remove the record")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemplateStore_Delete_NotFound(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
s := NewTemplateStore(db)
|
|
||||||
|
|
||||||
err := s.Delete(context.Background(), 999)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Delete() should not error for non-existent record, got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
117
internal/store/upload_store.go
Normal file
117
internal/store/upload_store.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UploadStore manages upload sessions and chunks with idempotent upsert support.
|
||||||
|
type UploadStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUploadStore creates a new UploadStore.
|
||||||
|
func NewUploadStore(db *gorm.DB) *UploadStore {
|
||||||
|
return &UploadStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession inserts a new upload session.
|
||||||
|
func (s *UploadStore) CreateSession(ctx context.Context, session *model.UploadSession) error {
|
||||||
|
return s.db.WithContext(ctx).Create(session).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession returns an upload session by ID. Returns (nil, nil) if not found.
|
||||||
|
func (s *UploadStore) GetSession(ctx context.Context, id int64) (*model.UploadSession, error) {
|
||||||
|
var session model.UploadSession
|
||||||
|
err := s.db.WithContext(ctx).First(&session, id).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionWithChunks returns an upload session along with its chunks ordered by chunk_index.
|
||||||
|
func (s *UploadStore) GetSessionWithChunks(ctx context.Context, id int64) (*model.UploadSession, []model.UploadChunk, error) {
|
||||||
|
session, err := s.GetSession(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []model.UploadChunk
|
||||||
|
if err := s.db.WithContext(ctx).Where("session_id = ?", id).Order("chunk_index ASC").Find(&chunks).Error; err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return session, chunks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSessionStatus updates the status field of an upload session.
|
||||||
|
// Returns gorm.ErrRecordNotFound if no row was affected.
|
||||||
|
func (s *UploadStore) UpdateSessionStatus(ctx context.Context, id int64, status string) error {
|
||||||
|
result := s.db.WithContext(ctx).Model(&model.UploadSession{}).Where("id = ?", id).Update("status", status)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListExpiredSessions returns sessions that are not in a terminal state and have expired.
|
||||||
|
func (s *UploadStore) ListExpiredSessions(ctx context.Context) ([]model.UploadSession, error) {
|
||||||
|
var sessions []model.UploadSession
|
||||||
|
err := s.db.WithContext(ctx).
|
||||||
|
Where("status NOT IN ?", []string{"completed", "cancelled", "expired"}).
|
||||||
|
Where("expires_at < ?", time.Now()).
|
||||||
|
Find(&sessions).Error
|
||||||
|
return sessions, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession removes all chunks for a session, then the session itself.
|
||||||
|
func (s *UploadStore) DeleteSession(ctx context.Context, id int64) error {
|
||||||
|
if err := s.db.WithContext(ctx).Where("session_id = ?", id).Delete(&model.UploadChunk{}).Error; err != nil {
|
||||||
|
return fmt.Errorf("delete chunks: %w", err)
|
||||||
|
}
|
||||||
|
result := s.db.WithContext(ctx).Delete(&model.UploadSession{}, id)
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertChunk inserts a chunk or updates it if the (session_id, chunk_index) pair already exists.
|
||||||
|
// Uses GORM clause.OnConflict for dialect-neutral upsert (works with both SQLite and MySQL).
|
||||||
|
func (s *UploadStore) UpsertChunk(ctx context.Context, chunk *model.UploadChunk) error {
|
||||||
|
return s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "session_id"}, {Name: "chunk_index"}},
|
||||||
|
DoUpdates: clause.AssignmentColumns([]string{"minio_key", "sha256", "size", "status", "updated_at"}),
|
||||||
|
}).Create(chunk).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUploadedChunkIndices returns the chunk indices that have been successfully uploaded.
|
||||||
|
func (s *UploadStore) GetUploadedChunkIndices(ctx context.Context, sessionID int64) ([]int, error) {
|
||||||
|
var indices []int
|
||||||
|
err := s.db.WithContext(ctx).Model(&model.UploadChunk{}).
|
||||||
|
Where("session_id = ? AND status = ?", sessionID, "uploaded").
|
||||||
|
Pluck("chunk_index", &indices).Error
|
||||||
|
return indices, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountUploadedChunks returns the number of chunks with status "uploaded" for a session.
|
||||||
|
func (s *UploadStore) CountUploadedChunks(ctx context.Context, sessionID int64) (int, error) {
|
||||||
|
var count int64
|
||||||
|
err := s.db.WithContext(ctx).Model(&model.UploadChunk{}).
|
||||||
|
Where("session_id = ? AND status = ?", sessionID, "uploaded").
|
||||||
|
Count(&count).Error
|
||||||
|
return int(count), err
|
||||||
|
}
|
||||||
329
internal/store/upload_store_test.go
Normal file
329
internal/store/upload_store_test.go
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupUploadTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open sqlite: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AutoMigrate(&model.UploadSession{}, &model.UploadChunk{}); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestSession(status string, expiresAt time.Time) *model.UploadSession {
|
||||||
|
return &model.UploadSession{
|
||||||
|
FileName: "test.bin",
|
||||||
|
FileSize: 100 * 1024 * 1024,
|
||||||
|
ChunkSize: 16 << 20,
|
||||||
|
TotalChunks: 7,
|
||||||
|
SHA256: "abc123",
|
||||||
|
Status: status,
|
||||||
|
MinioPrefix: "uploads/1/",
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestChunk(sessionID int64, index int, status string) *model.UploadChunk {
|
||||||
|
return &model.UploadChunk{
|
||||||
|
SessionID: sessionID,
|
||||||
|
ChunkIndex: index,
|
||||||
|
MinioKey: "uploads/1/chunk_00000",
|
||||||
|
SHA256: "chunk_hash_0",
|
||||||
|
Size: 16 << 20,
|
||||||
|
Status: status,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_CreateSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("pending", time.Now().Add(48*time.Hour))
|
||||||
|
if err := s.CreateSession(context.Background(), session); err != nil {
|
||||||
|
t.Fatalf("CreateSession() error = %v", err)
|
||||||
|
}
|
||||||
|
if session.ID <= 0 {
|
||||||
|
t.Errorf("ID = %d, want positive", session.ID)
|
||||||
|
}
|
||||||
|
if session.FileName != "test.bin" {
|
||||||
|
t.Errorf("FileName = %q, want %q", session.FileName, "test.bin")
|
||||||
|
}
|
||||||
|
if session.ExpiresAt.IsZero() {
|
||||||
|
t.Error("ExpiresAt is zero, want a real timestamp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_GetSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("pending", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
got, err := s.GetSession(context.Background(), session.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSession() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetSession() returned nil")
|
||||||
|
}
|
||||||
|
if got.FileName != session.FileName {
|
||||||
|
t.Errorf("FileName = %q, want %q", got.FileName, session.FileName)
|
||||||
|
}
|
||||||
|
if got.Status != "pending" {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, "pending")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_GetSession_NotFound(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
got, err := s.GetSession(context.Background(), 99999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSession() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Error("GetSession() expected nil for not-found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_UpdateSessionStatus(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("pending", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
if err := s.UpdateSessionStatus(context.Background(), session.ID, "uploading"); err != nil {
|
||||||
|
t.Fatalf("UpdateSessionStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := s.GetSession(context.Background(), session.ID)
|
||||||
|
if got.Status != "uploading" {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, "uploading")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_UpdateSessionStatus_NotFound(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
err := s.UpdateSessionStatus(context.Background(), 99999, "uploading")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("UpdateSessionStatus() expected error for not-found, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_GetSessionWithChunks(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 0, MinioKey: "uploads/1/chunk_0", SHA256: "h0", Size: 16 << 20, Status: "uploaded",
|
||||||
|
})
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 2, MinioKey: "uploads/1/chunk_2", SHA256: "h2", Size: 16 << 20, Status: "uploaded",
|
||||||
|
})
|
||||||
|
|
||||||
|
gotSession, chunks, err := s.GetSessionWithChunks(context.Background(), session.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionWithChunks() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotSession == nil {
|
||||||
|
t.Fatal("GetSessionWithChunks() session is nil")
|
||||||
|
}
|
||||||
|
if len(chunks) != 2 {
|
||||||
|
t.Fatalf("len(chunks) = %d, want 2", len(chunks))
|
||||||
|
}
|
||||||
|
if chunks[0].ChunkIndex != 0 {
|
||||||
|
t.Errorf("chunks[0].ChunkIndex = %d, want 0", chunks[0].ChunkIndex)
|
||||||
|
}
|
||||||
|
if chunks[1].ChunkIndex != 2 {
|
||||||
|
t.Errorf("chunks[1].ChunkIndex = %d, want 2", chunks[1].ChunkIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_GetSessionWithChunks_NotFound(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
gotSession, chunks, err := s.GetSessionWithChunks(context.Background(), 99999)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionWithChunks() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotSession != nil {
|
||||||
|
t.Error("expected nil session for not-found")
|
||||||
|
}
|
||||||
|
if chunks != nil {
|
||||||
|
t.Error("expected nil chunks for not-found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_UpsertChunk_Idempotent(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
chunk := &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 0, MinioKey: "uploads/1/chunk_0", SHA256: "hash_v1", Size: 1024, Status: "uploaded",
|
||||||
|
}
|
||||||
|
if err := s.UpsertChunk(context.Background(), chunk); err != nil {
|
||||||
|
t.Fatalf("first UpsertChunk() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk2 := &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 0, MinioKey: "uploads/1/chunk_0", SHA256: "hash_v2", Size: 2048, Status: "uploaded",
|
||||||
|
}
|
||||||
|
if err := s.UpsertChunk(context.Background(), chunk2); err != nil {
|
||||||
|
t.Fatalf("second UpsertChunk() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
indices, _ := s.GetUploadedChunkIndices(context.Background(), session.ID)
|
||||||
|
if len(indices) != 1 {
|
||||||
|
t.Errorf("len(indices) = %d, want 1 (idempotent)", len(indices))
|
||||||
|
}
|
||||||
|
|
||||||
|
var got model.UploadChunk
|
||||||
|
db.Where("session_id = ? AND chunk_index = ?", session.ID, 0).First(&got)
|
||||||
|
if got.SHA256 != "hash_v2" {
|
||||||
|
t.Errorf("SHA256 = %q, want %q (updated on conflict)", got.SHA256, "hash_v2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_GetUploadedChunkIndices(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 0, MinioKey: "k0", Size: 100, Status: "uploaded",
|
||||||
|
})
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 1, MinioKey: "k1", Size: 100, Status: "pending",
|
||||||
|
})
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 2, MinioKey: "k2", Size: 100, Status: "uploaded",
|
||||||
|
})
|
||||||
|
|
||||||
|
indices, err := s.GetUploadedChunkIndices(context.Background(), session.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUploadedChunkIndices() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(indices) != 2 {
|
||||||
|
t.Fatalf("len(indices) = %d, want 2", len(indices))
|
||||||
|
}
|
||||||
|
if indices[0] != 0 || indices[1] != 2 {
|
||||||
|
t.Errorf("indices = %v, want [0 2]", indices)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_CountUploadedChunks(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 0, MinioKey: "k0", Size: 100, Status: "uploaded",
|
||||||
|
})
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 1, MinioKey: "k1", Size: 100, Status: "uploaded",
|
||||||
|
})
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 2, MinioKey: "k2", Size: 100, Status: "pending",
|
||||||
|
})
|
||||||
|
|
||||||
|
count, err := s.CountUploadedChunks(context.Background(), session.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CountUploadedChunks() error = %v", err)
|
||||||
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("count = %d, want 2", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_ListExpiredSessions(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
past := time.Now().Add(-1 * time.Hour)
|
||||||
|
future := time.Now().Add(48 * time.Hour)
|
||||||
|
|
||||||
|
expired := newTestSession("pending", past)
|
||||||
|
expired.FileName = "expired.bin"
|
||||||
|
s.CreateSession(context.Background(), expired)
|
||||||
|
|
||||||
|
active := newTestSession("uploading", future)
|
||||||
|
active.FileName = "active.bin"
|
||||||
|
s.CreateSession(context.Background(), active)
|
||||||
|
|
||||||
|
completed := newTestSession("completed", past)
|
||||||
|
completed.FileName = "completed.bin"
|
||||||
|
s.CreateSession(context.Background(), completed)
|
||||||
|
|
||||||
|
sessions, err := s.ListExpiredSessions(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListExpiredSessions() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len(sessions) = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].FileName != "expired.bin" {
|
||||||
|
t.Errorf("FileName = %q, want %q", sessions[0].FileName, "expired.bin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStore_DeleteSession(t *testing.T) {
|
||||||
|
db := setupUploadTestDB(t)
|
||||||
|
s := NewUploadStore(db)
|
||||||
|
|
||||||
|
session := newTestSession("uploading", time.Now().Add(48*time.Hour))
|
||||||
|
s.CreateSession(context.Background(), session)
|
||||||
|
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 0, MinioKey: "k0", Size: 100, Status: "uploaded",
|
||||||
|
})
|
||||||
|
s.UpsertChunk(context.Background(), &model.UploadChunk{
|
||||||
|
SessionID: session.ID, ChunkIndex: 1, MinioKey: "k1", Size: 100, Status: "uploaded",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := s.DeleteSession(context.Background(), session.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteSession() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := s.GetSession(context.Background(), session.ID)
|
||||||
|
if got != nil {
|
||||||
|
t.Error("session still exists after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunkCount int64
|
||||||
|
db.Model(&model.UploadChunk{}).Where("session_id = ?", session.ID).Count(&chunkCount)
|
||||||
|
if chunkCount != 0 {
|
||||||
|
t.Errorf("chunkCount = %d, want 0 after delete", chunkCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
239
internal/testutil/mockminio/storage.go
Normal file
239
internal/testutil/mockminio/storage.go
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
// Package mockminio provides an in-memory implementation of storage.ObjectStorage
|
||||||
|
// for use in tests. It is thread-safe and supports Range reads.
|
||||||
|
package mockminio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gcy_hpc_server/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Compile-time interface check.
|
||||||
|
var _ storage.ObjectStorage = (*InMemoryStorage)(nil)
|
||||||
|
|
||||||
|
// objectMeta holds metadata for a stored object.
|
||||||
|
type objectMeta struct {
|
||||||
|
size int64
|
||||||
|
etag string
|
||||||
|
lastModified time.Time
|
||||||
|
contentType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// InMemoryStorage is a thread-safe, in-memory implementation of
|
||||||
|
// storage.ObjectStorage. All data is kept in memory; no network or disk I/O
|
||||||
|
// is performed.
|
||||||
|
type InMemoryStorage struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
objects map[string][]byte
|
||||||
|
meta map[string]objectMeta
|
||||||
|
buckets map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInMemoryStorage returns a ready-to-use InMemoryStorage.
|
||||||
|
func NewInMemoryStorage() *InMemoryStorage {
|
||||||
|
return &InMemoryStorage{
|
||||||
|
objects: make(map[string][]byte),
|
||||||
|
meta: make(map[string]objectMeta),
|
||||||
|
buckets: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutObject reads all bytes from reader and stores them under key.
|
||||||
|
// The ETag is the SHA-256 hash of the data, formatted as hex.
|
||||||
|
func (s *InMemoryStorage) PutObject(_ context.Context, _, key string, reader io.Reader, _ int64, opts storage.PutObjectOptions) (storage.UploadInfo, error) {
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return storage.UploadInfo{}, fmt.Errorf("read all: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
etag := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.objects[key] = data
|
||||||
|
s.meta[key] = objectMeta{
|
||||||
|
size: int64(len(data)),
|
||||||
|
etag: etag,
|
||||||
|
lastModified: time.Now(),
|
||||||
|
contentType: opts.ContentType,
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return storage.UploadInfo{ETag: etag, Size: int64(len(data))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetObject retrieves an object. opts.Start and opts.End control byte-range
|
||||||
|
// reads. Four cases are supported:
|
||||||
|
// 1. No range (both nil) → return entire object
|
||||||
|
// 2. Start only (End nil) → from start to end of object
|
||||||
|
// 3. End only (Start nil) → from byte 0 to end
|
||||||
|
// 4. Start + End → standard byte range
|
||||||
|
func (s *InMemoryStorage) GetObject(_ context.Context, _, key string, opts storage.GetOptions) (io.ReadCloser, storage.ObjectInfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
data, ok := s.objects[key]
|
||||||
|
meta := s.meta[key]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := int64(len(data))
|
||||||
|
|
||||||
|
// Full object info (Size is always the total object size).
|
||||||
|
info := storage.ObjectInfo{
|
||||||
|
Key: key,
|
||||||
|
Size: size,
|
||||||
|
ETag: meta.etag,
|
||||||
|
LastModified: meta.lastModified,
|
||||||
|
ContentType: meta.contentType,
|
||||||
|
}
|
||||||
|
|
||||||
|
// No range requested → return everything.
|
||||||
|
if opts.Start == nil && opts.End == nil {
|
||||||
|
return io.NopCloser(bytes.NewReader(data)), info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build range. Check each pointer individually to avoid nil dereference.
|
||||||
|
start := int64(0)
|
||||||
|
if opts.Start != nil {
|
||||||
|
start = *opts.Start
|
||||||
|
}
|
||||||
|
|
||||||
|
end := size - 1
|
||||||
|
if opts.End != nil {
|
||||||
|
end = *opts.End
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamp end to last byte.
|
||||||
|
if end >= size {
|
||||||
|
end = size - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if start > end || start < 0 {
|
||||||
|
return nil, storage.ObjectInfo{}, fmt.Errorf("invalid range: start=%d, end=%d, size=%d", start, end, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
section := io.NewSectionReader(bytes.NewReader(data), start, end-start+1)
|
||||||
|
return io.NopCloser(section), info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComposeObject concatenates source objects (in order) into dst.
|
||||||
|
func (s *InMemoryStorage) ComposeObject(_ context.Context, _, dst string, sources []string) (storage.UploadInfo, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for _, src := range sources {
|
||||||
|
data, ok := s.objects[src]
|
||||||
|
if !ok {
|
||||||
|
return storage.UploadInfo{}, fmt.Errorf("source object %s not found", src)
|
||||||
|
}
|
||||||
|
buf.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
combined := buf.Bytes()
|
||||||
|
h := sha256.Sum256(combined)
|
||||||
|
etag := hex.EncodeToString(h[:])
|
||||||
|
|
||||||
|
s.objects[dst] = combined
|
||||||
|
s.meta[dst] = objectMeta{
|
||||||
|
size: int64(len(combined)),
|
||||||
|
etag: etag,
|
||||||
|
lastModified: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.UploadInfo{ETag: etag, Size: int64(len(combined))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveObject deletes a single object.
|
||||||
|
func (s *InMemoryStorage) RemoveObject(_ context.Context, _, key string, _ storage.RemoveObjectOptions) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.objects, key)
|
||||||
|
delete(s.meta, key)
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveObjects deletes multiple objects by key.
|
||||||
|
func (s *InMemoryStorage) RemoveObjects(_ context.Context, _ string, keys []string, _ storage.RemoveObjectsOptions) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, k := range keys {
|
||||||
|
delete(s.objects, k)
|
||||||
|
delete(s.meta, k)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListObjects returns object info for all objects matching prefix, sorted by key.
|
||||||
|
func (s *InMemoryStorage) ListObjects(_ context.Context, _, prefix string, _ bool) ([]storage.ObjectInfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
var result []storage.ObjectInfo
|
||||||
|
for k, m := range s.meta {
|
||||||
|
if strings.HasPrefix(k, prefix) {
|
||||||
|
result = append(result, storage.ObjectInfo{
|
||||||
|
Key: k,
|
||||||
|
Size: m.size,
|
||||||
|
ETag: m.etag,
|
||||||
|
LastModified: m.lastModified,
|
||||||
|
ContentType: m.contentType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(result, func(i, j int) bool { return result[i].Key < result[j].Key })
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BucketExists reports whether the named bucket exists.
|
||||||
|
func (s *InMemoryStorage) BucketExists(_ context.Context, bucket string) (bool, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.buckets[bucket], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeBucket creates a bucket.
|
||||||
|
func (s *InMemoryStorage) MakeBucket(_ context.Context, bucket string, _ storage.MakeBucketOptions) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.buckets[bucket] = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatObject returns metadata about an object without downloading it.
|
||||||
|
func (s *InMemoryStorage) StatObject(_ context.Context, _, key string, _ storage.StatObjectOptions) (storage.ObjectInfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
m, ok := s.meta[key]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return storage.ObjectInfo{}, fmt.Errorf("object %s not found", key)
|
||||||
|
}
|
||||||
|
return storage.ObjectInfo{
|
||||||
|
Key: key,
|
||||||
|
Size: m.size,
|
||||||
|
ETag: m.etag,
|
||||||
|
LastModified: m.lastModified,
|
||||||
|
ContentType: m.contentType,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AbortMultipartUpload is a no-op for the in-memory implementation.
|
||||||
|
func (s *InMemoryStorage) AbortMultipartUpload(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveIncompleteUpload is a no-op for the in-memory implementation.
|
||||||
|
func (s *InMemoryStorage) RemoveIncompleteUpload(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user