357 lines
11 KiB
Go
357 lines
11 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"gcy_hpc_server/internal/model"
|
|
"gcy_hpc_server/internal/slurm"
|
|
"gcy_hpc_server/internal/store"
|
|
|
|
"go.uber.org/zap"
|
|
gormlogger "gorm.io/gorm/logger"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func setupApplicationService(t *testing.T, slurmHandler http.HandlerFunc) (*ApplicationService, func()) {
|
|
t.Helper()
|
|
srv := httptest.NewServer(slurmHandler)
|
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
|
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.Application{}); err != nil {
|
|
t.Fatalf("auto migrate: %v", err)
|
|
}
|
|
|
|
jobSvc := NewJobService(client, zap.NewNop())
|
|
appStore := store.NewApplicationStore(db)
|
|
appSvc := NewApplicationService(appStore, jobSvc, "", zap.NewNop())
|
|
|
|
return appSvc, srv.Close
|
|
}
|
|
|
|
func TestValidateParams_AllRequired(t *testing.T) {
|
|
params := []model.ParameterSchema{
|
|
{Name: "NAME", Type: model.ParamTypeString, Required: true},
|
|
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
|
|
}
|
|
values := map[string]string{"NAME": "hello", "COUNT": "5"}
|
|
if err := ValidateParams(params, values); err != nil {
|
|
t.Errorf("expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateParams_MissingRequired(t *testing.T) {
|
|
params := []model.ParameterSchema{
|
|
{Name: "NAME", Type: model.ParamTypeString, Required: true},
|
|
}
|
|
values := map[string]string{}
|
|
err := ValidateParams(params, values)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing required param")
|
|
}
|
|
if !strings.Contains(err.Error(), "NAME") {
|
|
t.Errorf("error should mention param name, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateParams_InvalidInteger(t *testing.T) {
|
|
params := []model.ParameterSchema{
|
|
{Name: "COUNT", Type: model.ParamTypeInteger, Required: true},
|
|
}
|
|
values := map[string]string{"COUNT": "abc"}
|
|
err := ValidateParams(params, values)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid integer")
|
|
}
|
|
if !strings.Contains(err.Error(), "integer") {
|
|
t.Errorf("error should mention integer, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateParams_InvalidEnum(t *testing.T) {
|
|
params := []model.ParameterSchema{
|
|
{Name: "MODE", Type: model.ParamTypeEnum, Required: true, Options: []string{"fast", "slow"}},
|
|
}
|
|
values := map[string]string{"MODE": "medium"}
|
|
err := ValidateParams(params, values)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid enum value")
|
|
}
|
|
if !strings.Contains(err.Error(), "MODE") {
|
|
t.Errorf("error should mention param name, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateParams_BooleanValues(t *testing.T) {
|
|
params := []model.ParameterSchema{
|
|
{Name: "FLAG", Type: model.ParamTypeBoolean, Required: true},
|
|
}
|
|
for _, val := range []string{"true", "false", "1", "0"} {
|
|
err := ValidateParams(params, map[string]string{"FLAG": val})
|
|
if err != nil {
|
|
t.Errorf("boolean value %q should be valid, got error: %v", val, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRenderScript_SimpleReplacement(t *testing.T) {
|
|
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
|
values := map[string]string{"INPUT": "data.txt"}
|
|
result := RenderScript("echo $INPUT", params, values)
|
|
expected := "echo 'data.txt'"
|
|
if result != expected {
|
|
t.Errorf("got %q, want %q", result, expected)
|
|
}
|
|
}
|
|
|
|
func TestRenderScript_DefaultValues(t *testing.T) {
|
|
params := []model.ParameterSchema{{Name: "OUTPUT", Type: model.ParamTypeString, Default: "out.log"}}
|
|
values := map[string]string{}
|
|
result := RenderScript("cat $OUTPUT", params, values)
|
|
expected := "cat 'out.log'"
|
|
if result != expected {
|
|
t.Errorf("got %q, want %q", result, expected)
|
|
}
|
|
}
|
|
|
|
func TestRenderScript_PreservesUnknownVars(t *testing.T) {
|
|
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
|
values := map[string]string{"INPUT": "data.txt"}
|
|
result := RenderScript("export HOME=$HOME\necho $INPUT\necho $PATH", params, values)
|
|
if !strings.Contains(result, "$HOME") {
|
|
t.Error("$HOME should be preserved")
|
|
}
|
|
if !strings.Contains(result, "$PATH") {
|
|
t.Error("$PATH should be preserved")
|
|
}
|
|
if !strings.Contains(result, "'data.txt'") {
|
|
t.Error("$INPUT should be replaced")
|
|
}
|
|
}
|
|
|
|
func TestRenderScript_ShellEscaping(t *testing.T) {
|
|
params := []model.ParameterSchema{{Name: "INPUT", Type: model.ParamTypeString}}
|
|
|
|
tests := []struct {
|
|
name string
|
|
value string
|
|
expected string
|
|
}{
|
|
{"semicolon injection", "; rm -rf /", "'; rm -rf /'"},
|
|
{"command substitution", "$(cat /etc/passwd)", "'$(cat /etc/passwd)'"},
|
|
{"single quote", "hello'world", "'hello'\\''world'"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := RenderScript("$INPUT", params, map[string]string{"INPUT": tt.value})
|
|
if result != tt.expected {
|
|
t.Errorf("got %q, want %q", result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRenderScript_OverlappingParams(t *testing.T) {
|
|
template := "$JOB_NAME and $JOB"
|
|
params := []model.ParameterSchema{
|
|
{Name: "JOB", Type: model.ParamTypeString},
|
|
{Name: "JOB_NAME", Type: model.ParamTypeString},
|
|
}
|
|
values := map[string]string{"JOB": "myjob", "JOB_NAME": "my-test-job"}
|
|
result := RenderScript(template, params, values)
|
|
if strings.Contains(result, "$JOB_NAME") {
|
|
t.Error("$JOB_NAME was not replaced")
|
|
}
|
|
if strings.Contains(result, "$JOB") {
|
|
t.Error("$JOB was not replaced")
|
|
}
|
|
if !strings.Contains(result, "'my-test-job'") {
|
|
t.Errorf("expected 'my-test-job' in result, got: %s", result)
|
|
}
|
|
if !strings.Contains(result, "'myjob'") {
|
|
t.Errorf("expected 'myjob' in result, got: %s", result)
|
|
}
|
|
}
|
|
|
|
func TestRenderScript_NewlineInValue(t *testing.T) {
|
|
params := []model.ParameterSchema{{Name: "CMD", Type: model.ParamTypeString}}
|
|
values := map[string]string{"CMD": "line1\nline2"}
|
|
result := RenderScript("echo $CMD", params, values)
|
|
expected := "echo 'line1\nline2'"
|
|
if result != expected {
|
|
t.Errorf("got %q, want %q", result, expected)
|
|
}
|
|
}
|
|
|
|
func TestSubmitFromApplication_Success(t *testing.T) {
|
|
jobID := int32(42)
|
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
|
})
|
|
}))
|
|
defer cleanup()
|
|
|
|
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
|
|
Name: "test-app",
|
|
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
|
|
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create app: %v", err)
|
|
}
|
|
|
|
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
|
|
"JOB_NAME": "my-job",
|
|
"INPUT": "hello",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SubmitFromApplication() error = %v", err)
|
|
}
|
|
if resp.JobID != 42 {
|
|
t.Errorf("JobID = %d, want 42", resp.JobID)
|
|
}
|
|
}
|
|
|
|
func TestSubmitFromApplication_AppNotFound(t *testing.T) {
|
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer cleanup()
|
|
|
|
_, err := appSvc.SubmitFromApplication(context.Background(), 99999, map[string]string{})
|
|
if err == nil {
|
|
t.Fatal("expected error for non-existent app")
|
|
}
|
|
if !strings.Contains(err.Error(), "not found") {
|
|
t.Errorf("error should mention 'not found', got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSubmitFromApplication_ValidationFail(t *testing.T) {
|
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer cleanup()
|
|
|
|
_, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
|
|
Name: "valid-app",
|
|
ScriptTemplate: "#!/bin/bash\necho $INPUT",
|
|
Parameters: json.RawMessage(`[{"name":"INPUT","type":"string","required":true}]`),
|
|
})
|
|
|
|
_, err = appSvc.SubmitFromApplication(context.Background(), 1, map[string]string{})
|
|
if err == nil {
|
|
t.Fatal("expected validation error for missing required param")
|
|
}
|
|
if !strings.Contains(err.Error(), "missing") {
|
|
t.Errorf("error should mention 'missing', got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSubmitFromApplication_NoParameters(t *testing.T) {
|
|
jobID := int32(99)
|
|
appSvc, cleanup := setupApplicationService(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
|
})
|
|
}))
|
|
defer cleanup()
|
|
|
|
id, err := appSvc.store.Create(context.Background(), &model.CreateApplicationRequest{
|
|
Name: "simple-app",
|
|
ScriptTemplate: "#!/bin/bash\necho hello",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create app: %v", err)
|
|
}
|
|
|
|
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{})
|
|
if err != nil {
|
|
t.Fatalf("SubmitFromApplication() error = %v", err)
|
|
}
|
|
if resp.JobID != 99 {
|
|
t.Errorf("JobID = %d, want 99", resp.JobID)
|
|
}
|
|
}
|
|
|
|
func TestSubmitFromApplication_DelegatesToTaskService(t *testing.T) {
|
|
jobID := int32(77)
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(slurm.OpenapiJobSubmitResponse{
|
|
Result: &slurm.JobSubmitResponseMsg{JobID: &jobID},
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client, _ := slurm.NewClient(srv.URL, srv.Client())
|
|
jobSvc := NewJobService(client, zap.NewNop())
|
|
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
|
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite: %v", err)
|
|
}
|
|
if err := db.AutoMigrate(&model.Application{}, &model.Task{}, &model.File{}, &model.FileBlob{}); err != nil {
|
|
t.Fatalf("auto migrate: %v", err)
|
|
}
|
|
|
|
appStore := store.NewApplicationStore(db)
|
|
taskStore := store.NewTaskStore(db)
|
|
fileStore := store.NewFileStore(db)
|
|
blobStore := store.NewBlobStore(db)
|
|
|
|
workDirBase := filepath.Join(t.TempDir(), "workdir")
|
|
os.MkdirAll(workDirBase, 0777)
|
|
|
|
taskSvc := NewTaskService(taskStore, appStore, fileStore, blobStore, nil, jobSvc, workDirBase, zap.NewNop())
|
|
appSvc := NewApplicationService(appStore, jobSvc, workDirBase, zap.NewNop(), taskSvc)
|
|
|
|
id, err := appStore.Create(context.Background(), &model.CreateApplicationRequest{
|
|
Name: "delegated-app",
|
|
ScriptTemplate: "#!/bin/bash\n#SBATCH --job-name=$JOB_NAME\necho $INPUT",
|
|
Parameters: json.RawMessage(`[{"name":"JOB_NAME","type":"string","required":true},{"name":"INPUT","type":"string","required":true}]`),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create app: %v", err)
|
|
}
|
|
|
|
resp, err := appSvc.SubmitFromApplication(context.Background(), id, map[string]string{
|
|
"JOB_NAME": "delegated-job",
|
|
"INPUT": "test-data",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SubmitFromApplication() error = %v", err)
|
|
}
|
|
if resp.JobID != 77 {
|
|
t.Errorf("JobID = %d, want 77", resp.JobID)
|
|
}
|
|
|
|
var task model.Task
|
|
if err := db.Where("app_id = ?", id).First(&task).Error; err != nil {
|
|
t.Fatalf("no hpc_tasks record found for app_id %d: %v", id, err)
|
|
}
|
|
if task.SlurmJobID == nil || *task.SlurmJobID != 77 {
|
|
t.Errorf("task SlurmJobID = %v, want 77", task.SlurmJobID)
|
|
}
|
|
if task.Status != model.TaskStatusQueued {
|
|
t.Errorf("task Status = %q, want %q", task.Status, model.TaskStatusQueued)
|
|
}
|
|
}
|