Compare commits

..

4 Commits

Author SHA1 Message Date
dailz
246c19c052 feat(client): add functional option pattern for JWT auth config
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-09 10:33:31 +08:00
dailz
f8119ff9e5 feat(auth): add JWTAuthTransport with auto-refresh RoundTrip
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-09 10:31:39 +08:00
dailz
2dcbfb95b0 feat(auth): add token cache with thread-safe auto-refresh
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-09 10:31:27 +08:00
dailz
49cbea948a feat(jwt): add HS256 JWT signing with stdlib-only implementation
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-04-09 10:31:16 +08:00
10 changed files with 1014 additions and 1 deletions

62
internal/slurm/jwt.go Normal file
View File

@@ -0,0 +1,62 @@
package slurm
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"time"
)
type JWTClaims struct {
Sun string `json:"sun"`
IAT int64 `json:"iat"`
EXP int64 `json:"exp"`
}
func SignJWT(key []byte, username string, lifespan time.Duration) (string, error) {
if username == "" {
return "", fmt.Errorf("username must not be empty")
}
now := time.Now()
header := map[string]string{"alg": "HS256", "typ": "JWT"}
claims := JWTClaims{
Sun: username,
IAT: now.Unix(),
EXP: now.Add(lifespan).Unix(),
}
headerJSON, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("marshal header: %w", err)
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal claims: %w", err)
}
enc := base64.RawURLEncoding
headerEnc := enc.EncodeToString(headerJSON)
claimsEnc := enc.EncodeToString(claimsJSON)
signingInput := headerEnc + "." + claimsEnc
mac := hmac.New(sha256.New, key)
mac.Write([]byte(signingInput))
sig := enc.EncodeToString(mac.Sum(nil))
return signingInput + "." + sig, nil
}
func ReadJWTKey(path string) ([]byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if len(data) < 16 {
return nil, fmt.Errorf("key must be at least 16 bytes, got %d", len(data))
}
return data, nil
}

169
internal/slurm/jwt_test.go Normal file
View File

