diff --git a/internal/slurm/token_cache.go b/internal/slurm/token_cache.go new file mode 100644 index 0000000..da235cb --- /dev/null +++ b/internal/slurm/token_cache.go @@ -0,0 +1,51 @@ +package slurm + +import ( + "context" + "sync" + "time" +) + +type TokenCache struct { + mu sync.RWMutex + token string + expireAt time.Time + + refresh func(ctx context.Context) (string, error) + ttl time.Duration + leeway time.Duration +} + +func NewTokenCache(refresh func(ctx context.Context) (string, error), ttl time.Duration, leeway time.Duration) *TokenCache { + return &TokenCache{ + refresh: refresh, + ttl: ttl, + leeway: leeway, + } +} + +func (c *TokenCache) Token(ctx context.Context) (string, error) { + c.mu.RLock() + if c.token != "" && time.Now().Before(c.expireAt.Add(-c.leeway)) { + token := c.token + c.mu.RUnlock() + return token, nil + } + c.mu.RUnlock() + + c.mu.Lock() + defer c.mu.Unlock() + + if c.token != "" && time.Now().Before(c.expireAt.Add(-c.leeway)) { + return c.token, nil + } + + token, err := c.refresh(ctx) + if err != nil { + return "", err + } + + c.token = token + c.expireAt = time.Now().Add(c.ttl) + return token, nil +} diff --git a/internal/slurm/token_cache_test.go b/internal/slurm/token_cache_test.go new file mode 100644 index 0000000..885cb65 --- /dev/null +++ b/internal/slurm/token_cache_test.go @@ -0,0 +1,155 @@ +package slurm + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestTokenCache_NewToken_ReturnsCachedToken(t *testing.T) { + var calls atomic.Int32 + refresh := func(ctx context.Context) (string, error) { + calls.Add(1) + return "token-A", nil + } + + cache := NewTokenCache(refresh, 30*time.Minute, 0) + + token1, err := cache.Token(context.Background()) + if err != nil { + t.Fatalf("first Token() error: %v", err) + } + if token1 != "token-A" { + t.Errorf("first Token() = %q, want %q", token1, "token-A") + } + + token2, err := cache.Token(context.Background()) + if err != nil { + t.Fatalf("second Token() error: %v", err) + } + if token2 != "token-A" { + t.Errorf("second Token() = %q, want %q", token2, "token-A") + } + + if got := calls.Load(); got != 1 { + t.Errorf("refresh called %d times, want 1", got) + } +} + +func TestTokenCache_ExpiredToken_TriggersRefresh(t *testing.T) { + var calls atomic.Int32 + tokens := []string{"token-A", "token-B"} + refresh := func(ctx context.Context) (string, error) { + idx := calls.Add(1) - 1 + if int(idx) >= len(tokens) { + return tokens[len(tokens)-1], nil + } + return tokens[idx], nil + } + + cache := NewTokenCache(refresh, 1*time.Millisecond, 0) + + token1, err := cache.Token(context.Background()) + if err != nil { + t.Fatalf("first Token() error: %v", err) + } + if token1 != "token-A" { + t.Errorf("first Token() = %q, want %q", token1, "token-A") + } + + time.Sleep(5 * time.Millisecond) + + token2, err := cache.Token(context.Background()) + if err != nil { + t.Fatalf("second Token() error: %v", err) + } + if token2 != "token-B" { + t.Errorf("second Token() = %q, want %q", token2, "token-B") + } + + if got := calls.Load(); got != 2 { + t.Errorf("refresh called %d times, want 2", got) + } +} + +func TestTokenCache_ConcurrentAccess(t *testing.T) { + var calls atomic.Int32 + refresh := func(ctx context.Context) (string, error) { + calls.Add(1) + time.Sleep(10 * time.Millisecond) + return "concurrent-token", nil + } + + cache := NewTokenCache(refresh, 30*time.Minute, 0) + + const goroutines = 100 + var wg sync.WaitGroup + results := make([]string, goroutines) + errs := make([]error, goroutines) + + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx], errs[idx] = cache.Token(context.Background()) + }(i) + } + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Errorf("goroutine %d error: %v", i, err) + } + } + + for i, tok := range results { + if tok != "concurrent-token" { + t.Errorf("goroutine %d got %q, want %q", i, tok, "concurrent-token") + } + } + + if got := calls.Load(); got != 1 { + t.Errorf("refresh called %d times, want 1", got) + } +} + +func TestTokenCache_Leeway_EarlyRefresh(t *testing.T) { + var calls atomic.Int32 + tokens := []string{"token-A", "token-B"} + refresh := func(ctx context.Context) (string, error) { + idx := calls.Add(1) - 1 + if int(idx) >= len(tokens) { + return tokens[len(tokens)-1], nil + } + return tokens[idx], nil + } + + ttl := 100 * time.Millisecond + leeway := 90 * time.Millisecond + cache := NewTokenCache(refresh, ttl, leeway) + + token1, err := cache.Token(context.Background()) + if err != nil { + t.Fatalf("first Token() error: %v", err) + } + if token1 != "token-A" { + t.Errorf("first Token() = %q, want %q", token1, "token-A") + } + + // Token is stale after ttl - leeway = 10ms + time.Sleep(20 * time.Millisecond) + + token2, err := cache.Token(context.Background()) + if err != nil { + t.Fatalf("second Token() error: %v", err) + } + if token2 != "token-B" { + t.Errorf("second Token() = %q, want %q", token2, "token-B") + } + + if got := calls.Load(); got != 2 { + t.Errorf("refresh called %d times, want 2 (early refresh via leeway)", got) + } +} diff --git a/internal/slurm/token_source.go b/internal/slurm/token_source.go new file mode 100644 index 0000000..aba981c --- /dev/null +++ b/internal/slurm/token_source.go @@ -0,0 +1,8 @@ +package slurm + +import "context" + +// TokenSource provides tokens for authentication. +type TokenSource interface { + Token(ctx context.Context) (string, error) +}