diff --git a/internal/slurm/jwt_transport.go b/internal/slurm/jwt_transport.go new file mode 100644 index 0000000..54fb859 --- /dev/null +++ b/internal/slurm/jwt_transport.go @@ -0,0 +1,90 @@ +package slurm + +import ( + "context" + "fmt" + "net/http" + "time" +) + +type JWTTransportOption func(*JWTAuthTransport) + +type JWTAuthTransport struct { + UserName string + key []byte + + tokenCache *TokenCache + ttl time.Duration + leeway time.Duration + + Base http.RoundTripper +} + +func NewJWTAuthTransport(username string, key []byte, opts ...JWTTransportOption) *JWTAuthTransport { + const ( + defaultTTL = 30 * time.Minute + defaultLeeway = 30 * time.Second + ) + + t := &JWTAuthTransport{ + UserName: username, + key: key, + ttl: defaultTTL, + leeway: defaultLeeway, + } + + for _, opt := range opts { + opt(t) + } + + t.tokenCache = NewTokenCache( + func(ctx context.Context) (string, error) { + return SignJWT(t.key, t.UserName, t.ttl) + }, + t.ttl, + t.leeway, + ) + + return t +} + +func (t *JWTAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.tokenCache.Token(req.Context()) + if err != nil { + return nil, fmt.Errorf("failed to get JWT token: %w", err) + } + + req2 := cloneRequest(req) + req2.Header.Set("X-SLURM-USER-NAME", t.UserName) + req2.Header.Set("X-SLURM-USER-TOKEN", token) + return t.transport().RoundTrip(req2) +} + +func (t *JWTAuthTransport) Client() *http.Client { + return &http.Client{Transport: t} +} + +func (t *JWTAuthTransport) transport() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func WithTTL(ttl time.Duration) JWTTransportOption { + return func(t *JWTAuthTransport) { + t.ttl = ttl + } +} + +func WithLeeway(leeway time.Duration) JWTTransportOption { + return func(t *JWTAuthTransport) { + t.leeway = leeway + } +} + +func WithBaseTransport(base http.RoundTripper) JWTTransportOption { + return func(t *JWTAuthTransport) { + t.Base = base + } +} diff --git a/internal/slurm/jwt_transport_test.go b/internal/slurm/jwt_transport_test.go new file mode 100644 index 0000000..4750394 --- /dev/null +++ b/internal/slurm/jwt_transport_test.go @@ -0,0 +1,214 @@ +package slurm + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestJWTAuthTransport_SetsCorrectHeaders(t *testing.T) { + var gotHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + + tr := NewJWTAuthTransport("testuser", key) + client := tr.Client() + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if gotHeaders.Get("X-SLURM-USER-NAME") != "testuser" { + t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "testuser") + } + + tokenStr := gotHeaders.Get("X-SLURM-USER-TOKEN") + if tokenStr == "" { + t.Fatal("X-SLURM-USER-TOKEN is empty") + } + + parts := strings.Split(tokenStr, ".") + if len(parts) != 3 { + t.Fatalf("JWT should have 3 parts, got %d", len(parts)) + } + + claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("decode claims: %v", err) + } + + var claims JWTClaims + if err := json.Unmarshal(claimsJSON, &claims); err != nil { + t.Fatalf("unmarshal claims: %v", err) + } + + if claims.Sun != "testuser" { + t.Errorf("sun claim = %q, want %q", claims.Sun, "testuser") + } +} + +func TestJWTAuthTransport_AutoRefreshOnExpiry(t *testing.T) { + var tokens []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokens = append(tokens, r.Header.Get("X-SLURM-USER-TOKEN")) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + + tr := NewJWTAuthTransport("testuser", key, + WithTTL(1*time.Millisecond), + WithLeeway(0), + ) + + var callCount int + tr.tokenCache = NewTokenCache( + func(ctx context.Context) (string, error) { + callCount++ + return SignJWT(key, fmt.Sprintf("testuser-%d", callCount), 5*time.Minute) + }, + 1*time.Millisecond, + 0, + ) + + client := tr.Client() + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("first request: %v", err) + } + resp.Body.Close() + + time.Sleep(10 * time.Millisecond) + + req2, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp2, err := client.Do(req2) + if err != nil { + t.Fatalf("second request: %v", err) + } + resp2.Body.Close() + + if len(tokens) != 2 { + t.Fatalf("expected 2 tokens, got %d", len(tokens)) + } + + if tokens[0] == tokens[1] { + t.Error("token should have been refreshed after expiry, but got same token") + } + + for i, tok := range tokens { + parts := strings.Split(tok, ".") + if len(parts) != 3 { + t.Errorf("token[%d] should have 3 parts, got %d", i, len(parts)) + } + } +} + +func TestJWTAuthTransport_StaticTokenStillWorks(t *testing.T) { + var gotHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + tr := &TokenAuthTransport{ + UserName: "staticuser", + Token: "static-token-123", + } + client := tr.Client() + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if gotHeaders.Get("X-SLURM-USER-NAME") != "staticuser" { + t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "staticuser") + } + if gotHeaders.Get("X-SLURM-USER-TOKEN") != "static-token-123" { + t.Errorf("X-SLURM-USER-TOKEN = %q, want %q", gotHeaders.Get("X-SLURM-USER-TOKEN"), "static-token-123") + } +} + +func TestJWTAuthTransport_SigningError_ReturnsError(t *testing.T) { + expectedErr := errors.New("signing failed") + tr := NewJWTAuthTransport("testuser", nil) + tr.tokenCache = NewTokenCache( + func(ctx context.Context) (string, error) { return "", expectedErr }, + 30*time.Minute, + 30*time.Second, + ) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil) + _, err := tr.RoundTrip(req) + if err == nil { + t.Fatal("expected error from RoundTrip, got nil") + } + if !strings.Contains(err.Error(), "failed to get JWT token") { + t.Errorf("error = %q, want containing %q", err.Error(), "failed to get JWT token") + } +} + +func TestJWTAuthTransport_CustomBaseTransport(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Response", "from-base") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + + var custom http.RoundTripper = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + req.Header.Set("X-Custom-Request", "injected") + return http.DefaultTransport.RoundTrip(req) + }) + + tr := NewJWTAuthTransport("testuser", key, WithBaseTransport(custom)) + client := tr.Client() + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.Header.Get("X-Custom-Response") != "from-base" { + t.Error("custom base transport was not used") + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}