@@ -0,0 +1,169 @@
package slurm
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"os"
"strings"
"testing"
"time"
)
func TestSignJWT_ValidKey_ProducesValidToken(t *testing.T) {
key := []byte("0123456789abcdef0123456789abcdef") // 32 bytes
username := "testuser"
lifespan := 1 * time.Hour
token, err := SignJWT(key, username, lifespan)
if err != nil {
t.Fatalf("SignJWT returned error: %v", err)
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
t.Fatalf("expected 3 parts, got %d", len(parts))
}
// Decode header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
t.Fatalf("decode header: %v", err)
}
var header map[string]string
if err := json.Unmarshal(headerBytes, &header); err != nil {
t.Fatalf("unmarshal header: %v", err)
}
if header["alg"] != "HS256" || header["typ"] != "JWT" {
t.Errorf("header = %v, want alg=HS256 typ=JWT", header)
}
// Decode payload
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("decode payload: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
if claims.Sun != username {
t.Errorf("sun = %q, want %q", claims.Sun, username)
}
if claims.EXP-claims.IAT != 3600 {
t.Errorf("exp - iat = %d, want 3600", claims.EXP-claims.IAT)
}
// Verify signature
signingInput := parts[0] + "." + parts[1]
mac := hmac.New(sha256.New, key)
mac.Write([]byte(signingInput))
expectedSig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
if parts[2] != expectedSig {
t.Errorf("signature mismatch: got %q, want %q", parts[2], expectedSig)
}
}
func TestSignJWT_UsesRawURLEncoding(t *testing.T) {
key := []byte("0123456789abcdef0123456789abcdef")
token, err := SignJWT(key, "testuser", 1*time.Hour)
if err != nil {
t.Fatalf("SignJWT returned error: %v", err)
}
for i, part := range strings.Split(token, ".") {
if strings.Contains(part, "=") {
t.Errorf("part %d contains padding '=': %q", i, part)
}
}
}
func TestSignJWT_ExpiredToken(t *testing.T) {
key := []byte("0123456789abcdef0123456789abcdef")
token, err := SignJWT(key, "testuser", -1*time.Hour)
if err != nil {
t.Fatalf("SignJWT returned error: %v", err)
}
parts := strings.Split(token, ".")
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("decode payload: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
now := time.Now().Unix()
if claims.EXP >= now {
t.Errorf("exp = %d, expected < now (%d)", claims.EXP, now)
}
}
func TestSignJWT_EmptyUsername_Error(t *testing.T) {
key := []byte("0123456789abcdef0123456789abcdef")
_, err := SignJWT(key, "", 1*time.Hour)
if err == nil {
t.Fatal("expected error for empty username, got nil")
}
}
func TestReadJWTKey_ValidFile(t *testing.T) {
keyData := make([]byte, 32)
for i := range keyData {
keyData[i] = byte(i)
}
f, err := os.CreateTemp("", "jwtkey-*")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
if _, err := f.Write(keyData); err != nil {
t.Fatal(err)
}
f.Close()
got, err := ReadJWTKey(f.Name())
if err != nil {
t.Fatalf("ReadJWTKey returned error: %v", err)
}
if string(got) != string(keyData) {
t.Errorf("key mismatch: got %d bytes, want %d bytes", len(got), len(keyData))
}
}
func TestReadJWTKey_FileNotFound(t *testing.T) {
_, err := ReadJWTKey("/nonexistent/path/keyfile")
if err == nil {
t.Fatal("expected error for nonexistent file, got nil")
}
if !errors.Is(err, os.ErrNotExist) {
t.Errorf("error = %v, want os.ErrNotExist", err)
}
}
func TestReadJWTKey_InvalidSize(t *testing.T) {
f, err := os.CreateTemp("", "jwtkey-*")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
if _, err := f.Write([]byte("short")); err != nil {
t.Fatal(err)
}
f.Close()
_, err = ReadJWTKey(f.Name())
if err == nil {
t.Fatal("expected error for short key, got nil")
}
if !strings.Contains(err.Error(), "key must be") {
t.Errorf("error = %q, want message containing 'key must be'", err.Error())
}
}

View File

@@ -0,0 +1,90 @@
package slurm
import (
"context"
"fmt"
"net/http"
"time"
)
type JWTTransportOption func(*JWTAuthTransport)
type JWTAuthTransport struct {
UserName string
key []byte
tokenCache *TokenCache
ttl time.Duration
leeway time.Duration
Base http.RoundTripper
}
func NewJWTAuthTransport(username string, key []byte, opts ...JWTTransportOption) *JWTAuthTransport {
const (
defaultTTL = 30 * time.Minute
defaultLeeway = 30 * time.Second
)
t := &JWTAuthTransport{
UserName: username,
key: key,
ttl: defaultTTL,
leeway: defaultLeeway,
}
for _, opt := range opts {
opt(t)
}
t.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) {
return SignJWT(t.key, t.UserName, t.ttl)
},
t.ttl,
t.leeway,
)
return t
}
func (t *JWTAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.tokenCache.Token(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to get JWT token: %w", err)
}
req2 := cloneRequest(req)
req2.Header.Set("X-SLURM-USER-NAME", t.UserName)
req2.Header.Set("X-SLURM-USER-TOKEN", token)
return t.transport().RoundTrip(req2)
}
func (t *JWTAuthTransport) Client() *http.Client {
return &http.Client{Transport: t}
}
func (t *JWTAuthTransport) transport() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func WithTTL(ttl time.Duration) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.ttl = ttl
}
}
func WithLeeway(leeway time.Duration) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.leeway = leeway
}
}
func WithBaseTransport(base http.RoundTripper) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.Base = base
}
}

View File

