feat(auth): add JWTAuthTransport with auto-refresh RoundTrip

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
dailz
2026-04-09 10:31:39 +08:00
parent 2dcbfb95b0
commit f8119ff9e5
2 changed files with 304 additions and 0 deletions

View File

@@ -0,0 +1,90 @@
package slurm
import (
"context"
"fmt"
"net/http"
"time"
)
type JWTTransportOption func(*JWTAuthTransport)
type JWTAuthTransport struct {
UserName string
key []byte
tokenCache *TokenCache
ttl time.Duration
leeway time.Duration
Base http.RoundTripper
}
func NewJWTAuthTransport(username string, key []byte, opts ...JWTTransportOption) *JWTAuthTransport {
const (
defaultTTL = 30 * time.Minute
defaultLeeway = 30 * time.Second
)
t := &JWTAuthTransport{
UserName: username,
key: key,
ttl: defaultTTL,
leeway: defaultLeeway,
}
for _, opt := range opts {
opt(t)
}
t.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) {
return SignJWT(t.key, t.UserName, t.ttl)
},
t.ttl,
t.leeway,
)
return t
}
func (t *JWTAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.tokenCache.Token(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to get JWT token: %w", err)
}
req2 := cloneRequest(req)
req2.Header.Set("X-SLURM-USER-NAME", t.UserName)
req2.Header.Set("X-SLURM-USER-TOKEN", token)
return t.transport().RoundTrip(req2)
}
func (t *JWTAuthTransport) Client() *http.Client {
return &http.Client{Transport: t}
}
func (t *JWTAuthTransport) transport() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func WithTTL(ttl time.Duration) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.ttl = ttl
}
}
func WithLeeway(leeway time.Duration) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.leeway = leeway
}
}
func WithBaseTransport(base http.RoundTripper) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.Base = base
}
}

View File

@@ -0,0 +1,214 @@
package slurm
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestJWTAuthTransport_SetsCorrectHeaders(t *testing.T) {
var gotHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tr := NewJWTAuthTransport("testuser", key)
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if gotHeaders.Get("X-SLURM-USER-NAME") != "testuser" {
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "testuser")
}
tokenStr := gotHeaders.Get("X-SLURM-USER-TOKEN")
if tokenStr == "" {
t.Fatal("X-SLURM-USER-TOKEN is empty")
}
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
t.Fatalf("JWT should have 3 parts, got %d", len(parts))
}
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("decode claims: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
t.Fatalf("unmarshal claims: %v", err)
}
if claims.Sun != "testuser" {
t.Errorf("sun claim = %q, want %q", claims.Sun, "testuser")
}
}
func TestJWTAuthTransport_AutoRefreshOnExpiry(t *testing.T) {
var tokens []string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokens = append(tokens, r.Header.Get("X-SLURM-USER-TOKEN"))
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tr := NewJWTAuthTransport("testuser", key,
WithTTL(1*time.Millisecond),
WithLeeway(0),
)
var callCount int
tr.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) {
callCount++
return SignJWT(key, fmt.Sprintf("testuser-%d", callCount), 5*time.Minute)
},
1*time.Millisecond,
0,
)
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("first request: %v", err)
}
resp.Body.Close()
time.Sleep(10 * time.Millisecond)
req2, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp2, err := client.Do(req2)
if err != nil {
t.Fatalf("second request: %v", err)
}
resp2.Body.Close()
if len(tokens) != 2 {
t.Fatalf("expected 2 tokens, got %d", len(tokens))
}
if tokens[0] == tokens[1] {
t.Error("token should have been refreshed after expiry, but got same token")
}
for i, tok := range tokens {
parts := strings.Split(tok, ".")
if len(parts) != 3 {
t.Errorf("token[%d] should have 3 parts, got %d", i, len(parts))
}
}
}
func TestJWTAuthTransport_StaticTokenStillWorks(t *testing.T) {
var gotHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
tr := &TokenAuthTransport{
UserName: "staticuser",
Token: "static-token-123",
}
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if gotHeaders.Get("X-SLURM-USER-NAME") != "staticuser" {
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "staticuser")
}
if gotHeaders.Get("X-SLURM-USER-TOKEN") != "static-token-123" {
t.Errorf("X-SLURM-USER-TOKEN = %q, want %q", gotHeaders.Get("X-SLURM-USER-TOKEN"), "static-token-123")
}
}
func TestJWTAuthTransport_SigningError_ReturnsError(t *testing.T) {
expectedErr := errors.New("signing failed")
tr := NewJWTAuthTransport("testuser", nil)
tr.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) { return "", expectedErr },
30*time.Minute,
30*time.Second,
)
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil)
_, err := tr.RoundTrip(req)
if err == nil {
t.Fatal("expected error from RoundTrip, got nil")
}
if !strings.Contains(err.Error(), "failed to get JWT token") {
t.Errorf("error = %q, want containing %q", err.Error(), "failed to get JWT token")
}
}
func TestJWTAuthTransport_CustomBaseTransport(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Response", "from-base")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
var custom http.RoundTripper = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Custom-Request", "injected")
return http.DefaultTransport.RoundTrip(req)
})
tr := NewJWTAuthTransport("testuser", key, WithBaseTransport(custom))
client := tr.Client()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.Header.Get("X-Custom-Response") != "from-base" {
t.Error("custom base transport was not used")
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}