package slurm import ( "context" "fmt" "net/http" "time" ) type JWTTransportOption func(*JWTAuthTransport) type JWTAuthTransport struct { UserName string key []byte tokenCache *TokenCache ttl time.Duration leeway time.Duration Base http.RoundTripper } func NewJWTAuthTransport(username string, key []byte, opts ...JWTTransportOption) *JWTAuthTransport { const ( defaultTTL = 30 * time.Minute defaultLeeway = 30 * time.Second ) t := &JWTAuthTransport{ UserName: username, key: key, ttl: defaultTTL, leeway: defaultLeeway, } for _, opt := range opts { opt(t) } t.tokenCache = NewTokenCache( func(ctx context.Context) (string, error) { return SignJWT(t.key, t.UserName, t.ttl) }, t.ttl, t.leeway, ) return t } func (t *JWTAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { token, err := t.tokenCache.Token(req.Context()) if err != nil { return nil, fmt.Errorf("failed to get JWT token: %w", err) } req2 := cloneRequest(req) req2.Header.Set("X-SLURM-USER-NAME", t.UserName) req2.Header.Set("X-SLURM-USER-TOKEN", token) return t.transport().RoundTrip(req2) } func (t *JWTAuthTransport) Client() *http.Client { return &http.Client{Transport: t} } func (t *JWTAuthTransport) transport() http.RoundTripper { if t.Base != nil { return t.Base } return http.DefaultTransport } func WithTTL(ttl time.Duration) JWTTransportOption { return func(t *JWTAuthTransport) { t.ttl = ttl } } func WithLeeway(leeway time.Duration) JWTTransportOption { return func(t *JWTAuthTransport) { t.leeway = leeway } } func WithBaseTransport(base http.RoundTripper) JWTTransportOption { return func(t *JWTAuthTransport) { t.Base = base } }