@@ -0,0 +1,214 @@
package slurm
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestJWTAuthTransport_SetsCorrectHeaders(t *testing.T) {
var gotHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tr := NewJWTAuthTransport("testuser", key)
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if gotHeaders.Get("X-SLURM-USER-NAME") != "testuser" {
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "testuser")
}
tokenStr := gotHeaders.Get("X-SLURM-USER-TOKEN")
if tokenStr == "" {
t.Fatal("X-SLURM-USER-TOKEN is empty")
}
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
t.Fatalf("JWT should have 3 parts, got %d", len(parts))
}
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("decode claims: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
t.Fatalf("unmarshal claims: %v", err)
}
if claims.Sun != "testuser" {
t.Errorf("sun claim = %q, want %q", claims.Sun, "testuser")
}
}
func TestJWTAuthTransport_AutoRefreshOnExpiry(t *testing.T) {
var tokens []string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokens = append(tokens, r.Header.Get("X-SLURM-USER-TOKEN"))
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tr := NewJWTAuthTransport("testuser", key,
WithTTL(1*time.Millisecond),
WithLeeway(0),
)
var callCount int
tr.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) {
callCount++
return SignJWT(key, fmt.Sprintf("testuser-%d", callCount), 5*time.Minute)
},
1*time.Millisecond,
0,
)
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("first request: %v", err)
}
resp.Body.Close()
time.Sleep(10 * time.Millisecond)
req2, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp2, err := client.Do(req2)
if err != nil {
t.Fatalf("second request: %v", err)
}
resp2.Body.Close()
if len(tokens) != 2 {
t.Fatalf("expected 2 tokens, got %d", len(tokens))
}
if tokens[0] == tokens[1] {
t.Error("token should have been refreshed after expiry, but got same token")
}
for i, tok := range tokens {
parts := strings.Split(tok, ".")
if len(parts) != 3 {
t.Errorf("token[%d] should have 3 parts, got %d", i, len(parts))
}
}
}
func TestJWTAuthTransport_StaticTokenStillWorks(t *testing.T) {
var gotHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
tr := &TokenAuthTransport{
UserName: "staticuser",
Token: "static-token-123",
}
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if gotHeaders.Get("X-SLURM-USER-NAME") != "staticuser" {
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "staticuser")
}
if gotHeaders.Get("X-SLURM-USER-TOKEN") != "static-token-123" {
t.Errorf("X-SLURM-USER-TOKEN = %q, want %q", gotHeaders.Get("X-SLURM-USER-TOKEN"), "static-token-123")
}
}
func TestJWTAuthTransport_SigningError_ReturnsError(t *testing.T) {
expectedErr := errors.New("signing failed")
tr := NewJWTAuthTransport("testuser", nil)
tr.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) { return "", expectedErr },
30*time.Minute,
30*time.Second,
)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil)
_, err := tr.RoundTrip(req)
if err == nil {
t.Fatal("expected error from RoundTrip, got nil")
}
if !strings.Contains(err.Error(), "failed to get JWT token") {
t.Errorf("error = %q, want containing %q", err.Error(), "failed to get JWT token")
}
}
func TestJWTAuthTransport_CustomBaseTransport(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Response", "from-base")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
var custom http.RoundTripper = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Custom-Request", "injected")
return http.DefaultTransport.RoundTrip(req)
})
tr := NewJWTAuthTransport("testuser", key, WithBaseTransport(custom))
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.Header.Get("X-Custom-Response") != "from-base" {
t.Error("custom base transport was not used")
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

100
internal/slurm/options.go Normal file
View File

@@ -0,0 +1,100 @@
package slurm
import (
"fmt"
"net/http"
"time"
)
// ClientOption configures a Client via functional options.
type ClientOption func(*clientConfig) error
type clientConfig struct {
jwtKeyPath string
username string
ttl time.Duration
leeway time.Duration
httpClient *http.Client
}
func defaultClientConfig() *clientConfig {
return &clientConfig{
ttl: 30 * time.Minute,
leeway: 30 * time.Second,
}
}
// WithJWTKey specifies the path to the JWT key file.
func WithJWTKey(path string) ClientOption {
return func(c *clientConfig) error {
c.jwtKeyPath = path
return nil
}
}
// WithUsername specifies the Slurm username for JWT authentication.
func WithUsername(username string) ClientOption {
return func(c *clientConfig) error {
c.username = username
return nil
}
}
// WithTokenTTL sets the JWT token time-to-live (default: 30 minutes).
func WithTokenTTL(ttl time.Duration) ClientOption {
return func(c *clientConfig) error {
c.ttl = ttl
return nil
}
}
// WithTokenLeeway sets the JWT token refresh leeway (default: 30 seconds).
func WithTokenLeeway(leeway time.Duration) ClientOption {
return func(c *clientConfig) error {
c.leeway = leeway
return nil
}
}
// WithHTTPClient specifies a custom HTTP client.
func WithHTTPClient(client *http.Client) ClientOption {
return func(c *clientConfig) error {
c.httpClient = client
return nil
}
}
// NewClientWithOpts creates a new Slurm API client using functional options.
// If WithJWTKey and WithUsername are provided, JWT authentication is configured
// automatically. If no JWT options are provided, http.DefaultClient is used.
func NewClientWithOpts(baseURL string, opts ...ClientOption) (*Client, error) {
cfg := defaultClientConfig()
for _, opt := range opts {
if err := opt(cfg); err != nil {
return nil, err
}
}
var httpClient *http.Client
if cfg.jwtKeyPath != "" && cfg.username != "" {
key, err := ReadJWTKey(cfg.jwtKeyPath)
if err != nil {
return nil, fmt.Errorf("read JWT key: %w", err)
}
transportOpts := []JWTTransportOption{
WithTTL(cfg.ttl),
WithLeeway(cfg.leeway),
}
tr := NewJWTAuthTransport(cfg.username, key, transportOpts...)
httpClient = tr.Client()
} else if cfg.httpClient != nil {
httpClient = cfg.httpClient
} else {
httpClient = http.DefaultClient
}
return NewClient(baseURL, httpClient)
}

View File

@@ -0,0 +1,154 @@
package slurm
import (
"crypto/rand"
"net/http"
"os"
"path/filepath"
"testing"
"time"
)
func TestNewClientWithOpts_JWTKey_Success(t *testing.T) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
t.Fatalf("generate key: %v", err)
}
dir := t.TempDir()
keyPath := filepath.Join(dir, "jwt.key")
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("write key file: %v", err)
}
client, err := NewClientWithOpts("http://localhost:6820/",
WithJWTKey(keyPath),
WithUsername("testuser"),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if client == nil {
t.Fatal("expected non-nil client")
}
transport, ok := client.client.Transport.(*JWTAuthTransport)
if !ok {
t.Fatalf("expected *JWTAuthTransport, got %T", client.client.Transport)
}
if transport.UserName != "testuser" {
t.Errorf("expected username %q, got %q", "testuser", transport.UserName)
}
}
func TestNewClientWithOpts_InvalidKeyPath_Error(t *testing.T) {
client, err := NewClientWithOpts("http://localhost:6820/",
WithJWTKey("/nonexistent/key"),
WithUsername("testuser"),
)
if err == nil {
t.Fatal("expected error for invalid key path, got nil")
}
if client != nil {
t.Fatal("expected nil client on error")
}
}
func TestNewClientWithOpts_BackwardCompatible(t *testing.T) {
client, err := NewClientWithOpts("http://localhost:6820/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if client == nil {
t.Fatal("expected non-nil client")
}
if client.client != http.DefaultClient {
t.Error("expected http.DefaultClient when no options provided")
}
}
func TestNewClientWithOpts_AllServicesInitialized(t *testing.T) {
client, err := NewClientWithOpts("http://localhost:6820/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
services := []struct {
name string
svc interface{}
}{
{"Jobs", client.Jobs},
{"Nodes", client.Nodes},
{"Partitions", client.Partitions},
{"Reservations", client.Reservations},
{"Diag", client.Diag},
{"Ping", client.Ping},
{"Licenses", client.Licenses},
{"Reconfigure", client.Reconfigure},
{"Shares", client.Shares},
{"SlurmdbDiag", client.SlurmdbDiag},
{"SlurmdbConfig", client.SlurmdbConfig},
{"SlurmdbTres", client.SlurmdbTres},
{"SlurmdbQos", client.SlurmdbQos},
{"SlurmdbAssocs", client.SlurmdbAssocs},
{"SlurmdbInstances", client.SlurmdbInstances},
{"SlurmdbUsers", client.SlurmdbUsers},
{"SlurmdbClusters", client.SlurmdbClusters},
{"SlurmdbWckeys", client.SlurmdbWckeys},
{"SlurmdbAccounts", client.SlurmdbAccounts},
{"SlurmdbJobs", client.SlurmdbJobs},
}
for _, s := range services {
if s.svc == nil {
t.Errorf("%s service is nil", s.name)
}
}
}
func TestWithTokenTTL_Custom(t *testing.T) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
t.Fatalf("generate key: %v", err)
}
dir := t.TempDir()
keyPath := filepath.Join(dir, "jwt.key")
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("write key file: %v", err)
}
client, err := NewClientWithOpts("http://localhost:6820/",
WithJWTKey(keyPath),
WithUsername("testuser"),
WithTokenTTL(1*time.Hour),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if client == nil {
t.Fatal("expected non-nil client")
}
}
func TestWithTokenLeeway_Custom(t *testing.T) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
t.Fatalf("generate key: %v", err)
}
dir := t.TempDir()
keyPath := filepath.Join(dir, "jwt.key")
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("write key file: %v", err)
}
client, err := NewClientWithOpts("http://localhost:6820/",
WithJWTKey(keyPath),
WithUsername("testuser"),
WithTokenLeeway(1*time.Minute),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if client == nil {
t.Fatal("expected non-nil client")
}
}

