feat(auth): add JWTAuthTransport with auto-refresh RoundTrip
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
90
internal/slurm/jwt_transport.go
Normal file
90
internal/slurm/jwt_transport.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package slurm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type JWTTransportOption func(*JWTAuthTransport)
|
||||
|
||||
type JWTAuthTransport struct {
|
||||
UserName string
|
||||
key []byte
|
||||
|
||||
tokenCache *TokenCache
|
||||
ttl time.Duration
|
||||
leeway time.Duration
|
||||
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
func NewJWTAuthTransport(username string, key []byte, opts ...JWTTransportOption) *JWTAuthTransport {
|
||||
const (
|
||||
defaultTTL = 30 * time.Minute
|
||||
defaultLeeway = 30 * time.Second
|
||||
)
|
||||
|
||||
t := &JWTAuthTransport{
|
||||
UserName: username,
|
||||
key: key,
|
||||
ttl: defaultTTL,
|
||||
leeway: defaultLeeway,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
t.tokenCache = NewTokenCache(
|
||||
func(ctx context.Context) (string, error) {
|
||||
return SignJWT(t.key, t.UserName, t.ttl)
|
||||
},
|
||||
t.ttl,
|
||||
t.leeway,
|
||||
)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *JWTAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token, err := t.tokenCache.Token(req.Context())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get JWT token: %w", err)
|
||||
}
|
||||
|
||||
req2 := cloneRequest(req)
|
||||
req2.Header.Set("X-SLURM-USER-NAME", t.UserName)
|
||||
req2.Header.Set("X-SLURM-USER-TOKEN", token)
|
||||
return t.transport().RoundTrip(req2)
|
||||
}
|
||||
|
||||
func (t *JWTAuthTransport) Client() *http.Client {
|
||||
return &http.Client{Transport: t}
|
||||
}
|
||||
|
||||
func (t *JWTAuthTransport) transport() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func WithTTL(ttl time.Duration) JWTTransportOption {
|
||||
return func(t *JWTAuthTransport) {
|
||||
t.ttl = ttl
|
||||
}
|
||||
}
|
||||
|
||||
func WithLeeway(leeway time.Duration) JWTTransportOption {
|
||||
return func(t *JWTAuthTransport) {
|
||||
t.leeway = leeway
|
||||
}
|
||||
}
|
||||
|
||||
func WithBaseTransport(base http.RoundTripper) JWTTransportOption {
|
||||
return func(t *JWTAuthTransport) {
|
||||
t.Base = base
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user