Files
hpc/internal/slurm/jwt_transport.go
2026-04-09 10:31:39 +08:00

91 lines
1.7 KiB
Go

package slurm
import (
"context"
"fmt"
"net/http"
"time"
)
type JWTTransportOption func(*JWTAuthTransport)
type JWTAuthTransport struct {
UserName string
key []byte
tokenCache *TokenCache
ttl time.Duration
leeway time.Duration
Base http.RoundTripper
}
func NewJWTAuthTransport(username string, key []byte, opts ...JWTTransportOption) *JWTAuthTransport {
const (
defaultTTL = 30 * time.Minute
defaultLeeway = 30 * time.Second
)
t := &JWTAuthTransport{
UserName: username,
key: key,
ttl: defaultTTL,
leeway: defaultLeeway,
}
for _, opt := range opts {
opt(t)
}
t.tokenCache = NewTokenCache(
func(ctx context.Context) (string, error) {
return SignJWT(t.key, t.UserName, t.ttl)
},
t.ttl,
t.leeway,
)
return t
}
func (t *JWTAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.tokenCache.Token(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to get JWT token: %w", err)
}
req2 := cloneRequest(req)
req2.Header.Set("X-SLURM-USER-NAME", t.UserName)
req2.Header.Set("X-SLURM-USER-TOKEN", token)
return t.transport().RoundTrip(req2)
}
func (t *JWTAuthTransport) Client() *http.Client {
return &http.Client{Transport: t}
}
func (t *JWTAuthTransport) transport() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func WithTTL(ttl time.Duration) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.ttl = ttl
}
}
func WithLeeway(leeway time.Duration) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.leeway = leeway
}
}
func WithBaseTransport(base http.RoundTripper) JWTTransportOption {
return func(t *JWTAuthTransport) {
t.Base = base
}
}