View File

@@ -3,7 +3,7 @@
// The client handles authentication via X-SLURM-USER-NAME and X-SLURM-USER-TOKEN // The client handles authentication via X-SLURM-USER-NAME and X-SLURM-USER-TOKEN
// headers, request/response marshaling, and error handling. // headers, request/response marshaling, and error handling.
// //
// Basic usage: // Static token authentication:
// //
// httpClient := &http.Client{ // httpClient := &http.Client{
// Transport: &slurm.TokenAuthTransport{ // Transport: &slurm.TokenAuthTransport{
@@ -15,4 +15,14 @@
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
//
// JWT authentication (auto-signed from local key):
//
// client, err := slurm.NewClientWithOpts("http://localhost:6820",
// slurm.WithJWTKey("/etc/slurm/jwt/slurm_jwt.key"),
// slurm.WithUsername("slurmapi"),
// )
// if err != nil {
// log.Fatal(err)
// }
package slurm package slurm

View File

@@ -0,0 +1,51 @@
package slurm
import (
"context"
"sync"
"time"
)
type TokenCache struct {
mu sync.RWMutex
token string
expireAt time.Time
refresh func(ctx context.Context) (string, error)
ttl time.Duration
leeway time.Duration
}
func NewTokenCache(refresh func(ctx context.Context) (string, error), ttl time.Duration, leeway time.Duration) *TokenCache {
return &TokenCache{
refresh: refresh,
ttl: ttl,
leeway: leeway,
}
}
func (c *TokenCache) Token(ctx context.Context) (string, error) {
c.mu.RLock()
if c.token != "" && time.Now().Before(c.expireAt.Add(-c.leeway)) {
token := c.token
c.mu.RUnlock()
return token, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.token != "" && time.Now().Before(c.expireAt.Add(-c.leeway)) {
return c.token, nil
}
token, err := c.refresh(ctx)
if err != nil {
return "", err
}
c.token = token
c.expireAt = time.Now().Add(c.ttl)
return token, nil
}

View File

@@ -0,0 +1,155 @@
package slurm
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestTokenCache_NewToken_ReturnsCachedToken(t *testing.T) {
var calls atomic.Int32
refresh := func(ctx context.Context) (string, error) {
calls.Add(1)
return "token-A", nil
}
cache := NewTokenCache(refresh, 30*time.Minute, 0)
token1, err := cache.Token(context.Background())
if err != nil {
t.Fatalf("first Token() error: %v", err)
}
if token1 != "token-A" {
t.Errorf("first Token() = %q, want %q", token1, "token-A")
}
token2, err := cache.Token(context.Background())
if err != nil {
t.Fatalf("second Token() error: %v", err)
}
if token2 != "token-A" {
t.Errorf("second Token() = %q, want %q", token2, "token-A")
}
if got := calls.Load(); got != 1 {
t.Errorf("refresh called %d times, want 1", got)
}
}
func TestTokenCache_ExpiredToken_TriggersRefresh(t *testing.T) {
var calls atomic.Int32
tokens := []string{"token-A", "token-B"}
refresh := func(ctx context.Context) (string, error) {
idx := calls.Add(1) - 1
if int(idx) >= len(tokens) {
return tokens[len(tokens)-1], nil
}
return tokens[idx], nil
}
cache := NewTokenCache(refresh, 1*time.Millisecond, 0)
token1, err := cache.Token(context.Background())
if err != nil {
t.Fatalf("first Token() error: %v", err)
}
if token1 != "token-A" {
t.Errorf("first Token() = %q, want %q", token1, "token-A")
}
time.Sleep(5 * time.Millisecond)
token2, err := cache.Token(context.Background())
if err != nil {
t.Fatalf("second Token() error: %v", err)
}
if token2 != "token-B" {
t.Errorf("second Token() = %q, want %q", token2, "token-B")
}
if got := calls.Load(); got != 2 {
t.Errorf("refresh called %d times, want 2", got)
}
}
func TestTokenCache_ConcurrentAccess(t *testing.T) {
var calls atomic.Int32
refresh := func(ctx context.Context) (string, error) {
calls.Add(1)
time.Sleep(10 * time.Millisecond)
return "concurrent-token", nil
}
cache := NewTokenCache(refresh, 30*time.Minute, 0)
const goroutines = 100
var wg sync.WaitGroup
results := make([]string, goroutines)
errs := make([]error, goroutines)
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx], errs[idx] = cache.Token(context.Background())
}(i)
}
wg.Wait()
for i, err := range errs {
if err != nil {
t.Errorf("goroutine %d error: %v", i, err)
}
}
for i, tok := range results {
if tok != "concurrent-token" {
t.Errorf("goroutine %d got %q, want %q", i, tok, "concurrent-token")
}
}
if got := calls.Load(); got != 1 {
t.Errorf("refresh called %d times, want 1", got)
}
}
func TestTokenCache_Leeway_EarlyRefresh(t *testing.T) {
var calls atomic.Int32
tokens := []string{"token-A", "token-B"}
refresh := func(ctx context.Context) (string, error) {
idx := calls.Add(1) - 1
if int(idx) >= len(tokens) {
return tokens[len(tokens)-1], nil
}
return tokens[idx], nil
}
ttl := 100 * time.Millisecond
leeway := 90 * time.Millisecond
cache := NewTokenCache(refresh, ttl, leeway)
token1, err := cache.Token(context.Background())
if err != nil {
t.Fatalf("first Token() error: %v", err)
}
if token1 != "token-A" {
t.Errorf("first Token() = %q, want %q", token1, "token-A")
}
// Token is stale after ttl - leeway = 10ms
time.Sleep(20 * time.Millisecond)
token2, err := cache.Token(context.Background())
if err != nil {
t.Fatalf("second Token() error: %v", err)
}
if token2 != "token-B" {
t.Errorf("second Token() = %q, want %q", token2, "token-B")
}
if got := calls.Load(); got != 2 {
t.Errorf("refresh called %d times, want 2 (early refresh via leeway)", got)
}
}

View File

@@ -0,0 +1,8 @@
package slurm
import "context"
// TokenSource provides tokens for authentication.
type TokenSource interface {
Token(ctx context.Context) (string, error)
}