diff --git a/internal/slurm/jwt.go b/internal/slurm/jwt.go new file mode 100644 index 0000000..6ad9147 --- /dev/null +++ b/internal/slurm/jwt.go @@ -0,0 +1,62 @@ +package slurm + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "time" +) + +type JWTClaims struct { + Sun string `json:"sun"` + IAT int64 `json:"iat"` + EXP int64 `json:"exp"` +} + +func SignJWT(key []byte, username string, lifespan time.Duration) (string, error) { + if username == "" { + return "", fmt.Errorf("username must not be empty") + } + + now := time.Now() + header := map[string]string{"alg": "HS256", "typ": "JWT"} + claims := JWTClaims{ + Sun: username, + IAT: now.Unix(), + EXP: now.Add(lifespan).Unix(), + } + + headerJSON, err := json.Marshal(header) + if err != nil { + return "", fmt.Errorf("marshal header: %w", err) + } + claimsJSON, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal claims: %w", err) + } + + enc := base64.RawURLEncoding + headerEnc := enc.EncodeToString(headerJSON) + claimsEnc := enc.EncodeToString(claimsJSON) + + signingInput := headerEnc + "." + claimsEnc + mac := hmac.New(sha256.New, key) + mac.Write([]byte(signingInput)) + sig := enc.EncodeToString(mac.Sum(nil)) + + return signingInput + "." + sig, nil +} + +func ReadJWTKey(path string) ([]byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + if len(data) < 16 { + return nil, fmt.Errorf("key must be at least 16 bytes, got %d", len(data)) + } + return data, nil +} diff --git a/internal/slurm/jwt_test.go b/internal/slurm/jwt_test.go new file mode 100644 index 0000000..4b72323 --- /dev/null +++ b/internal/slurm/jwt_test.go @@ -0,0 +1,169 @@ +package slurm + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "os" + "strings" + "testing" + "time" +) + +func TestSignJWT_ValidKey_ProducesValidToken(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") // 32 bytes + username := "testuser" + lifespan := 1 * time.Hour + + token, err := SignJWT(key, username, lifespan) + if err != nil { + t.Fatalf("SignJWT returned error: %v", err) + } + + parts := strings.Split(token, ".") + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d", len(parts)) + } + + // Decode header + headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + t.Fatalf("decode header: %v", err) + } + var header map[string]string + if err := json.Unmarshal(headerBytes, &header); err != nil { + t.Fatalf("unmarshal header: %v", err) + } + if header["alg"] != "HS256" || header["typ"] != "JWT" { + t.Errorf("header = %v, want alg=HS256 typ=JWT", header) + } + + // Decode payload + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("decode payload: %v", err) + } + var claims JWTClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if claims.Sun != username { + t.Errorf("sun = %q, want %q", claims.Sun, username) + } + if claims.EXP-claims.IAT != 3600 { + t.Errorf("exp - iat = %d, want 3600", claims.EXP-claims.IAT) + } + + // Verify signature + signingInput := parts[0] + "." + parts[1] + mac := hmac.New(sha256.New, key) + mac.Write([]byte(signingInput)) + expectedSig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + if parts[2] != expectedSig { + t.Errorf("signature mismatch: got %q, want %q", parts[2], expectedSig) + } +} + +func TestSignJWT_UsesRawURLEncoding(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + token, err := SignJWT(key, "testuser", 1*time.Hour) + if err != nil { + t.Fatalf("SignJWT returned error: %v", err) + } + + for i, part := range strings.Split(token, ".") { + if strings.Contains(part, "=") { + t.Errorf("part %d contains padding '=': %q", i, part) + } + } +} + +func TestSignJWT_ExpiredToken(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + token, err := SignJWT(key, "testuser", -1*time.Hour) + if err != nil { + t.Fatalf("SignJWT returned error: %v", err) + } + + parts := strings.Split(token, ".") + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("decode payload: %v", err) + } + var claims JWTClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + + now := time.Now().Unix() + if claims.EXP >= now { + t.Errorf("exp = %d, expected < now (%d)", claims.EXP, now) + } +} + +func TestSignJWT_EmptyUsername_Error(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + _, err := SignJWT(key, "", 1*time.Hour) + if err == nil { + t.Fatal("expected error for empty username, got nil") + } +} + +func TestReadJWTKey_ValidFile(t *testing.T) { + keyData := make([]byte, 32) + for i := range keyData { + keyData[i] = byte(i) + } + + f, err := os.CreateTemp("", "jwtkey-*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + if _, err := f.Write(keyData); err != nil { + t.Fatal(err) + } + f.Close() + + got, err := ReadJWTKey(f.Name()) + if err != nil { + t.Fatalf("ReadJWTKey returned error: %v", err) + } + if string(got) != string(keyData) { + t.Errorf("key mismatch: got %d bytes, want %d bytes", len(got), len(keyData)) + } +} + +func TestReadJWTKey_FileNotFound(t *testing.T) { + _, err := ReadJWTKey("/nonexistent/path/keyfile") + if err == nil { + t.Fatal("expected error for nonexistent file, got nil") + } + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("error = %v, want os.ErrNotExist", err) + } +} + +func TestReadJWTKey_InvalidSize(t *testing.T) { + f, err := os.CreateTemp("", "jwtkey-*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + if _, err := f.Write([]byte("short")); err != nil { + t.Fatal(err) + } + f.Close() + + _, err = ReadJWTKey(f.Name()) + if err == nil { + t.Fatal("expected error for short key, got nil") + } + if !strings.Contains(err.Error(), "key must be") { + t.Errorf("error = %q, want message containing 'key must be'", err.Error()) + } +}