diff --git a/internal/slurm/options.go b/internal/slurm/options.go new file mode 100644 index 0000000..e926d2c --- /dev/null +++ b/internal/slurm/options.go @@ -0,0 +1,100 @@ +package slurm + +import ( + "fmt" + "net/http" + "time" +) + +// ClientOption configures a Client via functional options. +type ClientOption func(*clientConfig) error + +type clientConfig struct { + jwtKeyPath string + username string + ttl time.Duration + leeway time.Duration + httpClient *http.Client +} + +func defaultClientConfig() *clientConfig { + return &clientConfig{ + ttl: 30 * time.Minute, + leeway: 30 * time.Second, + } +} + +// WithJWTKey specifies the path to the JWT key file. +func WithJWTKey(path string) ClientOption { + return func(c *clientConfig) error { + c.jwtKeyPath = path + return nil + } +} + +// WithUsername specifies the Slurm username for JWT authentication. +func WithUsername(username string) ClientOption { + return func(c *clientConfig) error { + c.username = username + return nil + } +} + +// WithTokenTTL sets the JWT token time-to-live (default: 30 minutes). +func WithTokenTTL(ttl time.Duration) ClientOption { + return func(c *clientConfig) error { + c.ttl = ttl + return nil + } +} + +// WithTokenLeeway sets the JWT token refresh leeway (default: 30 seconds). +func WithTokenLeeway(leeway time.Duration) ClientOption { + return func(c *clientConfig) error { + c.leeway = leeway + return nil + } +} + +// WithHTTPClient specifies a custom HTTP client. +func WithHTTPClient(client *http.Client) ClientOption { + return func(c *clientConfig) error { + c.httpClient = client + return nil + } +} + +// NewClientWithOpts creates a new Slurm API client using functional options. +// If WithJWTKey and WithUsername are provided, JWT authentication is configured +// automatically. If no JWT options are provided, http.DefaultClient is used. +func NewClientWithOpts(baseURL string, opts ...ClientOption) (*Client, error) { + cfg := defaultClientConfig() + for _, opt := range opts { + if err := opt(cfg); err != nil { + return nil, err + } + } + + var httpClient *http.Client + + if cfg.jwtKeyPath != "" && cfg.username != "" { + key, err := ReadJWTKey(cfg.jwtKeyPath) + if err != nil { + return nil, fmt.Errorf("read JWT key: %w", err) + } + + transportOpts := []JWTTransportOption{ + WithTTL(cfg.ttl), + WithLeeway(cfg.leeway), + } + + tr := NewJWTAuthTransport(cfg.username, key, transportOpts...) + httpClient = tr.Client() + } else if cfg.httpClient != nil { + httpClient = cfg.httpClient + } else { + httpClient = http.DefaultClient + } + + return NewClient(baseURL, httpClient) +} diff --git a/internal/slurm/options_test.go b/internal/slurm/options_test.go new file mode 100644 index 0000000..e9d5cc7 --- /dev/null +++ b/internal/slurm/options_test.go @@ -0,0 +1,154 @@ +package slurm + +import ( + "crypto/rand" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +func TestNewClientWithOpts_JWTKey_Success(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatalf("generate key: %v", err) + } + dir := t.TempDir() + keyPath := filepath.Join(dir, "jwt.key") + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("write key file: %v", err) + } + + client, err := NewClientWithOpts("http://localhost:6820/", + WithJWTKey(keyPath), + WithUsername("testuser"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } + + transport, ok := client.client.Transport.(*JWTAuthTransport) + if !ok { + t.Fatalf("expected *JWTAuthTransport, got %T", client.client.Transport) + } + if transport.UserName != "testuser" { + t.Errorf("expected username %q, got %q", "testuser", transport.UserName) + } +} + +func TestNewClientWithOpts_InvalidKeyPath_Error(t *testing.T) { + client, err := NewClientWithOpts("http://localhost:6820/", + WithJWTKey("/nonexistent/key"), + WithUsername("testuser"), + ) + if err == nil { + t.Fatal("expected error for invalid key path, got nil") + } + if client != nil { + t.Fatal("expected nil client on error") + } +} + +func TestNewClientWithOpts_BackwardCompatible(t *testing.T) { + client, err := NewClientWithOpts("http://localhost:6820/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } + if client.client != http.DefaultClient { + t.Error("expected http.DefaultClient when no options provided") + } +} + +func TestNewClientWithOpts_AllServicesInitialized(t *testing.T) { + client, err := NewClientWithOpts("http://localhost:6820/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + services := []struct { + name string + svc interface{} + }{ + {"Jobs", client.Jobs}, + {"Nodes", client.Nodes}, + {"Partitions", client.Partitions}, + {"Reservations", client.Reservations}, + {"Diag", client.Diag}, + {"Ping", client.Ping}, + {"Licenses", client.Licenses}, + {"Reconfigure", client.Reconfigure}, + {"Shares", client.Shares}, + {"SlurmdbDiag", client.SlurmdbDiag}, + {"SlurmdbConfig", client.SlurmdbConfig}, + {"SlurmdbTres", client.SlurmdbTres}, + {"SlurmdbQos", client.SlurmdbQos}, + {"SlurmdbAssocs", client.SlurmdbAssocs}, + {"SlurmdbInstances", client.SlurmdbInstances}, + {"SlurmdbUsers", client.SlurmdbUsers}, + {"SlurmdbClusters", client.SlurmdbClusters}, + {"SlurmdbWckeys", client.SlurmdbWckeys}, + {"SlurmdbAccounts", client.SlurmdbAccounts}, + {"SlurmdbJobs", client.SlurmdbJobs}, + } + + for _, s := range services { + if s.svc == nil { + t.Errorf("%s service is nil", s.name) + } + } +} + +func TestWithTokenTTL_Custom(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatalf("generate key: %v", err) + } + dir := t.TempDir() + keyPath := filepath.Join(dir, "jwt.key") + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("write key file: %v", err) + } + + client, err := NewClientWithOpts("http://localhost:6820/", + WithJWTKey(keyPath), + WithUsername("testuser"), + WithTokenTTL(1*time.Hour), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } +} + +func TestWithTokenLeeway_Custom(t *testing.T) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatalf("generate key: %v", err) + } + dir := t.TempDir() + keyPath := filepath.Join(dir, "jwt.key") + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("write key file: %v", err) + } + + client, err := NewClientWithOpts("http://localhost:6820/", + WithJWTKey(keyPath), + WithUsername("testuser"), + WithTokenLeeway(1*time.Minute), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } +} diff --git a/internal/slurm/slurm.go b/internal/slurm/slurm.go index 34d2522..df940d9 100644 --- a/internal/slurm/slurm.go +++ b/internal/slurm/slurm.go @@ -3,7 +3,7 @@ // The client handles authentication via X-SLURM-USER-NAME and X-SLURM-USER-TOKEN // headers, request/response marshaling, and error handling. // -// Basic usage: +// Static token authentication: // // httpClient := &http.Client{ // Transport: &slurm.TokenAuthTransport{ @@ -15,4 +15,14 @@ // if err != nil { // log.Fatal(err) // } +// +// JWT authentication (auto-signed from local key): +// +// client, err := slurm.NewClientWithOpts("http://localhost:6820", +// slurm.WithJWTKey("/etc/slurm/jwt/slurm_jwt.key"), +// slurm.WithUsername("slurmapi"), +// ) +// if err != nil { +// log.Fatal(err) +// } package slurm