package slurm import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "os" "strings" "testing" "time" ) func TestSignJWT_ValidKey_ProducesValidToken(t *testing.T) { key := []byte("0123456789abcdef0123456789abcdef") // 32 bytes username := "testuser" lifespan := 1 * time.Hour token, err := SignJWT(key, username, lifespan) if err != nil { t.Fatalf("SignJWT returned error: %v", err) } parts := strings.Split(token, ".") if len(parts) != 3 { t.Fatalf("expected 3 parts, got %d", len(parts)) } // Decode header headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) if err != nil { t.Fatalf("decode header: %v", err) } var header map[string]string if err := json.Unmarshal(headerBytes, &header); err != nil { t.Fatalf("unmarshal header: %v", err) } if header["alg"] != "HS256" || header["typ"] != "JWT" { t.Errorf("header = %v, want alg=HS256 typ=JWT", header) } // Decode payload payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("decode payload: %v", err) } var claims JWTClaims if err := json.Unmarshal(payloadBytes, &claims); err != nil { t.Fatalf("unmarshal payload: %v", err) } if claims.Sun != username { t.Errorf("sun = %q, want %q", claims.Sun, username) } if claims.EXP-claims.IAT != 3600 { t.Errorf("exp - iat = %d, want 3600", claims.EXP-claims.IAT) } // Verify signature signingInput := parts[0] + "." + parts[1] mac := hmac.New(sha256.New, key) mac.Write([]byte(signingInput)) expectedSig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) if parts[2] != expectedSig { t.Errorf("signature mismatch: got %q, want %q", parts[2], expectedSig) } } func TestSignJWT_UsesRawURLEncoding(t *testing.T) { key := []byte("0123456789abcdef0123456789abcdef") token, err := SignJWT(key, "testuser", 1*time.Hour) if err != nil { t.Fatalf("SignJWT returned error: %v", err) } for i, part := range strings.Split(token, ".") { if strings.Contains(part, "=") { t.Errorf("part %d contains padding '=': %q", i, part) } } } func TestSignJWT_ExpiredToken(t *testing.T) { key := []byte("0123456789abcdef0123456789abcdef") token, err := SignJWT(key, "testuser", -1*time.Hour) if err != nil { t.Fatalf("SignJWT returned error: %v", err) } parts := strings.Split(token, ".") payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { t.Fatalf("decode payload: %v", err) } var claims JWTClaims if err := json.Unmarshal(payloadBytes, &claims); err != nil { t.Fatalf("unmarshal payload: %v", err) } now := time.Now().Unix() if claims.EXP >= now { t.Errorf("exp = %d, expected < now (%d)", claims.EXP, now) } } func TestSignJWT_EmptyUsername_Error(t *testing.T) { key := []byte("0123456789abcdef0123456789abcdef") _, err := SignJWT(key, "", 1*time.Hour) if err == nil { t.Fatal("expected error for empty username, got nil") } } func TestReadJWTKey_ValidFile(t *testing.T) { keyData := make([]byte, 32) for i := range keyData { keyData[i] = byte(i) } f, err := os.CreateTemp("", "jwtkey-*") if err != nil { t.Fatal(err) } defer os.Remove(f.Name()) if _, err := f.Write(keyData); err != nil { t.Fatal(err) } f.Close() got, err := ReadJWTKey(f.Name()) if err != nil { t.Fatalf("ReadJWTKey returned error: %v", err) } if string(got) != string(keyData) { t.Errorf("key mismatch: got %d bytes, want %d bytes", len(got), len(keyData)) } } func TestReadJWTKey_FileNotFound(t *testing.T) { _, err := ReadJWTKey("/nonexistent/path/keyfile") if err == nil { t.Fatal("expected error for nonexistent file, got nil") } if !errors.Is(err, os.ErrNotExist) { t.Errorf("error = %v, want os.ErrNotExist", err) } } func TestReadJWTKey_InvalidSize(t *testing.T) { f, err := os.CreateTemp("", "jwtkey-*") if err != nil { t.Fatal(err) } defer os.Remove(f.Name()) if _, err := f.Write([]byte("short")); err != nil { t.Fatal(err) } f.Close() _, err = ReadJWTKey(f.Name()) if err == nil { t.Fatal("expected error for short key, got nil") } if !strings.Contains(err.Error(), "key must be") { t.Errorf("error = %q, want message containing 'key must be'", err.Error()) } }