package slurm import ( "context" "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" ) func TestJWTAuthTransport_SetsCorrectHeaders(t *testing.T) { var gotHeaders http.Header srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotHeaders = r.Header.Clone() w.WriteHeader(http.StatusOK) })) defer srv.Close() key := make([]byte, 32) for i := range key { key[i] = byte(i) } tr := NewJWTAuthTransport("testuser", key) client := tr.Client() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) resp, err := client.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if gotHeaders.Get("X-SLURM-USER-NAME") != "testuser" { t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "testuser") } tokenStr := gotHeaders.Get("X-SLURM-USER-TOKEN") if tokenStr == "" { t.Fatal("X-SLURM-USER-TOKEN is empty") } parts := strings.Split(tokenStr, ".") if len(parts) != 3 { t.Fatalf("JWT should have 3 parts, got %d", len(parts)) } claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("decode claims: %v", err) } var claims JWTClaims if err := json.Unmarshal(claimsJSON, &claims); err != nil { t.Fatalf("unmarshal claims: %v", err) } if claims.Sun != "testuser" { t.Errorf("sun claim = %q, want %q", claims.Sun, "testuser") } } func TestJWTAuthTransport_AutoRefreshOnExpiry(t *testing.T) { var tokens []string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokens = append(tokens, r.Header.Get("X-SLURM-USER-TOKEN")) w.WriteHeader(http.StatusOK) })) defer srv.Close() key := make([]byte, 32) for i := range key { key[i] = byte(i) } tr := NewJWTAuthTransport("testuser", key, WithTTL(1*time.Millisecond), WithLeeway(0), ) var callCount int tr.tokenCache = NewTokenCache( func(ctx context.Context) (string, error) { callCount++ return SignJWT(key, fmt.Sprintf("testuser-%d", callCount), 5*time.Minute) }, 1*time.Millisecond, 0, ) client := tr.Client() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) resp, err := client.Do(req) if err != nil { t.Fatalf("first request: %v", err) } resp.Body.Close() time.Sleep(10 * time.Millisecond) req2, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) resp2, err := client.Do(req2) if err != nil { t.Fatalf("second request: %v", err) } resp2.Body.Close() if len(tokens) != 2 { t.Fatalf("expected 2 tokens, got %d", len(tokens)) } if tokens[0] == tokens[1] { t.Error("token should have been refreshed after expiry, but got same token") } for i, tok := range tokens { parts := strings.Split(tok, ".") if len(parts) != 3 { t.Errorf("token[%d] should have 3 parts, got %d", i, len(parts)) } } } func TestJWTAuthTransport_StaticTokenStillWorks(t *testing.T) { var gotHeaders http.Header srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotHeaders = r.Header.Clone() w.WriteHeader(http.StatusOK) })) defer srv.Close() tr := &TokenAuthTransport{ UserName: "staticuser", Token: "static-token-123", } client := tr.Client() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) resp, err := client.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if gotHeaders.Get("X-SLURM-USER-NAME") != "staticuser" { t.Errorf("X-SLURM-USER-NAME = %q, want %q", gotHeaders.Get("X-SLURM-USER-NAME"), "staticuser") } if gotHeaders.Get("X-SLURM-USER-TOKEN") != "static-token-123" { t.Errorf("X-SLURM-USER-TOKEN = %q, want %q", gotHeaders.Get("X-SLURM-USER-TOKEN"), "static-token-123") } } func TestJWTAuthTransport_SigningError_ReturnsError(t *testing.T) { expectedErr := errors.New("signing failed") tr := NewJWTAuthTransport("testuser", nil) tr.tokenCache = NewTokenCache( func(ctx context.Context) (string, error) { return "", expectedErr }, 30*time.Minute, 30*time.Second, ) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil) _, err := tr.RoundTrip(req) if err == nil { t.Fatal("expected error from RoundTrip, got nil") } if !strings.Contains(err.Error(), "failed to get JWT token") { t.Errorf("error = %q, want containing %q", err.Error(), "failed to get JWT token") } } func TestJWTAuthTransport_CustomBaseTransport(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Custom-Response", "from-base") w.WriteHeader(http.StatusOK) })) defer srv.Close() key := make([]byte, 32) for i := range key { key[i] = byte(i) } var custom http.RoundTripper = roundTripperFunc(func(req *http.Request) (*http.Response, error) { req.Header.Set("X-Custom-Request", "injected") return http.DefaultTransport.RoundTrip(req) }) tr := NewJWTAuthTransport("testuser", key, WithBaseTransport(custom)) client := tr.Client() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) resp, err := client.Do(req) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.Header.Get("X-Custom-Response") != "from-base" { t.Error("custom base transport was not used") } } type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }