feat(auth): add JWTAuthTransport with auto-refresh RoundTrip
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
214
internal/slurm/jwt_transport_test.go
Normal file
214
internal/slurm/jwt_transport_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package slurm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestJWTAuthTransport_SetsCorrectHeaders(t *testing.T) {
|
||||
var gotHeaders http.Header
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
|
||||
tr := NewJWTAuthTransport("testuser", key)
|
||||
client := tr.Client()
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if gotHeaders.Get("X-SLURM-USER-NAME") != "testuser" {
|
||||
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "testuser")
|
||||
}
|
||||
|
||||
tokenStr := gotHeaders.Get("X-SLURM-USER-TOKEN")
|
||||
if tokenStr == "" {
|
||||
t.Fatal("X-SLURM-USER-TOKEN is empty")
|
||||
}
|
||||
|
||||
parts := strings.Split(tokenStr, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("JWT should have 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("decode claims: %v", err)
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
|
||||
t.Fatalf("unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.Sun != "testuser" {
|
||||
t.Errorf("sun claim = %q, want %q", claims.Sun, "testuser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuthTransport_AutoRefreshOnExpiry(t *testing.T) {
|
||||
var tokens []string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokens = append(tokens, r.Header.Get("X-SLURM-USER-TOKEN"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
|
||||
tr := NewJWTAuthTransport("testuser", key,
|
||||
WithTTL(1*time.Millisecond),
|
||||
WithLeeway(0),
|
||||
)
|
||||
|
||||
var callCount int
|
||||
tr.tokenCache = NewTokenCache(
|
||||
func(ctx context.Context) (string, error) {
|
||||
callCount++
|
||||
return SignJWT(key, fmt.Sprintf("testuser-%d", callCount), 5*time.Minute)
|
||||
},
|
||||
1*time.Millisecond,
|
||||
0,
|
||||
)
|
||||
|
||||
client := tr.Client()
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("first request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
req2, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
resp2, err := client.Do(req2)
|
||||
if err != nil {
|
||||
t.Fatalf("second request: %v", err)
|
||||
}
|
||||
resp2.Body.Close()
|
||||
|
||||
if len(tokens) != 2 {
|
||||
t.Fatalf("expected 2 tokens, got %d", len(tokens))
|
||||
}
|
||||
|
||||
if tokens[0] == tokens[1] {
|
||||
t.Error("token should have been refreshed after expiry, but got same token")
|
||||
}
|
||||
|
||||
for i, tok := range tokens {
|
||||
parts := strings.Split(tok, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("token[%d] should have 3 parts, got %d", i, len(parts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuthTransport_StaticTokenStillWorks(t *testing.T) {
|
||||
var gotHeaders http.Header
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := &TokenAuthTransport{
|
||||
UserName: "staticuser",
|
||||
Token: "static-token-123",
|
||||
}
|
||||
client := tr.Client()
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if gotHeaders.Get("X-SLURM-USER-NAME") != "staticuser" {
|
||||
t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "staticuser")
|
||||
}
|
||||
if gotHeaders.Get("X-SLURM-USER-TOKEN") != "static-token-123" {
|
||||
t.Errorf("X-SLURM-USER-TOKEN = %q, want %q", gotHeaders.Get("X-SLURM-USER-TOKEN"), "static-token-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuthTransport_SigningError_ReturnsError(t *testing.T) {
|
||||
expectedErr := errors.New("signing failed")
|
||||
tr := NewJWTAuthTransport("testuser", nil)
|
||||
tr.tokenCache = NewTokenCache(
|
||||
func(ctx context.Context) (string, error) { return "", expectedErr },
|
||||
30*time.Minute,
|
||||
30*time.Second,
|
||||
)
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil)
|
||||
_, err := tr.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from RoundTrip, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to get JWT token") {
|
||||
t.Errorf("error = %q, want containing %q", err.Error(), "failed to get JWT token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuthTransport_CustomBaseTransport(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Response", "from-base")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
|
||||
var custom http.RoundTripper = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("X-Custom-Request", "injected")
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
})
|
||||
|
||||
tr := NewJWTAuthTransport("testuser", key, WithBaseTransport(custom))
|
||||
client := tr.Client()
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.Header.Get("X-Custom-Response") != "from-base" {
|
||||
t.Error("custom base transport was not used")
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
Reference in New Issue
Block a user