feat(jwt): add HS256 JWT signing with stdlib-only implementation
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
62
internal/slurm/jwt.go
Normal file
62
internal/slurm/jwt.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package slurm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JWTClaims struct {
|
||||||
|
Sun string `json:"sun"`
|
||||||
|
IAT int64 `json:"iat"`
|
||||||
|
EXP int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignJWT(key []byte, username string, lifespan time.Duration) (string, error) {
|
||||||
|
if username == "" {
|
||||||
|
return "", fmt.Errorf("username must not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
header := map[string]string{"alg": "HS256", "typ": "JWT"}
|
||||||
|
claims := JWTClaims{
|
||||||
|
Sun: username,
|
||||||
|
IAT: now.Unix(),
|
||||||
|
EXP: now.Add(lifespan).Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
headerJSON, err := json.Marshal(header)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal header: %w", err)
|
||||||
|
}
|
||||||
|
claimsJSON, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := base64.RawURLEncoding
|
||||||
|
headerEnc := enc.EncodeToString(headerJSON)
|
||||||
|
claimsEnc := enc.EncodeToString(claimsJSON)
|
||||||
|
|
||||||
|
signingInput := headerEnc + "." + claimsEnc
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
|
mac.Write([]byte(signingInput))
|
||||||
|
sig := enc.EncodeToString(mac.Sum(nil))
|
||||||
|
|
||||||
|
return signingInput + "." + sig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadJWTKey(path string) ([]byte, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(data) < 16 {
|
||||||
|
return nil, fmt.Errorf("key must be at least 16 bytes, got %d", len(data))
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
169
internal/slurm/jwt_test.go
Normal file
169
internal/slurm/jwt_test.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package slurm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignJWT_ValidKey_ProducesValidToken(t *testing.T) {
|
||||||
|
key := []byte("0123456789abcdef0123456789abcdef") // 32 bytes
|
||||||
|
username := "testuser"
|
||||||
|
lifespan := 1 * time.Hour
|
||||||
|
|
||||||
|
token, err := SignJWT(key, username, lifespan)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignJWT returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode header
|
||||||
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode header: %v", err)
|
||||||
|
}
|
||||||
|
var header map[string]string
|
||||||
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||||
|
t.Fatalf("unmarshal header: %v", err)
|
||||||
|
}
|
||||||
|
if header["alg"] != "HS256" || header["typ"] != "JWT" {
|
||||||
|
t.Errorf("header = %v, want alg=HS256 typ=JWT", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode payload
|
||||||
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode payload: %v", err)
|
||||||
|
}
|
||||||
|
var claims JWTClaims
|
||||||
|
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
||||||
|
t.Fatalf("unmarshal payload: %v", err)
|
||||||
|
}
|
||||||
|
if claims.Sun != username {
|
||||||
|
t.Errorf("sun = %q, want %q", claims.Sun, username)
|
||||||
|
}
|
||||||
|
if claims.EXP-claims.IAT != 3600 {
|
||||||
|
t.Errorf("exp - iat = %d, want 3600", claims.EXP-claims.IAT)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
signingInput := parts[0] + "." + parts[1]
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
|
mac.Write([]byte(signingInput))
|
||||||
|
expectedSig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||||
|
if parts[2] != expectedSig {
|
||||||
|
t.Errorf("signature mismatch: got %q, want %q", parts[2], expectedSig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignJWT_UsesRawURLEncoding(t *testing.T) {
|
||||||
|
key := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
token, err := SignJWT(key, "testuser", 1*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignJWT returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, part := range strings.Split(token, ".") {
|
||||||
|
if strings.Contains(part, "=") {
|
||||||
|
t.Errorf("part %d contains padding '=': %q", i, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignJWT_ExpiredToken(t *testing.T) {
|
||||||
|
key := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
token, err := SignJWT(key, "testuser", -1*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignJWT returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode payload: %v", err)
|
||||||
|
}
|
||||||
|
var claims JWTClaims
|
||||||
|
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
||||||
|
t.Fatalf("unmarshal payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if claims.EXP >= now {
|
||||||
|
t.Errorf("exp = %d, expected < now (%d)", claims.EXP, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignJWT_EmptyUsername_Error(t *testing.T) {
|
||||||
|
key := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
_, err := SignJWT(key, "", 1*time.Hour)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty username, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadJWTKey_ValidFile(t *testing.T) {
|
||||||
|
keyData := make([]byte, 32)
|
||||||
|
for i := range keyData {
|
||||||
|
keyData[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.CreateTemp("", "jwtkey-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
|
||||||
|
if _, err := f.Write(keyData); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
got, err := ReadJWTKey(f.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadJWTKey returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != string(keyData) {
|
||||||
|
t.Errorf("key mismatch: got %d bytes, want %d bytes", len(got), len(keyData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadJWTKey_FileNotFound(t *testing.T) {
|
||||||
|
_, err := ReadJWTKey("/nonexistent/path/keyfile")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent file, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
t.Errorf("error = %v, want os.ErrNotExist", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadJWTKey_InvalidSize(t *testing.T) {
|
||||||
|
f, err := os.CreateTemp("", "jwtkey-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
|
||||||
|
if _, err := f.Write([]byte("short")); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
_, err = ReadJWTKey(f.Name())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for short key, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "key must be") {
|
||||||
|
t.Errorf("error = %q, want message containing 'key must be'", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user