feat(auth): add JWTAuthTransport with auto-refresh RoundTrip

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
dailz
2026-04-09 10:31:39 +08:00
parent 2dcbfb95b0
commit f8119ff9e5
2 changed files with 304 additions and 0 deletions

View File

@@ -0,0 +1,214 @@
package slurm
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestJWTAuthTransport_SetsCorrectHeaders(t *testing.T) {
var gotHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tr := NewJWTAuthTransport("testuser", key)
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if gotHeaders.Get("X-SLURM-USER-NAME") != "testuser" {
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "testuser")
}
tokenStr := gotHeaders.Get("X-SLURM-USER-TOKEN")
if tokenStr == "" {
t.Fatal("X-SLURM-USER-TOKEN is empty")
}
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
t.Fatalf("JWT should have 3 parts, got %d", len(parts))
}
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("decode claims: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
t.Fatalf("unmarshal claims: %v", err)
}
if claims.Sun != "testuser" {
t.Errorf("sun claim = %q, want %q", claims.Sun, "testuser")
}
}
func TestJWTAuthTransport_AutoRefreshOnExpiry(t *testing.T) {
var tokens []string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokens = append(tokens, r.Header.Get("X-SLURM-USER-TOKEN"))
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tr := NewJWTAuthTransport("testuser", key,
WithTTL(1*time.Millisecond),
WithLeeway(0),
)
var callCount int
tr.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) {
callCount++
return SignJWT(key, fmt.Sprintf("testuser-%d", callCount), 5*time.Minute)
},
1*time.Millisecond,
0,
)
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("first request: %v", err)
}
resp.Body.Close()
time.Sleep(10 * time.Millisecond)
req2, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp2, err := client.Do(req2)
if err != nil {
t.Fatalf("second request: %v", err)
}
resp2.Body.Close()
if len(tokens) != 2 {
t.Fatalf("expected 2 tokens, got %d", len(tokens))
}
if tokens[0] == tokens[1] {
t.Error("token should have been refreshed after expiry, but got same token")
}
for i, tok := range tokens {
parts := strings.Split(tok, ".")
if len(parts) != 3 {
t.Errorf("token[%d] should have 3 parts, got %d", i, len(parts))
}
}
}
func TestJWTAuthTransport_StaticTokenStillWorks(t *testing.T) {
var gotHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
tr := &TokenAuthTransport{
UserName: "staticuser",
Token: "static-token-123",
}
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if gotHeaders.Get("X-SLURM-USER-NAME") != "staticuser" {
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "staticuser")
}
if gotHeaders.Get("X-SLURM-USER-TOKEN") != "static-token-123" {
t.Errorf("X-SLURM-USER-TOKEN = %q, want %q", gotHeaders.Get("X-SLURM-USER-TOKEN"), "static-token-123")
}
}
func TestJWTAuthTransport_SigningError_ReturnsError(t *testing.T) {
expectedErr := errors.New("signing failed")
tr := NewJWTAuthTransport("testuser", nil)
tr.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) { return "", expectedErr },
30*time.Minute,
30*time.Second,
)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil)
_, err := tr.RoundTrip(req)
if err == nil {
t.Fatal("expected error from RoundTrip, got nil")
}
if !strings.Contains(err.Error(), "failed to get JWT token") {
t.Errorf("error = %q, want containing %q", err.Error(), "failed to get JWT token")
}
}
func TestJWTAuthTransport_CustomBaseTransport(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Response", "from-base")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
var custom http.RoundTripper = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Custom-Request", "injected")
return http.DefaultTransport.RoundTrip(req)
})
tr := NewJWTAuthTransport("testuser", key, WithBaseTransport(custom))
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.Header.Get("X-Custom-Response") != "from-base" {
t.Error("custom base transport was not used")
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}