Compare commits

...

27 Commits

Author SHA1 Message Date
dailz
503e4dbc22 feat(portal): independent WebRTC thread + channel tuning for 60fps mouse latency
- Move WebRTC send to dedicated wl-webrtc-webrtc thread (was inline in main loop)
- Reduce frame_rx 16→1, input_tx 2→1 (drop-on-full), webrtc_tx 32→2
- Recv_timeout 10ms→2ms to reduce pipeline latency
- Fix sent_gap_p95 stats bug: compute gap at actual send time in WebRTC
  thread instead of batch-draining at snapshot time (was always 0.0ms)
- High profile via AVCodecContext.profile, veryfast preset, 5x bitrate
- Stats drain via sent_gap channel with record_send_from_thread()
- Shutdown: drop input_tx → join encode → drop webrtc_tx → join webrtc
2026-06-07 18:30:09 +08:00
dailz
caccfec44e fix(portal): compositor stall detection + filler frames + PipeWire state logging
P0: Detect compositor frame delivery stalls (>100ms no frames) and log
    stall/resume events with duration. Rate-limited to 1 warn/sec.

P1: Insert duplicate raw CpuNv12Frame filler during stalls at target fps.
    Keeps WebRTC stream smooth (sent_fps 20-40 instead of 3-5 during
    compositor pauses). Stops after 2s max stale. WebRTC mode only.

P2: Replace silent _ => {} in PipeWire state_changed callback with
    explicit Paused/Streaming/Connecting log messages.

P4: Add PwCtrlEvent::FormatChanged for mid-stream dimension changes.
    param_changed detects resolution renegotiation (skips first call).
    Logs warning in poll_and_encode; full encoder reinit deferred.

Verified: cargo check 0 errors, 70/70 tests, release build, --stats live.
2026-06-07 17:20:54 +08:00
dailz
826f544569 feat(portal): async encode pipeline - decouple capture from encoding
Split synchronous encode pipeline so sws_scale + libx264 runs on a
dedicated thread, leaving only VAAPI import + GPU scale + GPU→CPU
transfer on the main capture thread.

Problem: encode_p95 occasionally hit 74ms, blocking the entire capture
pipeline and causing capture_gap_max=356ms stutter.

Solution:
- avhw.rs: Split SwEncState into SwEncImport (main thread: VAAPI import,
  filter_graph scale, GPU→CPU transfer) and SwEncEncode (encode thread:
  sws_scale NV12→YUV420P, libx264 encode). New CpuNv12Frame struct
  carries owned pixel data across threads via crossbeam channel.
  SwEncState wraps both for backward compat (MP4/sync path untouched).
- state_portal.rs: WebRTC portal path spawns 'wl-webrtc-encode' thread
  with bounded(2) input channel (drop-newest backpressure) and separate
  timing channel. Graceful shutdown: drop webrtc_rx → drop input_tx →
  join encode thread → flush sync encoder.
- stats.rs: Add record_import() + record_encode_thread() for async timing.

Results: encode_p95 stable at 2.9-4.2ms (was 11-74ms), capture_fps
stable 59-60fps, cap_gap_p95 17-19ms. Remaining capture stalls traced
to PipeWire compositor frame delivery (external, not our code).
2026-06-07 16:55:28 +08:00
dailz
aae030f309 fix(webrtc): SO_SNDBUF 2MB + VBV rate limiting + stats integration
P0 - UDP send buffer: set SO_SNDBUF=2MB to prevent EAGAIN on large IDR
frames (218KB/256KB keyframes caused 18+ EAGAIN bursts). Actual Linux
buffer 4096KB confirmed.

P1 - VBV rate limiting: cap rc_max_rate=bitrate and rc_buffer_size=
bitrate/4 for WebRTC encode path, preventing oversized IDR frames.

Stats: integrate PipelineStats into cap_portal (dropped_count), state.rs
(wlroots path), webrtc.rs (browser getStats enhancement + stats panel).
2026-06-07 16:55:07 +08:00
dailz
029fe13e37 feat(stats): add --stats flag and PipelineStats windowed diagnostics
Add lightweight per-second pipeline statistics for stutter diagnosis:
- --stats CLI flag enables structured stats logging
- PipelineStats tracks capture/encode/send timing with p95/pmax
- FrameTimings records import/scale/transfer/sws/encode per-frame
- StatsSnapshot produces one structured log line per second
2026-06-07 16:54:45 +08:00
dailz
f3da1e4e6c fix(webrtc): propagate poll_output error as cleanup signal to prevent zombie state (closes #14) 2026-06-06 21:48:38 +08:00
dailz
e6e05fb44a fix(webrtc): fix is_idr_nalu boundary bug missing tail NAL units (closes #13) 2026-06-06 21:34:22 +08:00
dailz
8b04893ceb fix(security): remove error details from HTTP 500 response (#12)
The 500 error response previously included the raw error message {e}
in the body, potentially leaking internal implementation details (SDP
parse errors, ICE candidate info) to clients.

The detailed error is already logged server-side via tracing::error!,
so the response body is now a fixed generic string with a proper
HTTP/1.1 status line.
2026-06-06 21:22:57 +08:00
dailz
1beaea8088 fix(webrtc): use MediaAdded event to discover video mid instead of hardcoded iteration (closes #11) 2026-06-06 21:16:55 +08:00
dailz
fc4733ffe8 fix: return Ok(true) on ICE Disconnected to prevent resource leak
poll_rtc() always returned Ok(false), preventing WebRtcState from
clearing self.inner on disconnect. This leaked the UDP socket, Rtc
instance, and 65KB buffer permanently if the client never reconnected.

Closes #10
2026-06-06 20:57:25 +08:00
dailz
d5679be3a4 fix(state_portal): replace expect() with bail-style error propagation (closes #9) 2026-06-06 20:19:51 +08:00
dailz
36f07c92e9 fix(state_portal): prevent shutdown deadlock on full bounded channel (closes #8)
shutdown() calls enc.flush() → drain_encoder() → tx.send() on a
crossbeam bounded(32) channel.  If the channel is full and the
receiver (webrtc_rx) is alive but not being drained, send() blocks
forever — a self-deadlock since both ends belong to the same struct.

Two-layer fix:
- avhw.rs: replace tx.send() with tx.try_send(); handle Full (drop
  frame) and Disconnected (set flag) separately.
- state_portal.rs: drop webrtc_rx before flushing in shutdown() so
  try_send returns Disconnected immediately.

Regression tests added for the channel semantics.
2026-06-06 20:02:09 +08:00
dailz
7c1c9b2e19 fix(avhw): add SAFETY comments to all undocumented unsafe blocks
Close #7

- Add // SAFETY: comments to 19 undocumented unsafe blocks and impls
- Add nb_streams/null guard on stream array dereference (drain_encoder)
- Add clippy undocumented_unsafe_blocks = warn lint to prevent regression

avhw.rs now has 0 clippy unsafe documentation warnings.
2026-06-06 15:54:09 +08:00
dailz
226768c3e3 fix(avhw): handle tx.send() failure and pause encoding on WebRTC disconnect (closes #6)
- Replace 'let _ = tx.send()' with proper error handling: log warning,
  set webrtc_disconnected flag, and break drain loop on SendError
- Add Arc<AtomicBool> webrtc_paused shared between State/StatePortal
  and SwEncState, synced from wrtc.is_connected() in poll_webrtc()
- Skip encoding in encode_filtered_frame() when paused or disconnected
- Drain and discard stale channel frames on disconnect
- Resume encoding automatically on WebRTC reconnection
2026-06-06 15:12:49 +08:00
dailz
fd170b66d9 fix(unsafe): add SAFETY comment and runtime guards for from_raw_parts in drain_encoder
Issue: #5

- Read AVPacket fields into local variable to avoid repeated pointer deref
- Guard against size <= 0 (prevents c_int negative wrap to huge usize)
- Guard against null data pointer (from_raw_parts(null, 0) is UB in Rust)
- Add SAFETY comment matching existing codebase convention (30+ instances)
2026-06-06 11:56:47 +08:00
dailz
9a5b09cd7f fix(security): harden token file permissions (closes #2)
- save_restore_token: use create_new(true) + mode(0o600) for exclusive
  atomic file creation, preventing symlink attacks and predictable
  temp file exploitation
- token_path: return Option, eliminate insecure /tmp fallback
- load_restore_token: reject insecure files (symlinks, wrong owner,
  group/world-readable permissions)
- Directory creation uses DirBuilderExt::mode(0o700) bypassing umask
- Added verify_secure_dir and ensure_secure_parent with full metadata
  validation (owner, permissions, symlink rejection)
- Added 11 regression tests covering all security scenarios
2026-06-06 11:05:00 +08:00
dailz
46367ef6b5 fix(state): add WebRTC support to wlr-screencopy backend
Fixes #1 -- --port mode with wlr-screencopy backend caused panic at
negotiate_format() because self.args.output is None and .expect() was
called unconditionally.

Changes:
- Introduce StreamingEncoder enum wrapping EncState (MP4) and
  SwEncState (WebRTC) with unified frames_rgb/encode_frame/flush API
- Add WebRTC fields to State<S> (webrtc, webrtc_tx, webrtc_rx,
  webrtc_frames_sent) matching Portal backend pattern
- State::new() returns Result<Self> for clean WebRtcState init failure
- negotiate_format() branches on webrtc_tx: WebRTC path uses
  SwEncState::new_webrtc(), MP4 path unchanged (hardware VAAPI)
- Add poll_webrtc() method to drive signaling + channel drain
- Event loop calls poll_webrtc() each iteration
- Fix pre-existing test/bench Args construction (Option<String> output,
  missing no_persist field)
2026-06-04 22:10:46 +08:00
dailz
b0ed6548a6 feat: add WebRTC streaming via str0m + portal session persistence
- Add src/webrtc.rs: HTTP signaling server + str0m Sans-IO WebRTC transport
  with H.264 Annex-B → RTP packetization and key-frame request handling
- avhw: introduce FrameOutput enum (Muxer | Channel) so SwEncState can
  output to either MP4 muxer or crossbeam channel for WebRTC
- cap_portal: support portal session restore tokens (PersistMode::ExplicitlyRevoked)
  to skip re-authorization dialog; add --no-persist flag to force fresh dialog
- args: make --output optional when --port is used for WebRTC mode
- state_portal: integrate WebRTC pipeline (encoder channel → RTP forwarding)
  with shorter GOP for WebRTC (fps/2, min 10)
- main: redirect tracing to stderr; validate --output or --port required
- Add dependencies: str0m 0.20, serde_json 1, dirs 6
2026-06-04 20:54:16 +08:00
dailz
74f4dc826d perf(portal): achieve 58-60fps PipeWire screen capture
- Force PipeWire quantum=512 via NODE_FORCE_QUANTUM (48000/512=93Hz scheduling)
- Switch to libx264 ultrafast/zerolatency with 6 threads
- Use two-phase poll_and_encode: blocking recv_timeout for first frame,
  non-blocking try_recv drain for subsequent frames
- Remove fps_limit from portal path (PW already rate-limits via quantum/KWin;
  fps_limit's min_interval was silently dropping ~10% of valid frames)
- Remove diagnostic instrumentation (TIMING/PIPEWIRE logs, timing fields,
  pw_stats counters)
- Add lightweight production stats: per-10s fps log + shutdown summary
- Prefer libx264 over libopenh264 (better quality at same speed)
2026-05-30 08:44:15 +08:00
dailz
a83d146ed3 fix: FPS limiter never passes frames when input > target rate
The old FpsLimit compared timestamps between CONSECUTIVE frames.
When PipeWire delivers at 60fps (16ms intervals) and target is 30fps
(33ms min_interval), the gap between consecutive frames is always
16ms < 33ms, so EVERY frame was rejected after the first.

Fix: track last_output_time and compare against that instead of the
previous frame's timestamp. Now frames pass when enough time has
elapsed since the last OUTPUT, not since the last INPUT.

Also adds PipeWire process callback counter logging and frame
diagnostic STATS in state_portal.rs for debugging.
2026-05-29 22:09:35 +08:00
dailz
d80b34f44f feat: GPU-downscale + software H.264 encode pipeline (WIP)
Add SwEncState in avhw.rs: GPU pipeline using scale_vaapi to downscale
4K BGRA -> 2K NV12 on AMD iGPU, then software encode with libopenh264.

- import_dma_buf_to_vaapi: av_hwframe_map based DMA-BUF import
- SwEncState: GPU filter graph (scale_vaapi) + NV12->YUV420P + libopenh264
- state_portal.rs: integrated SwEncState, auto DRM device detection
- vaapi_import_bench.rs: CPU vs GPU pipeline benchmark
- sw_encode_bench.rs: software encode benchmark

Benchmark results: GPU pipeline ~91 FPS theoretical (10.95ms/frame)
vs CPU pipeline ~33 FPS (30.21ms/frame).

Known issue: only 1 frame encoded in production recording,
diagnostic STATS logging added to debug frame flow.
2026-05-29 22:04:12 +08:00
dailz
55abb5e56d fix(backend_detect): use raw zbus for portal check to avoid OnceLock connection poisoning
ashpd caches zbus::Connection in a global OnceLock. When check_portal_available()
created a Screencast proxy, the connection was cached there. When the function
returned and its tokio Runtime dropped, the cached connection became dead.
Subsequent setup_portal() calls reused this dead connection and hung forever.

Fix: replace ashpd Screencast proxy with direct zbus D-Bus interface check,
which does not touch the ashpd global connection cache.

Add examples/test_portal.rs for minimal Portal ScreenCast testing.
2026-05-27 22:07:11 +08:00
dailz
715a9c0bab refactor(cap_portal): split PwEvent into separate ctrl/frame channels
- Rename PwEvent to PwCtrlEvent, separate frame data into its own channel
- Add null chunk check to prevent crash on malformed PipeWire buffer
- Remove redundant inline comments and signal handlers
- Use try_send for error events to avoid blocking on full channel
2026-05-27 09:25:00 +08:00
dailz
60a55c17f2 fix(state_portal): add Drop impl, null dangling pointers, extract compute_pts, add tests
- Add Drop impl for StatePortal to flush encoder on drop (bug #2)
- Use enc.take() in shutdown() to prevent double-flush of write_trailer
- Null out data[0] after Box::from_raw recovery to avoid dangling pointer
- Extract compute_pts() for testable PTS calculation
- Add 8 tests: PTS calculation, DRM device resolution, descriptor building
2026-05-27 09:22:59 +08:00
dailz
5100d78aa8 fix: resolve SHM hang, DRM device mismatch, and duplicate VAAPI context
BUG-2 (HIGH): SHM Buffer event caused permanent hang
  In the ZwlrScreencopyFrameV1 dispatcher, receiving a SHM Buffer event
  left in_flight_surface stuck at AllocQueued forever, preventing
  queue_alloc_frame() from requesting new frames.
  Fix: treat Buffer as a metadata offer (v3 protocol), wait for
  BufferDone to decide failure, and add AllocQueued state guard to
  LinuxDmabuf handler.

BUG-3 (MEDIUM): Portal backend picked wrong GPU on multi-GPU systems
  state_portal.rs hardcoded /dev/dri/renderD128 then renderD129, which
  selects the wrong GPU when PipeWire uses a different device.
  Fix: extract find_drm_render_nodes() as shared utility; defer DRM
  device selection to first PipeWire frame; test each candidate with
  av_hwframe_transfer_data to find the GPU that can actually import
  the DMA-BUF frame.

BUG-4 (LOW): VAAPI device context created twice unnecessarily
  try_finalize_output() created an AvHwDevCtx stored in EverythingButFmt,
  but negotiate_format() discarded it (_hw_device_ctx) and EncState::new
  created a new one.
  Fix: thread the existing hw_device_ctx through negotiate_format() and
  create_encoder() to EncState::new() which reuses it when provided.
2026-05-25 14:32:58 +08:00
dailz
460a3ee711 fix(cap_portal): remove unsafe pw::deinit() to prevent global state corruption
pw::init() is guarded by an internal OnceCell (process-global one-shot).
pw::deinit() is unsafe and requires 'only called once per process lifetime
after all PipeWire use has permanently stopped'. Since CapPortal can be
created/destroyed multiple times, calling deinit() from a function-local
scope would prevent re-initialization (OnceCell already consumed) and
violate the unsafe contract.

The 5 early-return error paths in pipewire_thread() that previously
leaked global state are now consistent with the success path — neither
calls pw::deinit(). Process exit reclaims global PipeWire state.
2026-05-25 14:32:19 +08:00
dailz
b8026981d2 feat(examples): add Wayland globals lister utility
Minimal example that connects to the Wayland compositor and prints
all advertised globals (interface name, ID, version).
2026-05-25 08:56:55 +08:00
18 changed files with 6471 additions and 610 deletions

4
.gitignore vendored
View File

@@ -17,3 +17,7 @@ Thumbs.db
# Sisyphus orchestration artifacts
.sisyphus/
.omo/
.playwright-mcp/
wl-webrtc.log
webrtc-p0-success.png

809
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -14,13 +14,23 @@ signal-hook = "0.3"
signal-hook-mio = { version = "0.2", features = ["support-v1_0"] }
clap = { version = "4", features = ["derive"] }
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
anyhow = "1"
drm = "0.12"
drm-fourcc = "2"
libc = "0.2"
ashpd = { version = "0.13", features = ["tokio", "screencast"] }
zbus = { version = "5", default-features = false, features = ["tokio"] }
tokio = { version = "1", features = ["rt"] }
pipewire = "0.9"
pipewire = { version = "0.9", features = ["v0_3_45"] }
libspa = "0.9"
crossbeam-channel = "0.5"
str0m = "0.20"
serde_json = "1"
dirs = "6"
[dev-dependencies]
tempfile = "3.27.0"
[lints.clippy]
undocumented_unsafe_blocks = "warn"

26
examples/list_globals.rs Normal file
View File

@@ -0,0 +1,26 @@
use wayland_client::globals::registry_queue_init;
use wayland_client::globals::GlobalListContents;
use wayland_client::protocol::wl_registry::{Event, WlRegistry};
use wayland_client::{Connection, Dispatch, QueueHandle};
struct Ls;
impl Dispatch<WlRegistry, GlobalListContents> for Ls {
fn event(
_state: &mut Self,
_registry: &WlRegistry,
_event: Event,
_data: &GlobalListContents,
_conn: &Connection,
_qhandle: &QueueHandle<Self>,
) {
}
}
fn main() {
let conn = Connection::connect_to_env().unwrap();
let (globals, _queue) = registry_queue_init::<Ls>(&conn).unwrap();
for g in globals.contents().clone_list() {
println!("{}: {} v{}", g.name, g.interface, g.version);
}
}

68
examples/test_portal.rs Normal file
View File

@@ -0,0 +1,68 @@
use ashpd::desktop::screencast::{CursorMode, Screencast, SelectSourcesOptions, SourceType};
use ashpd::desktop::PersistMode;
use ashpd::enumflags2::BitFlags;
fn main() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
eprintln!("1. Creating Screencast proxy...");
let proxy = match Screencast::new().await {
Ok(p) => {
eprintln!(" OK");
p
}
Err(e) => {
eprintln!(" FAIL: {e}");
return;
}
};
eprintln!("2. Creating session...");
let session = match proxy.create_session(Default::default()).await {
Ok(s) => {
eprintln!(" OK");
s
}
Err(e) => {
eprintln!(" FAIL: {e}");
return;
}
};
eprintln!("3. Selecting sources...");
let sources: BitFlags<SourceType> = SourceType::Monitor.into();
let result = proxy
.select_sources(
&session,
SelectSourcesOptions::default()
.set_cursor_mode(CursorMode::Embedded)
.set_sources(sources)
.set_multiple(false)
.set_persist_mode(PersistMode::DoNot),
)
.await;
match result {
Ok(_) => eprintln!(" OK"),
Err(e) => {
eprintln!(" FAIL: {e}");
return;
}
}
eprintln!("4. Starting (should show dialog)...");
let response = match proxy.start(&session, None, Default::default()).await {
Ok(r) => {
eprintln!(" OK");
r
}
Err(e) => {
eprintln!(" FAIL: {e}");
return;
}
};
match response.response() {
Ok(r) => eprintln!(" Got {} stream(s)", r.streams().len()),
Err(e) => eprintln!(" Response error: {e}"),
}
});
}

View File

@@ -3,9 +3,9 @@ use clap::Parser;
#[derive(Parser, Debug, Clone)]
#[command(name = "wl-webrtc", about = "Wayland screen capture and encoding tool")]
pub struct Args {
/// Output file path (e.g., output.mp4, output.mkv)
/// Output file path (e.g., output.mp4, output.mkv). Optional when using --port for WebRTC mode
#[arg(short, long)]
pub output: String,
pub output: Option<String>,
/// Wayland output name to capture
#[arg(long)]
@@ -43,7 +43,15 @@ pub struct Args {
#[arg(long)]
pub backend: Option<String>,
/// Port for WebTransport server (Phase 2, unused in MVP)
/// Port for WebRTC HTTP signaling server; 0 keeps MP4 file output mode
#[arg(long, default_value_t = 0)]
pub port: u16,
/// Force re-authorization dialog (ignore saved portal restore token)
#[arg(long)]
pub no_persist: bool,
/// Enable per-second pipeline statistics output for stutter diagnosis
#[arg(long)]
pub stats: bool,
}

File diff suppressed because it is too large Load Diff

View File

@@ -37,11 +37,10 @@ impl Dispatch<WlRegistry, GlobalListContents> for RegistryLs {
}
}
// 通过 D-Bus 检测 XDG Desktop Portal 的 ScreenCast 接口是否可用
// 尝试创建 Screencast proxy如果 Portal 服务未运行则返回 false
// CAUTION: must NOT use ashpd here — ashpd caches zbus::Connection in a global
// OnceLock; if the tokio runtime owning that connection is dropped before
// setup_portal() runs, the cached connection becomes dead and hangs forever.
fn check_portal_available() -> bool {
use ashpd::desktop::screencast::Screencast;
let rt = match tokio::runtime::Runtime::new() {
Ok(rt) => rt,
Err(e) => {
@@ -51,30 +50,43 @@ fn check_portal_available() -> bool {
};
rt.block_on(async {
let proxy = match Screencast::new().await {
Ok(p) => p,
let conn = match zbus::Connection::session().await {
Ok(c) => c,
Err(e) => {
tracing::info!("Portal not available: {e}");
tracing::info!("D-Bus session bus unavailable: {e}");
return false;
}
};
// Verify the portal actually exposes ScreenCast capabilities,
// not just that the D-Bus service is running.
match proxy.available_source_types().await {
Ok(types) if !types.is_empty() => {
tracing::info!("Portal ScreenCast available (source types: {types:?})");
let inner: zbus::Proxy = match zbus::proxy::Builder::new(&conn)
.destination("org.freedesktop.portal.Desktop")
.and_then(|b| b.path("/org/freedesktop/portal/desktop"))
.and_then(|b| b.interface("org.freedesktop.portal.ScreenCast"))
{
Ok(b) => match b.build().await {
Ok(p) => p,
Err(e) => {
tracing::info!("Portal ScreenCast interface not available: {e}");
return false;
}
},
Err(e) => {
tracing::info!("Portal ScreenCast proxy build failed: {e}");
return false;
}
};
let version = match inner.get_property::<u32>("version").await {
Ok(version) => {
tracing::info!("Portal ScreenCast available (version: {version})");
true
}
Ok(types) => {
tracing::info!("Portal ScreenCast proxy exists but no source types available ({types:?})");
false
}
Err(e) => {
tracing::info!("Portal ScreenCast available_source_types query failed: {e}");
tracing::info!("Portal ScreenCast version query failed: {e}");
false
}
}
};
version
})
}
@@ -125,10 +137,7 @@ pub fn detect_backend(args: &Args) -> Result<CaptureBackend> {
}
other => {
// 未知后端名称,返回错误
anyhow::bail!(
"Unknown backend '{}'. Use 'screencopy' or 'portal'.",
other
);
anyhow::bail!("Unknown backend '{}'. Use 'screencopy' or 'portal'.", other);
}
};
}
@@ -169,7 +178,7 @@ mod tests {
// 测试辅助函数:构造指定后端参数的 Args 实例
fn make_args(backend: Option<&str>) -> Args {
Args {
output: "test.mp4".to_string(),
output: Some("test.mp4".to_string()),
output_name: None,
fps: 30,
codec: "h264".to_string(),
@@ -180,6 +189,8 @@ mod tests {
verbose: false,
backend: backend.map(String::from),
port: 0,
no_persist: false,
stats: false,
}
}

549
src/bin/sw_encode_bench.rs Normal file
View File

@@ -0,0 +1,549 @@
// sw_encode_bench.rs — Software encoding pipeline benchmark for screen capture
//
// Benchmarks: Portal capture -> mmap DMA-BUF -> sws_scale BGR0->YUV420P -> libx264 encode
//
// Usage: cargo run --bin sw_encode_bench -- --output /tmp/bench_test.mp4
use std::ffi::CString;
use std::os::fd::AsRawFd;
use std::path::Path;
use std::ptr;
use std::time::Instant;
use anyhow::{bail, Result};
use clap::Parser;
use ffmpeg_next as ff;
use ffmpeg_next::ffi;
use ffmpeg_next::packet::Mut;
use wl_webrtc::args::Args;
use wl_webrtc::cap_portal::{CapPortal, PwCtrlEvent};
#[derive(Parser, Debug)]
#[command(
name = "sw_encode_bench",
about = "Software encoding pipeline benchmark"
)]
struct BenchArgs {
#[arg(short, long)]
output: String,
#[arg(long, default_value_t = 120)]
frames: u32,
#[arg(long, default_value_t = 2560)]
enc_width: u32,
#[arg(long, default_value_t = 1440)]
enc_height: u32,
}
#[derive(Default)]
struct FrameStats {
mmap_us: Vec<u64>,
scale_us: Vec<u64>,
encode_us: Vec<u64>,
total_us: Vec<u64>,
mmap_failures: u32,
}
impl FrameStats {
fn avg_ms(data: &[u64]) -> f64 {
if data.is_empty() {
return 0.0;
}
data.iter().sum::<u64>() as f64 / data.len() as f64 / 1000.0
}
}
fn pix_fmt(p: ff::format::Pixel) -> ffi::AVPixelFormat {
Into::<ffi::AVPixelFormat>::into(p)
}
fn receive_first_frame(cap: &CapPortal) -> Result<wl_webrtc::cap_portal::PwDmaBufFrame> {
loop {
if let Ok(ctrl) = cap.event_receiver().try_recv() {
match ctrl {
PwCtrlEvent::StreamEnded => bail!("PipeWire stream ended before first frame"),
PwCtrlEvent::FormatChanged { .. } => {}
PwCtrlEvent::Error(e) => bail!("PipeWire error: {e}"),
}
}
match cap
.frame_receiver()
.recv_timeout(std::time::Duration::from_secs(10))
{
Ok(frame) => return Ok(frame),
Err(crossbeam_channel::RecvTimeoutError::Timeout) => {
bail!("Timeout waiting for first frame (10s)");
}
Err(crossbeam_channel::RecvTimeoutError::Disconnected) => {
bail!("PipeWire frame channel disconnected");
}
}
}
}
fn main() -> Result<()> {
let bench_args = BenchArgs::parse();
println!("=== Software Encode Benchmark ===");
println!("Output: {}", bench_args.output);
println!("Target frames: {}", bench_args.frames);
println!(
"Encode resolution: {}x{}",
bench_args.enc_width, bench_args.enc_height
);
println!();
ff::init()?;
println!("[1/4] Requesting screen capture via XDG Portal...");
println!(" (Select a screen to share in the portal dialog)");
let portal_args = Args {
output: Some(bench_args.output.clone()),
output_name: None,
fps: 60,
codec: "h264".to_string(),
hw_accel: "vaapi".to_string(),
drm_device: None,
bitrate: None,
gop_size: None,
verbose: false,
backend: Some("portal".to_string()),
port: 0,
no_persist: false,
stats: false,
};
let cap = CapPortal::new(&portal_args)?;
println!("[1/4] Portal connected, PipeWire stream active\n");
println!("[2/4] Waiting for first frame from PipeWire...");
let first_frame = receive_first_frame(&cap)?;
let src_width = first_frame.width;
let src_height = first_frame.height;
let src_stride = first_frame.stride;
let enc_width = bench_args.enc_width;
let enc_height = bench_args.enc_height;
println!(
"[2/4] First frame: {}x{}, stride={}, format=0x{:08X}",
src_width, src_height, src_stride, first_frame.format
);
println!(
" Capture: {}x{} Encode: {}x{}\n",
src_width, src_height, enc_width, enc_height
);
println!("[3/4] Testing mmap on DMA-BUF...");
let mmap_size = (src_stride as usize) * (src_height as usize);
let mmap_ptr = unsafe {
libc::mmap(
ptr::null_mut(),
mmap_size,
libc::PROT_READ,
libc::MAP_SHARED,
first_frame.fd.as_raw_fd(),
first_frame.offset as i64,
)
};
if mmap_ptr == libc::MAP_FAILED {
let errno = std::io::Error::last_os_error();
bail!(
"mmap on DMA-BUF fd FAILED — AMD driver may not support \
CPU read of screen capture DMA-BUF buffers.\n\
Error: {} (errno={})\n\
\n\
Workarounds:\n\
1. Use VAAPI hardware import (av_hwframe_map) instead of mmap\n\
2. Use wlroots compositor with wlr-screencopy (SHM-based)\n\
3. Use a virtual display or software renderer",
errno,
errno.raw_os_error().unwrap_or(-1)
);
}
println!(
"[3/4] mmap SUCCESS — CPU can read DMA-BUF ({:.1} MB)\n",
mmap_size as f64 / 1024.0 / 1024.0
);
unsafe {
libc::munmap(mmap_ptr, mmap_size);
}
drop(first_frame);
// Set up libx264 encoder via FFI (same pattern as avhw.rs)
println!("[4/4] Setting up libx264 encoder...");
let output_path = Path::new(&bench_args.output);
let output_cstr = CString::new(output_path.to_str().unwrap())?;
// Try libx264 first (best quality/speed), fall back to openh264
let codec = ff::encoder::find_by_name("libx264")
.or_else(|| ff::encoder::find_by_name("libopenh264"))
.ok_or_else(|| {
anyhow::anyhow!("No H.264 software encoder found (tried libx264, libopenh264)")
})?;
println!("[4/4] Using encoder: {}\n", codec.name());
let mut enc = {
let ctx = ff::codec::Context::new_with_codec(codec);
ctx.encoder().video()?
};
enc.set_width(enc_width);
enc.set_height(enc_height);
enc.set_format(ff::format::Pixel::YUV420P);
enc.set_time_base(ff::Rational::new(1, 60));
enc.set_max_b_frames(0);
enc.set_gop(60);
let codec_name = codec.name();
if codec_name == "libx264" {
unsafe {
let key = CString::new("preset").unwrap();
let val = CString::new("veryfast").unwrap();
ffi::av_opt_set((*enc.as_mut_ptr()).priv_data, key.as_ptr(), val.as_ptr(), 0);
let key = CString::new("tune").unwrap();
let val = CString::new("zerolatency").unwrap();
ffi::av_opt_set((*enc.as_mut_ptr()).priv_data, key.as_ptr(), val.as_ptr(), 0);
}
}
let opened = enc.open()?;
let mut enc_video = opened.0;
// Create output format context via FFI
let mut fmt_ctx_ptr: *mut ffi::AVFormatContext = ptr::null_mut();
let ret = unsafe {
ffi::avformat_alloc_output_context2(
&mut fmt_ctx_ptr,
ptr::null_mut(),
ptr::null(),
output_cstr.as_ptr(),
)
};
if ret < 0 || fmt_ctx_ptr.is_null() {
bail!("Failed to allocate output format context: error {ret}");
}
let stream_ptr = unsafe { ffi::avformat_new_stream(fmt_ctx_ptr, ptr::null()) };
if stream_ptr.is_null() {
bail!("Failed to create new stream");
}
let ret =
unsafe { ffi::avcodec_parameters_from_context((*stream_ptr).codecpar, enc_video.as_ptr()) };
if ret < 0 {
bail!("Failed to copy encoder parameters: error {ret}");
}
unsafe {
(*stream_ptr).time_base = (*enc_video.as_ptr()).time_base;
}
let ret = unsafe {
ffi::avio_open(
&mut (*fmt_ctx_ptr).pb,
output_cstr.as_ptr(),
ffi::AVIO_FLAG_WRITE,
)
};
if ret < 0 {
bail!(
"Failed to open output file '{}': error {ret}",
output_path.display()
);
}
let ret = unsafe { ffi::avformat_write_header(fmt_ctx_ptr, ptr::null_mut()) };
if ret < 0 {
bail!("Failed to write header: error {ret}");
}
let mut octx = unsafe { ff::format::context::Output::wrap(fmt_ctx_ptr) };
// Create sws_scale context: BGRZ (BGR0) -> YUV420P
let bgr0_fmt = pix_fmt(ff::format::Pixel::BGRZ);
let yuv420p_fmt = pix_fmt(ff::format::Pixel::YUV420P);
let sws_ctx = unsafe {
ffi::sws_getContext(
src_width as i32,
src_height as i32,
bgr0_fmt,
enc_width as i32,
enc_height as i32,
yuv420p_fmt,
2,
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
)
};
if sws_ctx.is_null() {
bail!("Failed to create sws_scale context");
}
// Allocate reusable YUV frame
let mut yuv_frame = unsafe {
let mut f = ffi::av_frame_alloc();
if f.is_null() {
bail!("av_frame_alloc failed");
}
(*f).width = enc_width as i32;
(*f).height = enc_height as i32;
(*f).format = yuv420p_fmt as i32;
let ret = ffi::av_frame_get_buffer(f, 0);
if ret < 0 {
ffi::av_frame_free(&mut f);
bail!("av_frame_get_buffer failed: {ret}");
}
f
};
println!(
"[4/4] Encoder ready: {}, {}x{}\n",
codec_name, enc_width, enc_height
);
println!("=== Encoding {} frames ===\n", bench_args.frames);
let mut stats = FrameStats::default();
let total_start = Instant::now();
let mut frames_encoded: u32 = 0;
let mut pts: i64 = 0;
while frames_encoded < bench_args.frames {
if let Ok(ctrl) = cap.event_receiver().try_recv() {
match ctrl {
PwCtrlEvent::StreamEnded => {
eprintln!("PipeWire stream ended after {} frames", frames_encoded);
break;
}
PwCtrlEvent::Error(e) => {
eprintln!("PipeWire error after {} frames: {}", frames_encoded, e);
break;
}
PwCtrlEvent::FormatChanged { .. } => {}
}
}
let frame = match cap
.frame_receiver()
.recv_timeout(std::time::Duration::from_secs(5))
{
Ok(f) => f,
Err(_) => {
eprintln!("Frame timeout/disconnect after {} frames", frames_encoded);
break;
}
};
let frame_start = Instant::now();
let mmap_start = Instant::now();
let frame_size = (frame.stride as usize) * (frame.height as usize);
let mmap_ptr = unsafe {
libc::mmap(
ptr::null_mut(),
frame_size,
libc::PROT_READ,
libc::MAP_SHARED,
frame.fd.as_raw_fd(),
frame.offset as i64,
)
};
if mmap_ptr == libc::MAP_FAILED {
stats.mmap_failures += 1;
eprintln!("mmap failed on frame {}", frames_encoded);
drop(frame);
continue;
}
stats.mmap_us.push(mmap_start.elapsed().as_micros() as u64);
let scale_start = Instant::now();
let src_data = unsafe { std::slice::from_raw_parts(mmap_ptr as *const u8, frame_size) };
unsafe {
ffi::av_frame_make_writable(yuv_frame);
let src_ptr = src_data.as_ptr();
let src_linesize = frame.stride as i32;
ffi::sws_scale(
sws_ctx,
&src_ptr as *const *const u8,
&src_linesize as *const i32,
0,
frame.height as i32,
(*yuv_frame).data.as_ptr() as *mut *mut u8,
(*yuv_frame).linesize.as_ptr() as *mut i32,
);
}
stats
.scale_us
.push(scale_start.elapsed().as_micros() as u64);
unsafe {
libc::munmap(mmap_ptr, frame_size);
}
drop(frame);
let encode_start = Instant::now();
unsafe {
(*yuv_frame).pts = pts;
pts += 1;
let ret = ffi::avcodec_send_frame(enc_video.as_mut_ptr(), yuv_frame);
if ret < 0 {
eprintln!("avcodec_send_frame failed: {ret}");
continue;
}
}
drain_encoder(&mut enc_video, &mut octx)?;
stats
.encode_us
.push(encode_start.elapsed().as_micros() as u64);
stats
.total_us
.push(frame_start.elapsed().as_micros() as u64);
frames_encoded += 1;
if frames_encoded % 30 == 0 {
let fps = frames_encoded as f64 / total_start.elapsed().as_secs_f64();
println!(
" [{}/{}] {:.1} FPS",
frames_encoded, bench_args.frames, fps
);
}
}
let total_elapsed = total_start.elapsed();
println!("\nFlushing encoder...");
unsafe {
ffi::avcodec_send_frame(enc_video.as_mut_ptr(), ptr::null());
}
drain_encoder(&mut enc_video, &mut octx)?;
octx.write_trailer()
.map_err(|e| anyhow::anyhow!("Failed to write trailer: {e}"))?;
// Cleanup
unsafe {
ffi::av_frame_free(&mut yuv_frame as *mut _);
ffi::sws_freeContext(sws_ctx);
}
drop(cap);
// Print results
let mmap_count = stats.mmap_us.len() as u32;
let mmap_success_rate = if mmap_count + stats.mmap_failures > 0 {
mmap_count as f64 / (mmap_count + stats.mmap_failures) as f64 * 100.0
} else {
0.0
};
let total_fps = frames_encoded as f64 / total_elapsed.as_secs_f64();
let avg_total_ms = FrameStats::avg_ms(&stats.total_us);
let max_fps = if avg_total_ms > 0.0 {
1000.0 / avg_total_ms
} else {
0.0
};
println!();
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ Software Encode Benchmark Results ║");
println!("╚══════════════════════════════════════════════════════════════╝");
println!();
println!("Capture resolution: {}x{}", src_width, src_height);
println!("Encode resolution: {}x{}", enc_width, enc_height);
println!("Frames encoded: {}", frames_encoded);
println!("Total time: {:.2}s", total_elapsed.as_secs_f64());
println!();
println!("mmap (DMA-BUF -> CPU):");
println!(
" avg: {:.2} ms/frame",
FrameStats::avg_ms(&stats.mmap_us)
);
println!(
" success rate: {:.1}% ({}/{})",
mmap_success_rate,
mmap_count,
mmap_count + stats.mmap_failures
);
println!();
println!("scale (BGR0 -> YUV420P via sws_scale):");
println!(
" avg: {:.2} ms/frame",
FrameStats::avg_ms(&stats.scale_us)
);
println!();
println!("encode ({}):", codec_name);
println!(
" avg: {:.2} ms/frame",
FrameStats::avg_ms(&stats.encode_us)
);
println!();
println!("total pipeline:");
println!(" avg: {:.2} ms/frame", avg_total_ms);
println!(" achieved FPS: {:.1}", total_fps);
println!(" max theoretical: {:.1} FPS", max_fps);
println!();
if mmap_success_rate < 100.0 {
println!(
"WARNING: Some mmap operations failed ({}/{})",
stats.mmap_failures,
stats.mmap_failures + mmap_count
);
}
if total_fps < 30.0 {
println!(
"NOTE: Achieved FPS ({:.1}) is below 30 FPS target.",
total_fps
);
}
println!("Output written to: {}", bench_args.output);
Ok(())
}
fn drain_encoder(
enc_video: &mut ff::encoder::video::Video,
octx: &mut ff::format::context::Output,
) -> Result<()> {
loop {
let mut pkt = ff::Packet::empty();
let ret = unsafe { ffi::avcodec_receive_packet(enc_video.as_mut_ptr(), pkt.as_mut_ptr()) };
if ret < 0 {
if ret == ffi::AVERROR(ffi::EAGAIN) || ret == ffi::AVERROR_EOF {
break;
}
eprintln!("avcodec_receive_packet failed: {ret}");
break;
}
let enc_tb = enc_video.time_base();
let stream_tb = unsafe {
let streams = (*octx.as_ptr()).streams;
let st = *streams.add(0);
ff::Rational::from((*st).time_base)
};
pkt.rescale_ts(enc_tb, stream_tb);
pkt.set_stream(0);
pkt.write_interleaved(octx)
.map_err(|e| anyhow::anyhow!("write packet failed: {e}"))?;
}
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@@ -12,11 +12,13 @@
// - crossbeam-channel: 高性能有界通道,用于线程间帧传递
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use anyhow::Result;
use crossbeam_channel::{Receiver, Sender, bounded};
use crossbeam_channel::{bounded, Receiver, Sender};
use tokio::runtime::Runtime;
use crate::args::Args;
@@ -45,15 +47,15 @@ pub struct PwDmaBufFrame {
pub pts: i64,
}
/// PipeWire 事件枚举
/// PipeWire 控制事件枚举
///
/// 从 PipeWire 捕获线程发送给消费者的事件类型
/// 消费者通过 frame_receiver() 获取的 Receiver 接收这些事件
pub enum PwEvent {
/// 收到一帧新的 DMA-BUF 视频帧
Frame(PwDmaBufFrame),
/// 从 PipeWire 捕获线程发送给消费者的控制事件。
/// 与帧数据分离,通过独立的 channel 传输,确保控制事件不被帧数据淹没
pub enum PwCtrlEvent {
/// 流已结束PipeWire 流断开连接或进入错误状态)
StreamEnded,
/// Format/dimensions changed mid-stream
FormatChanged { width: u32, height: u32 },
/// 发生错误,包含错误描述信息
Error(String),
}
@@ -68,14 +70,12 @@ pub enum PwEvent {
/// 2. frame_receiver() — 获取帧接收端,供消费者轮询
/// 3. Drop — 通过 eventfd 通知 PipeWire 线程安全退出
pub struct CapPortal {
/// eventfd 的写入端,用于在 drop 时通知 PipeWire 线程退出
shutdown_fd: OwnedFd,
/// 帧事件接收端,消费者通过此 Receiver 获取帧数据
frame_rx: Receiver<PwEvent>,
/// PipeWire 捕获线程的 JoinHandledrop 时等待线程退出
frame_rx: Receiver<PwDmaBufFrame>,
event_rx: Receiver<PwCtrlEvent>,
pw_thread: Option<JoinHandle<()>>,
/// Tokio 运行时,仅用于 setup_portal() 中的异步 Portal 调用
rt: Runtime,
pw_dropped: Arc<AtomicU64>,
}
/// PipeWire 捕获线程的上下文数据
@@ -83,17 +83,12 @@ pub struct CapPortal {
/// 从主线程传递给 PipeWire 捕获线程的所有必要资源。
/// 该结构体在线程创建时一次性 move 到线程中使用。
struct PwThreadCtx {
/// 帧事件发送端,用于向消费者线程发送帧数据或错误/结束事件
frame_tx: Sender<PwEvent>,
/// 已丢弃帧的计数器(原子操作),用于统计因通道满而丢弃的帧数
dropped: AtomicU64,
/// eventfd 的读取端,注册到 PipeWire 事件循环中,用于接收关闭信号
frame_tx: Sender<PwDmaBufFrame>,
event_tx: Sender<PwCtrlEvent>,
dropped: Arc<AtomicU64>,
shutdown_read: OwnedFd,
/// Portal 返回的 PipeWire 远程连接文件描述符
pw_fd: OwnedFd,
/// Portal 返回的 PipeWire 节点 ID标识要捕获的屏幕流
node_id: u32,
/// 目标帧率(当前保留,未直接用于 PipeWire 协商)
fps: u32,
}
@@ -103,27 +98,18 @@ impl CapPortal {
/// 执行流程:
/// 1. 创建 Tokio 运行时(用于异步 Portal 调用)
/// 2. 通过 XDG Desktop Portal 请求屏幕录制权限,获取 PipeWire fd 和 node_id
/// 3. 创建有界通道(容量 3)用于帧传递
/// 3. 创建有界通道(容量 1)用于帧传递(最新帧优先,避免队列积压延迟)
/// 4. 创建 eventfd 对,用于线程安全的关闭信号传递
/// 5. 启动 PipeWire 捕获线程
pub fn new(args: &Args) -> Result<Self> {
// 创建独立的 Tokio 运行时,仅用于 setup_portal 中的异步 Portal D-Bus 调用
let rt = Runtime::new()?;
// 通过 Portal 获取 PipeWire 连接 fd 和节点 ID
// block_on 在此处同步等待异步 Portal 调用完成
let (pw_fd, node_id) = rt.block_on(async {
Self::setup_portal().await
})?;
let no_persist = args.no_persist;
let (pw_fd, node_id) = rt.block_on(async { Self::setup_portal(no_persist).await })?;
// 创建有界通道,容量为 3 帧
// 使用有界通道实现背压:当消费者处理不过来时,生产者会丢弃帧而非无限堆积
let (frame_tx, frame_rx) = bounded(3);
let (frame_tx, frame_rx) = bounded(1);
let (event_tx, event_rx) = bounded(8);
// 创建 eventfd 对,用于线程安全的关闭信号传递
// eventfd 是 Linux 内核提供的轻量级进程/线程间通知机制
// 写入端保存在 CapPortal主线程读取端注册到 PipeWire 事件循环中
// 这样 CapPortal drop 时可以安全地通知 PipeWire 线程退出
let efd = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) };
if efd < 0 {
return Err(anyhow::anyhow!(
@@ -131,8 +117,6 @@ impl CapPortal {
std::io::Error::last_os_error()
));
}
// 复制 eventfd 得到写入端,原始 fd 作为读取端
// 需要 dup 是因为读取端和写入端需要各自独立的 OwnedFd 所有权
let write_fd = unsafe { libc::dup(efd) };
if write_fd < 0 {
let err = std::io::Error::last_os_error();
@@ -140,38 +124,56 @@ impl CapPortal {
return Err(anyhow::anyhow!("dup eventfd failed: {err}"));
}
// 构建 PipeWire 线程上下文,将所有必要资源 move 进去
let pw_dropped = Arc::new(AtomicU64::new(0));
let ctx = PwThreadCtx {
frame_tx,
dropped: AtomicU64::new(0),
event_tx,
dropped: pw_dropped.clone(),
shutdown_read: unsafe { OwnedFd::from_raw_fd(efd) },
pw_fd,
node_id,
fps: args.fps,
};
// 启动 PipeWire 捕获线程,命名便于调试和性能分析
let pw_thread = thread::Builder::new()
.name("pipewire-capture".into())
.spawn(move || {
pipewire_thread(ctx);
})
.map_err(|e| {
unsafe { libc::close(write_fd) };
anyhow::anyhow!("thread spawn failed: {e}")
})?;
Ok(Self {
shutdown_fd: unsafe { OwnedFd::from_raw_fd(write_fd) },
frame_rx,
event_rx,
pw_thread: Some(pw_thread),
rt,
pw_dropped,
})
}
/// 获取帧事件接收端的引用
///
/// 消费者通过此方法获取 Receiver然后不断接收 PwEvent 事件来获取帧数据。
pub fn frame_receiver(&self) -> &Receiver<PwEvent> {
pub fn frame_receiver(&self) -> &Receiver<PwDmaBufFrame> {
&self.frame_rx
}
pub fn event_receiver(&self) -> &Receiver<PwCtrlEvent> {
&self.event_rx
}
/// Returns the total number of PipeWire frames dropped due to channel backlog.
pub fn dropped_count(&self) -> u64 {
self.pw_dropped.load(Ordering::Relaxed)
}
/// Returns the number of frames currently waiting in the capture channel.
pub fn capture_queue_depth(&self) -> usize {
self.frame_rx.len()
}
/// 通过 XDG Desktop Portal 建立屏幕录制会话
///
/// 与桌面环境的 D-Bus 服务交互,请求用户授权屏幕录制。
@@ -183,44 +185,50 @@ impl CapPortal {
/// 5. 打开 PipeWire 远程连接,获取文件描述符
///
/// 返回 (PipeWire fd, node_id),供 PipeWire 线程连接使用
async fn setup_portal() -> Result<(OwnedFd, u32)> {
async fn setup_portal(no_persist: bool) -> Result<(OwnedFd, u32)> {
use ashpd::desktop::screencast::{
CursorMode, Screencast, SelectSourcesOptions, SourceType,
};
use ashpd::desktop::PersistMode;
// 创建 Screencast D-Bus 代理,与桌面环境的 Portal 服务通信
let proxy = Screencast::new().await.map_err(|e| {
anyhow::anyhow!("Failed to create Screencast proxy: {e}")
})?;
let proxy = Screencast::new()
.await
.map_err(|e| anyhow::anyhow!("Failed to create Screencast proxy: {e}"))?;
// 创建 ScreenCast 会话(每个会话对应一次屏幕录制请求)
let session = proxy
.create_session(Default::default())
.await
.map_err(|e| anyhow::anyhow!("Failed to create ScreenCast session: {e}"))?;
// 配置录制源选择参数:
// - CursorMode::Embedded: 光标嵌入到帧数据中(而非单独的元数据)
// - SourceType::Monitor: 仅捕获显示器(不捕获窗口)
// - multiple: false: 不允许多源选择
// - PersistMode::DoNot: 不持久化会话(每次需要重新授权)
proxy
.select_sources(
&session,
SelectSourcesOptions::default()
let version_supported = proxy.version() >= 4;
let (persist_mode, saved_token) = if !no_persist && version_supported {
let token = load_restore_token();
if token.is_some() {
tracing::info!("Attempting to restore portal session with saved token");
}
(PersistMode::ExplicitlyRevoked, token)
} else {
(PersistMode::DoNot, None)
};
let mut options = SelectSourcesOptions::default()
.set_cursor_mode(CursorMode::Embedded)
.set_sources(ashpd::enumflags2::BitFlags::from(SourceType::Monitor))
.set_multiple(false)
.set_persist_mode(PersistMode::DoNot),
)
.set_persist_mode(persist_mode);
if let Some(ref token) = saved_token {
options = options.set_restore_token(token.as_str());
}
proxy
.select_sources(&session, options)
.await
.map_err(|e| {
anyhow::anyhow!("屏幕共享权限被拒绝 / Screen sharing permission denied: {e}")
anyhow::anyhow!("Screen sharing permission denied: {e}")
})?;
// 启动录制会话,此时桌面环境会弹出权限确认对话框
// 用户确认后返回包含 PipeWire 流信息的响应
let response = proxy
.start(&session, None, Default::default())
.await
@@ -228,18 +236,19 @@ impl CapPortal {
.response()
.map_err(|e| anyhow::anyhow!("ScreenCast response error: {e}"))?;
// 获取返回的第一个(也是唯一的)视频流
// 每个流对应一个 PipeWire 节点
if !no_persist && version_supported {
if let Some(new_token) = response.restore_token() {
save_restore_token(new_token);
}
}
let stream = response
.streams()
.first()
.ok_or_else(|| anyhow::anyhow!("No streams returned from ScreenCast"))?;
// 提取 PipeWire 节点 ID用于后续连接到该节点的视频流
let node_id = stream.pipe_wire_node_id();
// 打开 PipeWire 远程连接,获取文件描述符
// 这个 fd 允许直接与 PipeWire 守护进程通信
let fd = proxy
.open_pipe_wire_remote(&session, Default::default())
.await
@@ -251,6 +260,165 @@ impl CapPortal {
}
}
fn token_path() -> Option<PathBuf> {
dirs::cache_dir().map(|base| base.join("wl-webrtc").join("portal-restore-token"))
}
/// Verify that `path` is a directory owned by the current user with no group/other permissions.
/// Rejects symlinks at the path itself (but allows the resolved target to be a real dir).
fn verify_secure_dir(path: &std::path::Path) -> bool {
use std::os::unix::fs::{MetadataExt, PermissionsExt};
match std::fs::symlink_metadata(path) {
Ok(meta) => {
if meta.file_type().is_symlink() {
tracing::warn!("Token parent dir is a symlink, rejecting: {}", path.display());
return false;
}
// Must be a directory
if !meta.is_dir() {
tracing::warn!("Token parent path is not a directory: {}", path.display());
return false;
}
// Must be owned by current user
if meta.uid() != unsafe { libc::getuid() } {
tracing::warn!("Token parent dir not owned by current user: {}", path.display());
return false;
}
// No group or other permissions (mode must be 0o700 exactly within the 0o777 mask)
let mode = meta.permissions().mode() & 0o777;
if mode != 0o700 {
tracing::warn!(
"Token parent dir has insecure permissions {:o}, expected 0700: {}",
mode,
path.display()
);
return false;
}
true
}
Err(e) => {
tracing::warn!("Failed to stat token parent dir: {e}");
false
}
}
}
/// Ensure the parent directory exists with restrictive permissions (0o700).
/// Returns false if the directory could not be created or is insecure.
fn ensure_secure_parent(parent: &std::path::Path) -> bool {
use std::os::unix::fs::{DirBuilderExt, PermissionsExt};
if parent.exists() {
// Directory exists — try to tighten permissions, then verify.
// set_permissions follows symlinks, which is fine here since
// we verify with symlink_metadata in verify_secure_dir.
if let Err(e) = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700)) {
tracing::warn!("Failed to set directory permissions: {e}");
return false;
}
return verify_secure_dir(parent);
}
// Create with restrictive mode — DirBuilderExt::mode bypasses umask.
let mut builder = std::fs::DirBuilder::new();
builder.recursive(true);
builder.mode(0o700);
if let Err(e) = builder.create(parent) {
tracing::warn!("Failed to create token directory: {e}");
return false;
}
// Verify after creation (belt-and-suspenders)
verify_secure_dir(parent)
}
fn load_restore_token() -> Option<String> {
load_restore_token_from(token_path()?)
}
fn load_restore_token_from(path: PathBuf) -> Option<String> {
use std::os::unix::fs::{MetadataExt, PermissionsExt};
let meta = match std::fs::symlink_metadata(&path) {
Ok(m) => m,
Err(_) => return None,
};
if meta.file_type().is_symlink() {
tracing::warn!("Token file is a symlink, refusing to read: {}", path.display());
return None;
}
if !meta.is_file() {
tracing::warn!("Token path is not a regular file: {}", path.display());
return None;
}
if meta.uid() != unsafe { libc::getuid() } {
tracing::warn!("Token file not owned by current user: {}", path.display());
return None;
}
let mode = meta.permissions().mode() & 0o777;
if mode & 0o077 != 0 {
tracing::warn!(
"Token file has insecure permissions {:o}, refusing to read: {}",
mode,
path.display()
);
return None;
}
let token = std::fs::read_to_string(&path).ok()?;
let trimmed = token.trim().to_string();
if trimmed.is_empty() { None } else { Some(trimmed) }
}
fn save_restore_token(token: &str) {
let Some(path) = token_path() else {
tracing::warn!("No secure cache directory available, skipping token save");
return;
};
save_restore_token_to(token, &path);
}
fn save_restore_token_to(token: &str, path: &std::path::Path) {
use std::fs::OpenOptions;
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
let Some(parent) = path.parent() else {
tracing::warn!("Token path has no parent directory");
return;
};
if !ensure_secure_parent(parent) {
tracing::warn!("Parent directory is insecure, refusing to save token");
return;
}
// Use a unique temp file to prevent symlink attacks.
// create_new(true) guarantees exclusive creation — fails if file already exists,
// and does NOT follow existing symlinks.
let tmp_path = path.with_extension(format!("{}.tmp", std::process::id()));
let result = (|| -> std::io::Result<()> {
let mut f = OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(&tmp_path)?;
f.write_all(token.as_bytes())?;
f.sync_all()?;
std::fs::rename(&tmp_path, path)?;
Ok(())
})();
match result {
Ok(()) => tracing::info!("Saved portal restore token"),
Err(e) => {
let _ = std::fs::remove_file(&tmp_path);
tracing::warn!("Failed to save restore token: {e}");
}
}
}
impl Drop for CapPortal {
/// 析构时安全关闭 PipeWire 线程
///
@@ -296,52 +464,55 @@ impl Drop for CapPortal {
fn pipewire_thread(ctx: PwThreadCtx) {
use pipewire as pw;
use pw::properties::properties;
use pw::spa::param::video::VideoInfoRaw;
use pw::stream::{StreamBox, StreamFlags};
use std::cell::Cell;
use std::rc::Rc;
use pw::spa::param::video::VideoInfoRaw;
// 初始化 PipeWire 库,必须在任何 PipeWire 操作之前调用
// 初始化 PipeWire 进程全局库。
//
// pipewire-rs 内部使用 OnceCell 保护 pw::init(),确保只调用一次。
// pw::deinit() 是 unsafe 且要求"进程生命周期内仅调用一次,且所有
// PipeWire 使用已停止"。由于 CapPortal 可被多次创建销毁,此函数
// 不调用 pw::deinit()——进程退出时全局状态由 OS 回收。
pw::init();
// 解构上下文,取出所有必要资源
// fps 重命名为 _fps 表示当前未使用(保留供将来帧率控制使用)
let PwThreadCtx {
frame_tx,
event_tx,
dropped,
shutdown_read,
pw_fd,
node_id,
fps: _fps,
fps,
} = ctx;
// 创建 PipeWire MainLoop主事件循环
// MainLoopBox 是栈分配的 PipeWire 主循环封装
let mainloop = match pw::main_loop::MainLoopBox::new(None) {
Ok(ml) => ml,
Err(e) => {
let _ = frame_tx.send(PwEvent::Error(format!("MainLoop::new failed: {e}")));
if let Err(e) = event_tx.try_send(PwCtrlEvent::Error(format!("MainLoop::new failed: {e}"))) {
tracing::error!("MainLoop::new failed and error channel also failed: {e}");
}
return;
}
};
// 创建 PipeWire Context用于管理核心对象和协议处理
let context = match pw::context::ContextBox::new(mainloop.loop_(), None) {
Ok(c) => c,
Err(e) => {
let _ = frame_tx.send(PwEvent::Error(format!("Context::new failed: {e}")));
if let Err(e) = event_tx.try_send(PwCtrlEvent::Error(format!("Context::new failed: {e}"))) {
tracing::error!("Context::new failed and error channel also failed: {e}");
}
return;
}
};
// 使用 Portal 提供的 fd 连接到 PipeWire 核心守护进程
// connect_fd 接管该 fd 的所有权(通过 dup不关闭原始 fd
let core = match context.connect_fd(pw_fd, None) {
Ok(c) => c,
Err(e) => {
let _ = frame_tx.send(PwEvent::Error(format!(
"connect_fd failed: {e}"
)));
if let Err(e) = event_tx.try_send(PwCtrlEvent::Error(format!("connect_fd failed: {e}"))) {
tracing::error!("connect_fd failed and error channel also failed: {e}");
}
return;
}
};
@@ -358,39 +529,40 @@ fn pipewire_thread(ctx: PwThreadCtx) {
*pw::keys::MEDIA_TYPE => "Video",
*pw::keys::MEDIA_CATEGORY => "Capture",
*pw::keys::MEDIA_ROLE => "Screen",
*pw::keys::NODE_FORCE_QUANTUM => "512",
},
) {
Ok(s) => s,
Err(e) => {
let _ = frame_tx.send(PwEvent::Error(format!("Stream::new failed: {e}")));
if let Err(e) = event_tx.try_send(PwCtrlEvent::Error(format!("Stream::new failed: {e}"))) {
tracing::error!("Stream::new failed and error channel also failed: {e}");
}
return;
}
};
// 共享的格式状态: (宽度, 高度, DRM FourCC 格式, 修饰符)
// 使用 Rc<Cell<>> 因为 PipeWire 回调在同一个线程内执行,无需跨线程同步
// Cell<Option<...>> 允许在不可变引用中修改值(内部可变性)
// format_info 在 param_changed 回调中设置,在 process 回调中读取
let format_info: Rc<Cell<Option<(u32, u32, u32, u64)>>> =
Rc::new(Cell::new(None));
let format_info: Rc<Cell<Option<(u32, u32, u32, u64)>>> = Rc::new(Cell::new(None));
let frame_tx_clone = frame_tx.clone();
// 注册流事件监听器,包含三个回调:
// - state_changed: 流状态变化通知
// - param_changed: 格式协商完成通知
// - process: 每帧数据处理
let event_tx_state = event_tx.clone();
let _listener = stream
.add_local_listener::<()>()
// 流状态变化回调
// 当流进入 Error 或 Unconnected 状态时,通知消费者流已结束
.state_changed(move |_, _, old, new| {
tracing::debug!("PipeWire stream state: {old:?} -> {new:?}");
tracing::info!("PipeWire stream state: {old:?} -> {new:?}");
match new {
pw::stream::StreamState::Error(_)
| pw::stream::StreamState::Unconnected => {
let _ = frame_tx_clone.send(PwEvent::StreamEnded);
pw::stream::StreamState::Error(e) => {
tracing::error!("PipeWire stream error: {e}");
let _ = event_tx_state.try_send(PwCtrlEvent::StreamEnded);
}
_ => {}
pw::stream::StreamState::Unconnected => {
let _ = event_tx_state.try_send(PwCtrlEvent::StreamEnded);
}
pw::stream::StreamState::Paused => {
tracing::warn!("PipeWire stream paused (compositor may be switching content)");
}
pw::stream::StreamState::Streaming => {
tracing::info!("PipeWire stream (re)started");
}
pw::stream::StreamState::Connecting => {}
}
})
// 参数变化回调(格式协商)
@@ -398,6 +570,7 @@ fn pipewire_thread(ctx: PwThreadCtx) {
// id 为参数类型param 包含具体的格式参数(分辨率、像素格式等)
.param_changed({
let format_info = format_info.clone();
let event_tx = event_tx.clone();
move |_, _, id, param| {
// 仅处理 Format 类型的参数变化
let Some(param) = param else { return };
@@ -416,11 +589,27 @@ fn pipewire_thread(ctx: PwThreadCtx) {
let drm_format = spa_to_drm_fourcc(info.format());
// 获取 DRM 修饰符,描述 GPU buffer 的内存布局(如 tiling 模式)
let modifier = info.modifier();
let framerate = info.framerate();
let max_framerate = info.max_framerate();
// 保存协商后的格式信息,供 process 回调读取
let previous_format = format_info.get();
format_info.set(Some((width, height, drm_format, modifier)));
if let Some((previous_width, previous_height, _, _)) = previous_format {
if width != previous_width || height != previous_height {
tracing::warn!(
"PipeWire dimensions changed: {}x{} (format renegotiation)",
width,
height
);
let _ = event_tx.try_send(PwCtrlEvent::FormatChanged { width, height });
}
}
tracing::info!(
"PipeWire format negotiated: {width}x{height}, \
drm_format={drm_format:#010x}, modifier={modifier:#x}"
drm_format={drm_format:#010x}, modifier={modifier:#x}, \
framerate={}/{}, max_framerate={}/{}",
framerate.num, framerate.denom,
max_framerate.num, max_framerate.denom,
);
}
})
@@ -432,15 +621,17 @@ fn pipewire_thread(ctx: PwThreadCtx) {
let frame_tx = frame_tx.clone();
let dropped = dropped;
move |stream, _| {
// 从流中出队原始 buffer包含帧数据的元信息
let raw_buf = unsafe { stream.dequeue_raw_buffer() };
if raw_buf.is_null() {
tracing::trace!("process: null raw_buf");
return;
}
// 获取 SPA buffer 结构体,包含数据数组、元数据等
let spa_buf = unsafe { (*raw_buf).buffer };
if spa_buf.is_null() {
tracing::trace!("process: null spa_buf");
unsafe { stream.queue_raw_buffer(raw_buf) };
return;
}
@@ -450,20 +641,27 @@ fn pipewire_thread(ctx: PwThreadCtx) {
let n_datas = unsafe { (*spa_buf).n_datas };
let datas_ptr = unsafe { (*spa_buf).datas };
if n_datas == 0 || datas_ptr.is_null() {
tracing::trace!("process: no data (n_datas={n_datas})");
unsafe { stream.queue_raw_buffer(raw_buf) };
return;
}
// 从第一个数据项中获取 DMA-BUF 文件描述符
// 通过 libspa 的 Data 包装类型安全地访问 SPA 数据结构
let data_ref: &pw::spa::buffer::Data = unsafe { &*(datas_ptr as *const pw::spa::buffer::Data) };
let data_ref: &pw::spa::buffer::Data =
unsafe { &*(datas_ptr as *const pw::spa::buffer::Data) };
let fd = data_ref.fd();
if fd < 0 {
tracing::trace!("process: invalid fd={fd}");
unsafe { stream.queue_raw_buffer(raw_buf) };
return;
}
// 获取 chunk 信息,包含帧数据在 DMA-BUF 中的偏移量和行跨度
if data_ref.as_raw().chunk.is_null() {
tracing::trace!("process: null chunk");
unsafe { stream.queue_raw_buffer(raw_buf) };
return;
}
let chunk = data_ref.chunk();
let offset = chunk.offset() as u64;
let stride = chunk.stride() as u32;
@@ -479,7 +677,8 @@ fn pipewire_thread(ctx: PwThreadCtx) {
for i in 0..n_metas {
let meta = &*metas.add(i as usize);
if meta.type_ == libspa::sys::SPA_META_Header
&& meta.size as usize >= std::mem::size_of::<libspa::sys::spa_meta_header>()
&& meta.size as usize
>= std::mem::size_of::<libspa::sys::spa_meta_header>()
&& !meta.data.is_null()
{
let header = &*(meta.data as *const libspa::sys::spa_meta_header);
@@ -497,6 +696,7 @@ fn pipewire_thread(ctx: PwThreadCtx) {
return;
};
if width == 0 || height == 0 || format == 0 {
tracing::trace!("process: invalid dimensions {width}x{height} format={format}");
unsafe { stream.queue_raw_buffer(raw_buf) };
return;
}
@@ -522,52 +722,33 @@ fn pipewire_thread(ctx: PwThreadCtx) {
pts,
};
// 尝试非阻塞发送帧到通道
// 如果通道已满(消费者处理不过来),丢弃该帧并增加丢弃计数
// 每 30 帧丢弃时输出一条警告日志,避免日志洪泛
if let Err(crossbeam_channel::TrySendError::Full(_)) =
frame_tx.try_send(PwEvent::Frame(frame))
{
let prev = dropped.fetch_add(1, Ordering::Relaxed);
if prev > 0 && prev % 30 == 0 {
tracing::warn!("dropped {prev} frames total: encoder backlog");
match frame_tx.try_send(frame) {
Ok(()) => {}
Err(crossbeam_channel::TrySendError::Full(_)) => {
dropped.fetch_add(1, Ordering::Relaxed);
}
Err(crossbeam_channel::TrySendError::Disconnected(_)) => {}
}
// 无论是否成功发送帧,都必须将 buffer 重新入队
// PipeWire 会复用这些 buffer不入队会导致 buffer 泄漏
unsafe { stream.queue_raw_buffer(raw_buf) };
}
})
.register();
// 空的参数数组,不主动请求特定格式(由 PipeWire 和源端协商决定)
let mut params: [&pw::spa::pod::Pod; 0] = [];
// 连接到指定的 PipeWire 节点
// Direction::Input: 作为消费者(输入方向接收数据)
// AUTOCONNECT: 允许 PipeWire 自动连接源和消费者
// MAP_BUFFERS: 映射 buffer 到用户空间DMA-BUF 模式下必须设置)
if let Err(e) = stream.connect(
pw::spa::utils::Direction::Input,
Some(node_id),
StreamFlags::AUTOCONNECT | StreamFlags::MAP_BUFFERS,
&mut params,
) {
let _ = frame_tx.send(PwEvent::Error(format!("stream.connect failed: {e}")));
if let Err(e) = event_tx.try_send(PwCtrlEvent::Error(format!("stream.connect failed: {e}"))) {
tracing::error!("stream.connect failed and error channel also failed: {e}");
}
return;
}
let loop_ = mainloop.loop_();
// 注册信号处理(空回调),阻止 SIGINT/SIGTERM 默认行为终止线程
// 真正的退出通过 shutdown eventfd 控制
loop_.add_signal_local(
pw::loop_::Signal::SIGINT,
Box::new(|| {}),
);
loop_.add_signal_local(
pw::loop_::Signal::SIGTERM,
Box::new(|| {}),
);
// Register the shutdown eventfd on the PipeWire loop.
//
@@ -607,10 +788,7 @@ fn pipewire_thread(ctx: PwThreadCtx) {
// run() returned — _shutdown_source drops first (reverse declaration order),
// which unregisters the callback from the loop. Then mainloop drops.
// No dangling raw pointers are possible.
// SAFETY: pipewire has been initialized with pw::init() above and all
// PipeWire resources (mainloop, stream) have been dropped.
unsafe { pw::deinit() };
// PipeWire global state is intentionally not deinitialized here — see pw::init() comment above.
}
/// 将四个 ASCII 字符编码为 32 位 FourCC (Four Character Code) 标识符
@@ -628,35 +806,66 @@ const fn fourcc(a: u8, b: u8, c: u8, d: u8) -> u32 {
/// 此函数建立了两者之间的映射关系。
///
/// 支持的格式:
/// - BGRA/BGRx: 蓝绿红(Alpha/X) 32位格式
/// - RGBA/RGBx: 红绿蓝(Alpha/X) 32位格式
/// - ARGB/xRGB: Alpha/X-红绿蓝 32位格式 (映射为 AR24/XR24)
/// - ABGR/xBGR: Alpha/X-蓝绿红 32位格式 (映射为 AB24/XB24)
///
/// 不支持的格式返回 0
/// DRM 格式名描述像素值位布局(大端序),而非内存字节序。
/// 例如 DRM_FORMAT_ARGB8888 在小端 x86 上内存为 [B,G,R,A] = PipeWire BGRA。
fn spa_to_drm_fourcc(format: libspa::param::video::VideoFormat) -> u32 {
use drm_fourcc::DrmFourcc;
use libspa::param::video::VideoFormat;
match format {
VideoFormat::BGRA => fourcc(b'B', b'G', b'R', b'A'),
VideoFormat::BGRx => fourcc(b'B', b'G', b'R', b'X'),
VideoFormat::RGBA => fourcc(b'R', b'G', b'B', b'A'),
VideoFormat::RGBx => fourcc(b'R', b'G', b'B', b'X'),
VideoFormat::ARGB => fourcc(b'A', b'R', b'2', b'4'),
VideoFormat::xRGB => fourcc(b'X', b'R', b'2', b'4'),
VideoFormat::ABGR => fourcc(b'A', b'B', b'2', b'4'),
VideoFormat::xBGR => fourcc(b'X', b'B', b'2', b'4'),
// 不支持的格式返回 0调用者应检查此值
_ => 0, }
VideoFormat::BGRA => DrmFourcc::Argb8888 as u32,
VideoFormat::BGRx => DrmFourcc::Xrgb8888 as u32,
VideoFormat::RGBA => DrmFourcc::Abgr8888 as u32,
VideoFormat::RGBx => DrmFourcc::Xbgr8888 as u32,
VideoFormat::ARGB => DrmFourcc::Bgra8888 as u32,
VideoFormat::xRGB => DrmFourcc::Bgrx8888 as u32,
VideoFormat::ABGR => DrmFourcc::Rgba8888 as u32,
VideoFormat::xBGR => DrmFourcc::Rgbx8888 as u32,
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use drm_fourcc::DrmFourcc;
use std::os::unix::fs::PermissionsExt;
#[test]
fn spa_to_drm_fourcc_bgra() {
fn spa_to_drm_fourcc_all_32bit() {
use libspa::param::video::VideoFormat;
assert_eq!(spa_to_drm_fourcc(VideoFormat::BGRA), fourcc(b'B', b'G', b'R', b'A'));
assert_eq!(
spa_to_drm_fourcc(VideoFormat::BGRA),
DrmFourcc::Argb8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::BGRx),
DrmFourcc::Xrgb8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::RGBA),
DrmFourcc::Abgr8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::RGBx),
DrmFourcc::Xbgr8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::ARGB),
DrmFourcc::Bgra8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::xRGB),
DrmFourcc::Bgrx8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::ABGR),
DrmFourcc::Rgba8888 as u32
);
assert_eq!(
spa_to_drm_fourcc(VideoFormat::xBGR),
DrmFourcc::Rgbx8888 as u32
);
}
#[test]
@@ -666,8 +875,158 @@ mod tests {
}
#[test]
fn fourcc_values() {
assert_eq!(fourcc(b'B', b'G', b'R', b'A'), 0x41524742);
assert_eq!(fourcc(b'R', b'G', b'B', b'A'), 0x41424752);
fn token_path_never_uses_tmp() {
assert!(token_path().is_some(), "token_path should resolve on Linux");
let path = token_path().unwrap();
assert!(!path.starts_with("/tmp"), "must not fallback to /tmp");
}
#[test]
fn verify_secure_dir_rejects_wrong_permissions() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path();
// 0o700 should pass
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700)).unwrap();
assert!(verify_secure_dir(path));
// 0o755 should fail
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o755)).unwrap();
assert!(!verify_secure_dir(path));
// 0o777 should fail
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o777)).unwrap();
assert!(!verify_secure_dir(path));
}
#[test]
fn verify_secure_dir_rejects_non_directory() {
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("not-a-dir");
std::fs::write(&file_path, b"test").unwrap();
assert!(!verify_secure_dir(&file_path));
}
#[test]
fn ensure_secure_parent_creates_with_0700() {
let base = tempfile::tempdir().unwrap();
let new_dir = base.path().join("wl-test-new-dir");
assert!(!new_dir.exists());
assert!(ensure_secure_parent(&new_dir));
assert!(new_dir.is_dir());
let meta = std::fs::symlink_metadata(&new_dir).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(mode, 0o700, "created directory should be 0700, got {mode:o}");
}
#[test]
fn ensure_secure_parent_tightens_existing_dir() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path();
// Simulate an existing directory with loose permissions
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o755)).unwrap();
assert!(ensure_secure_parent(path));
let meta = std::fs::symlink_metadata(path).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(mode, 0o700, "tightened directory should be 0700, got {mode:o}");
}
#[test]
fn save_creates_file_with_0600() {
let dir = tempfile::tempdir().unwrap();
let token_path = dir.path().join("portal-restore-token");
save_restore_token_to("secret-token-123", &token_path);
assert!(token_path.exists());
let meta = std::fs::symlink_metadata(&token_path).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "token file should be 0600, got {mode:o}");
assert_eq!(std::fs::read_to_string(&token_path).unwrap(), "secret-token-123");
}
#[test]
fn load_reads_secure_file() {
let dir = tempfile::tempdir().unwrap();
let token_path = dir.path().join("portal-restore-token");
// Write a valid 0o600 token file
use std::os::unix::fs::OpenOptionsExt;
let mut f = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(&token_path)
.unwrap();
std::io::Write::write_all(&mut f, b"my-secret\n").unwrap();
let result = load_restore_token_from(token_path);
assert_eq!(result, Some("my-secret".to_string()));
}
#[test]
fn load_rejects_group_readable_file() {
let dir = tempfile::tempdir().unwrap();
let token_path = dir.path().join("portal-restore-token");
// Write with 0o640 (group readable) — should be rejected
use std::os::unix::fs::OpenOptionsExt;
let mut f = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o640)
.open(&token_path)
.unwrap();
std::io::Write::write_all(&mut f, b"leaked-token\n").unwrap();
let result = load_restore_token_from(token_path);
assert!(result.is_none(), "should reject group-readable token file");
}
#[test]
fn load_rejects_world_readable_file() {
let dir = tempfile::tempdir().unwrap();
let token_path = dir.path().join("portal-restore-token");
use std::os::unix::fs::OpenOptionsExt;
let mut f = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o604)
.open(&token_path)
.unwrap();
std::io::Write::write_all(&mut f, b"leaked-token\n").unwrap();
let result = load_restore_token_from(token_path);
assert!(result.is_none(), "should reject world-readable token file");
}
#[test]
fn load_rejects_symlink() {
let dir = tempfile::tempdir().unwrap();
let real_path = dir.path().join("real-file");
let link_path = dir.path().join("portal-restore-token");
std::fs::write(&real_path, b"target-content\n").unwrap();
std::os::unix::fs::symlink(&real_path, &link_path).unwrap();
let result = load_restore_token_from(link_path);
assert!(result.is_none(), "should reject symlinked token file");
}
#[test]
fn save_then_load_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let token_path = dir.path().join("portal-restore-token");
save_restore_token_to("roundtrip-token", &token_path);
let loaded = load_restore_token_from(token_path);
assert_eq!(loaded, Some("roundtrip-token".to_string()));
}
}

View File

@@ -1,7 +1,8 @@
use std::time::{Duration, Instant};
pub struct FpsLimit<T> {
on_deck: Option<(T, Instant)>,
on_deck: Option<T>,
last_output_time: Option<Instant>,
min_interval: Duration,
}
@@ -9,30 +10,32 @@ impl<T> FpsLimit<T> {
pub fn new(fps: u32) -> Self {
Self {
on_deck: None,
last_output_time: None,
min_interval: Duration::from_secs_f64(1.0 / fps as f64),
}
}
/// Feed a new frame. Returns:
/// - Some(previous_frame) if enough time elapsed since previous frame
/// - None if frame is buffered (first frame) or previous is dropped (too close)
/// - Some(()) if enough time elapsed since the last output — proceed to encode current frame
/// - None if too close to the last output — drop current frame
pub fn on_new_frame(&mut self, frame: T, timestamp: Instant) -> Option<T> {
let old = self.on_deck.replace((frame, timestamp));
match old {
None => None, // First frame — buffer it
Some((old_frame, old_ts)) => {
if timestamp.duration_since(old_ts) >= self.min_interval {
Some(old_frame) // Enough time — output previous
let ready = match self.last_output_time {
None => true,
Some(last) => timestamp.duration_since(last) >= self.min_interval,
};
if ready {
self.last_output_time = Some(timestamp);
self.on_deck = Some(frame);
self.on_deck.take()
} else {
None // Too close — discard previous, keep new
}
}
let _ = self.on_deck.replace(frame);
None
}
}
/// Flush the last buffered frame at end of stream
pub fn flush(&mut self) -> Option<T> {
self.on_deck.take().map(|(frame, _ts)| frame)
self.on_deck.take()
}
}
@@ -41,15 +44,15 @@ mod tests {
use super::*;
#[test]
fn first_frame_is_buffered() {
fn first_frame_passes_immediately() {
let mut limiter: FpsLimit<u32> = FpsLimit::new(30);
let now = Instant::now();
let result = limiter.on_new_frame(1u32, now);
assert!(result.is_none());
assert_eq!(result, Some(1));
}
#[test]
fn frames_too_close_drops_old() {
fn frames_too_close_are_dropped() {
let mut limiter: FpsLimit<u32> = FpsLimit::new(30);
let now = Instant::now();
limiter.on_new_frame(1, now);
@@ -58,12 +61,29 @@ mod tests {
}
#[test]
fn frames_far_enough_output_old() {
fn frames_far_enough_pass() {
let mut limiter: FpsLimit<u32> = FpsLimit::new(30);
let now = Instant::now();
limiter.on_new_frame(1, now);
let result = limiter.on_new_frame(2, now + Duration::from_millis(40));
assert_eq!(result, Some(1));
let result = limiter.on_new_frame(2, now + Duration::from_millis(34));
assert_eq!(result, Some(2));
}
#[test]
fn high_fps_input_downsampled_correctly() {
let mut limiter: FpsLimit<u32> = FpsLimit::new(30);
let base = Instant::now();
let mut outputs = Vec::new();
for i in 0..10u32 {
let t = base + Duration::from_millis(i as u64 * 16);
if let Some(f) = limiter.on_new_frame(i, t) {
outputs.push(f);
}
}
assert!(outputs.len() >= 3, "expected at least 3 outputs, got {} ({:?})", outputs.len(), outputs);
assert_eq!(outputs[0], 0);
}
#[test]
@@ -71,7 +91,8 @@ mod tests {
let mut limiter: FpsLimit<u32> = FpsLimit::new(30);
let now = Instant::now();
limiter.on_new_frame(1, now);
assert_eq!(limiter.flush(), Some(1));
limiter.on_new_frame(2, now + Duration::from_millis(1));
assert_eq!(limiter.flush(), Some(2));
assert_eq!(limiter.flush(), None);
}
}

11
src/lib.rs Normal file
View File

@@ -0,0 +1,11 @@
pub mod args;
pub mod avhw;
pub mod backend_detect;
pub mod cap_portal;
pub mod cap_wlr_screencopy;
pub mod fps_limit;
pub mod stats;
pub mod state;
pub mod state_portal;
pub mod transform;
pub mod webrtc;

View File

@@ -15,9 +15,11 @@ mod backend_detect; // 截屏后端自动检测wlroots vs Portal/PipeWire
mod cap_portal; // XDG Portal 屏幕捕获
mod cap_wlr_screencopy; // wlroots wlr-screencopy 截屏协议
mod fps_limit; // 帧率限制器
mod stats; // 管道性能统计(卡顿诊断)
mod state; // wlr-screencopy 后端的主状态机
mod state_portal; // Portal/PipeWire 后端的主状态机
mod transform; // 图像变换(旋转/翻转)
mod webrtc; // WebRTC 传输str0m Sans-IO
use crate::args::Args;
use crate::cap_wlr_screencopy::CapWlrScreencopy;
@@ -42,35 +44,41 @@ fn main() -> Result<()> {
// 解析命令行参数
let args = Args::parse();
// 根据是否启用 verbose 模式设置日志级别
tracing_subscriber::fmt()
.with_max_level(if args.verbose {
tracing::Level::DEBUG
// 根据 verbose 模式或 RUST_LOG 环境变量设置日志级别
// 支持 RUST_LOG 粒度控制(如 RUST_LOG=wl_webrtc::webrtc=trace
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| {
if args.verbose {
tracing_subscriber::EnvFilter::new("debug")
} else {
tracing::Level::INFO
})
tracing_subscriber::EnvFilter::new("info")
}
});
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_writer(std::io::stderr)
.init();
tracing::info!("wl-webrtc starting");
tracing::debug!("Args: {:?}", args);
tracing::debug!("Args: output={:?} fps={} codec={} port={} verbose={}", args.output, args.fps, args.codec, args.port, args.verbose);
// MVP 阶段仅支持 H.264 编码,不支持 HEVC
if args.codec != "h264" {
anyhow::bail!("HEVC not supported in MVP. Use --codec h264");
}
if args.output.is_none() && args.port == 0 {
anyhow::bail!("Either --output or --port is required");
}
// 自动检测当前桌面环境可用的截屏后端
// 会尝试列举 Wayland 全局对象,判断合成器是否支持 wlr-screencopy 协议
let backend = crate::backend_detect::detect_backend(&args)?;
// 根据检测结果进入对应的事件循环
match backend {
crate::backend_detect::CaptureBackend::WlrScreencopy => {
run_wlr_screencopy(args)
}
crate::backend_detect::CaptureBackend::PortalPipeWire => {
run_portal_pipewire(args)
}
crate::backend_detect::CaptureBackend::WlrScreencopy => run_wlr_screencopy(args),
crate::backend_detect::CaptureBackend::PortalPipeWire => run_portal_pipewire(args),
}
}
@@ -98,7 +106,7 @@ fn run_wlr_screencopy(args: Args) -> Result<()> {
let qhandle = queue.handle();
// State 是 wlr-screencopy 后端的核心状态机,
// 内部管理输出探测、截屏请求、编码器构建、帧采集等阶段
let mut state = State::new(gm, args, qhandle);
let mut state = State::new(gm, args, qhandle)?;
// Extract the Wayland fd and consume any immediately-available events.
// prepare_read() flushes outgoing requests; read() pulls whatever the
@@ -244,9 +252,11 @@ fn run_wlr_screencopy(args: Args) -> Result<()> {
// - Streaming: 正常采集中,请求下一帧
state.queue_alloc_frame();
state.poll_webrtc()?;
// 状态机遇到致命错误时退出
if state.errored {
tracing::error!("Fatal error in state machine, exiting");
tracing::error!("Fatal error in state machine (check preceding error logs), exiting");
running = false;
}
@@ -305,21 +315,17 @@ fn run_portal_pipewire(args: Args) -> Result<()> {
// 只注册信号 fd没有 Wayland fd
// 所以 poll.poll 在这里只负责检测 SIGINT/SIGTERM
// 实际的帧采集完全依赖 poll_and_encode 的轮询
poll.registry().register(
&mut signals,
mio::Token(1),
mio::Interest::READABLE,
)?;
poll.registry()
.register(&mut signals, mio::Token(1), mio::Interest::READABLE)?;
// 主事件循环(超时 10ms比 wlr-screencopy 更短,因为不依赖 Wayland fd 唤醒
// 10ms 超时的作用是让循环高频转动,以便及时处理 PipeWire 投递的帧
// 如果没有信号poll 最多阻塞 10ms 就会超时返回
// 主事件循环(非阻塞信号检测 + recv_timeout 等待帧
// poll 超时为 0ms非阻塞实际等待由 poll_and_encode 的 recv_timeout 实现
let mut running = true;
while running {
// poll 在此循环中只监听信号 fd,所以
// poll 在此循环中只监听信号 fd(非阻塞)
// - 收到 SIGINT/SIGTERM → 事件触发,设置 running=false
// - 超时 10ms → 事件为空,继续执行 poll_and_encode
poll.poll(&mut events, Some(std::time::Duration::from_millis(10)))
// - 无事件 → 立即返回,继续执行 poll_and_encode(内部 recv_timeout 等待帧)
poll.poll(&mut events, Some(std::time::Duration::from_millis(0)))
.unwrap_or_else(|e| {
if e.kind() == std::io::ErrorKind::Interrupted {
return;
@@ -328,7 +334,6 @@ fn run_portal_pipewire(args: Args) -> Result<()> {
running = false;
});
// 遍历事件,检查是否收到退出信号
for event in &events {
if event.token() == mio::Token(1) {
tracing::info!("Received quit signal");
@@ -341,11 +346,13 @@ fn run_portal_pipewire(args: Args) -> Result<()> {
// poll_and_encode 会从 PipeWire 缓冲区取出帧,
// 编码为 H.264 并推送。返回 true 表示还有更多帧待处理,
// 返回 false 表示当前没有帧了while 循环退出等待下一轮 poll
while state.poll_and_encode()? {}
if state.poll_and_encode(true)? {
while state.poll_and_encode(false)? {}
}
// Portal 状态机遇到致命错误时退出
if state.is_errored() {
tracing::error!("Fatal error in portal state machine, exiting");
tracing::error!("Fatal error in portal state machine (check preceding error logs), exiting");
running = false;
}
}

View File

@@ -3,6 +3,8 @@ use std::mem;
use std::os::fd::{AsFd, OwnedFd};
use std::os::unix::io::FromRawFd;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
@@ -41,10 +43,12 @@ use ffmpeg_next as ff;
use ffmpeg_next::ffi;
use crate::args::Args;
use crate::avhw::{AvHwDevCtx, EncState};
use crate::avhw::{AvHwDevCtx, EncState, SwEncState};
use crate::cap_wlr_screencopy::CapWlrScreencopy;
use crate::fps_limit::FpsLimit;
use crate::stats::{FrameTimings, PipelineStats};
use crate::transform::{transpose_if_transform_transposed, Transform};
use crate::webrtc::WebRtcState;
// ---------------------------------------------------------------------------
// CaptureSource trait
@@ -113,6 +117,42 @@ struct WlrHeadInfo {
/// User data for XdgOutput dispatch to identify which WlOutput it belongs to.
pub struct OutputId(pub u32);
// ---------------------------------------------------------------------------
// StreamingEncoder
// ---------------------------------------------------------------------------
/// Wraps the two possible encoder backends for the streaming stage.
///
/// - `Mp4(EncState)` — hardware VAAPI encoder writing to an MP4 file
/// - `WebRtc(SwEncState)` — software encoder feeding H.264 NALUs into a WebRTC channel
pub enum StreamingEncoder {
Mp4(EncState),
WebRtc(SwEncState),
}
impl StreamingEncoder {
fn frames_rgb(&self) -> &crate::avhw::AvHwFrameCtx {
match self {
StreamingEncoder::Mp4(enc) => enc.frames_rgb(),
StreamingEncoder::WebRtc(enc) => enc.frames_rgb(),
}
}
fn encode_frame(&mut self, hw_frame: &ffmpeg_next::frame::Video) -> anyhow::Result<()> {
match self {
StreamingEncoder::Mp4(enc) => enc.encode_frame(hw_frame),
StreamingEncoder::WebRtc(enc) => enc.encode_frame(hw_frame),
}
}
pub fn flush(&mut self) -> anyhow::Result<()> {
match self {
StreamingEncoder::Mp4(enc) => enc.flush(),
StreamingEncoder::WebRtc(enc) => enc.flush(),
}
}
}
// ---------------------------------------------------------------------------
// EncConstructionStage
// ---------------------------------------------------------------------------
@@ -142,7 +182,7 @@ pub enum EncConstructionStage<S: CaptureSource> {
Streaming {
output_info: OutputInfo,
output: WlOutput,
enc: EncState,
enc: StreamingEncoder,
cap: S,
screencopy_manager: ZwlrScreencopyManagerV1,
dmabuf: ZwpLinuxDmabufV1,
@@ -174,6 +214,9 @@ pub struct State<S: CaptureSource> {
pub stage: EncConstructionStage<S>,
pub in_flight_surface: InFlightSurface<S>,
pub starting_timestamp: Option<i64>,
pub stats_start_time: Option<Instant>,
pub stats_last_time: Option<Instant>,
pub stats_frames: u64,
pub first_frame: bool,
pub args: Args,
pub errored: bool,
@@ -182,29 +225,41 @@ pub struct State<S: CaptureSource> {
pub qhandle: QueueHandle<State<S>>,
pub drm_device: Option<PathBuf>,
pub drm_device_from_compositor: Option<PathBuf>,
pub webrtc: Option<WebRtcState>,
pub webrtc_tx: Option<crossbeam_channel::Sender<Vec<u8>>>,
webrtc_rx: Option<crossbeam_channel::Receiver<Vec<u8>>>,
webrtc_frames_sent: u64,
webrtc_paused: Option<Arc<AtomicBool>>,
stats: PipelineStats,
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Scan /dev/dri for all available DRM render nodes (renderD*), sorted by node number.
pub(crate) fn find_drm_render_nodes() -> Vec<PathBuf> {
let Ok(entries) = std::fs::read_dir("/dev/dri") else {
return Vec::new();
};
let mut nodes: Vec<(u32, PathBuf)> = entries
.filter_map(Result::ok)
.filter_map(|entry| {
let path = entry.path();
let name = path.file_name()?.to_str()?;
let number = name.strip_prefix("renderD")?.parse::<u32>().ok()?;
std::fs::metadata(&path).ok()?;
Some((number, path))
})
.collect();
nodes.sort_by_key(|(number, _)| *number);
nodes.into_iter().map(|(_, path)| path).collect()
}
/// Scan /dev/dri for the first available DRM render node (renderD*).
fn find_drm_render_node() -> Option<PathBuf> {
std::fs::read_dir("/dev/dri")
.ok()?
.filter_map(|e| e.ok())
.filter(|e| {
e.file_name()
.to_str()
.map(|s| s.starts_with("renderD"))
.unwrap_or(false)
})
.filter_map(|e| {
let path = e.path();
std::fs::metadata(&path).ok()?;
Some(path)
})
.min_by_key(|e| e.to_path_buf())
find_drm_render_nodes().into_iter().next()
}
impl<S: CaptureSource> State<S> {
@@ -222,9 +277,20 @@ impl<S: CaptureSource> State<S> {
// ---------------------------------------------------------------------------
impl<S: CaptureSource> State<S> {
pub fn new(gm: GlobalList, args: Args, qhandle: QueueHandle<State<S>>) -> Self {
pub fn new(gm: GlobalList, args: Args, qhandle: QueueHandle<State<S>>) -> Result<Self> {
let fps = args.fps;
let drm_device = args.drm_device.as_ref().map(PathBuf::from);
let (webrtc, webrtc_tx, webrtc_rx, webrtc_paused) = if args.port > 0 {
let (tx, rx) = crossbeam_channel::bounded(32);
let wrtc = WebRtcState::new(args.port, args.fps)?;
// paused=true until first WebRTC client connects
let paused = Arc::new(AtomicBool::new(true));
(Some(wrtc), Some(tx), Some(rx), Some(paused))
} else {
(None, None, None, None)
};
let mut state = Self {
stage: EncConstructionStage::ProbingOutputs {
outputs: Vec::new(),
@@ -241,6 +307,9 @@ impl<S: CaptureSource> State<S> {
},
in_flight_surface: InFlightSurface::None,
starting_timestamp: None,
stats_start_time: None,
stats_last_time: None,
stats_frames: 0,
first_frame: true,
fps_limit: FpsLimit::new(fps),
args,
@@ -249,6 +318,12 @@ impl<S: CaptureSource> State<S> {
qhandle,
drm_device,
drm_device_from_compositor: None,
webrtc,
webrtc_tx,
webrtc_rx,
webrtc_frames_sent: 0,
webrtc_paused,
stats: PipelineStats::new(),
};
// registry_queue_init consumes registry events internally during its
@@ -256,7 +331,7 @@ impl<S: CaptureSource> State<S> {
// We must manually bind the initial globals here.
state.bind_initial_globals();
state
Ok(state)
}
/// Iterate over the GlobalList from registry_queue_init and bind all
@@ -426,7 +501,7 @@ impl<S: CaptureSource> State<S> {
// is a freshly allocated empty Video frame.
let ret = unsafe { ffi::av_hwframe_get_buffer(frames_rgb_ctx, surface.as_mut_ptr(), 0) };
if ret < 0 {
tracing::error!("av_hwframe_get_buffer failed: error {}", ret);
tracing::error!("av_hwframe_get_buffer failed: {}", crate::avhw::ff_err(ret));
self.errored = true;
return;
}
@@ -439,7 +514,7 @@ impl<S: CaptureSource> State<S> {
}
let ret = unsafe { ffi::av_hwframe_map(map_frame.as_mut_ptr(), surface.as_ptr(), 0) };
if ret < 0 {
tracing::error!("av_hwframe_map failed: error {}", ret);
tracing::error!("av_hwframe_map failed: {}", crate::avhw::ff_err(ret));
self.errored = true;
return;
}
@@ -464,7 +539,7 @@ impl<S: CaptureSource> State<S> {
// takes ownership of the fd, and the original fd is owned by map_frame.
let fd_dup = unsafe { libc::dup(obj.fd) };
if fd_dup < 0 {
tracing::error!("failed to dup dma-buf fd");
tracing::error!("failed to dup dma-buf fd: {}", std::io::Error::last_os_error());
// wayland-client does not auto-destroy params on Drop.
params.destroy();
self.errored = true;
@@ -508,6 +583,8 @@ impl<S: CaptureSource> State<S> {
where
S::Frame: Default,
{
self.stats.record_capture();
let (mut surface, _drm_map, frame, buffer) =
match mem::replace(&mut self.in_flight_surface, InFlightSurface::None) {
InFlightSurface::CopyQueued {
@@ -548,10 +625,29 @@ impl<S: CaptureSource> State<S> {
.is_some()
};
if should_encode {
let encode_start = Instant::now();
if let Err(e) = enc.encode_frame(&surface) {
tracing::error!("encode_frame failed: {}", e);
self.errored = true;
}
let encode_elapsed = encode_start.elapsed().as_micros() as u64;
self.stats.record_encode(&FrameTimings {
total_us: encode_elapsed,
..Default::default()
});
}
self.stats_frames += 1;
if let Some(last) = self.stats_last_time {
if last.elapsed() >= std::time::Duration::from_secs(10) {
let delta = self.stats_frames;
let fps = delta as f64 / last.elapsed().as_secs_f64();
tracing::info!(frames = self.stats_frames, fps = format!("{fps:.1}"), "encoding stats");
self.stats_last_time = Some(std::time::Instant::now());
self.stats_frames = 0;
}
} else {
self.stats_start_time = Some(std::time::Instant::now());
self.stats_last_time = Some(std::time::Instant::now());
}
}
@@ -562,11 +658,7 @@ impl<S: CaptureSource> State<S> {
tracing::error!("compositor copy failed");
let taken = mem::replace(&mut self.in_flight_surface, InFlightSurface::None);
match taken {
InFlightSurface::CopyQueued {
buffer,
frame,
..
} => {
InFlightSurface::CopyQueued { buffer, frame, .. } => {
drop(buffer);
if let EncConstructionStage::Streaming { cap, .. } = &mut self.stage {
cap.on_done_with_frame(frame);
@@ -579,44 +671,129 @@ impl<S: CaptureSource> State<S> {
self.errored = true;
}
pub fn poll_webrtc(&mut self) -> Result<()> {
let Some(ref mut wrtc) = self.webrtc else { return Ok(()) };
wrtc.handle_signaling()?;
wrtc.poll_and_feed()?;
let connected = wrtc.is_connected();
if let Some(ref paused) = self.webrtc_paused {
let was_paused = paused.load(Ordering::Relaxed);
let now_paused = !connected;
if was_paused && !now_paused {
tracing::info!("WebRTC client connected, resuming encoding");
} else if !was_paused && now_paused {
tracing::warn!("WebRTC client disconnected, pausing encoding");
}
paused.store(now_paused, Ordering::Relaxed);
}
if let Some(ref rx) = self.webrtc_rx {
let mut count = 0u32;
while let Ok(data) = rx.try_recv() {
if !connected {
continue;
}
count += 1;
if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) {
tracing::debug!("WebRTC write frame error: {e}");
}
self.stats.record_send(0.0, None);
self.webrtc_frames_sent = self.webrtc_frames_sent.saturating_add(1);
}
if count > 0 {
tracing::debug!("WebRTC forwarded {count} frames from channel");
}
}
if self.args.stats && self.stats.should_snapshot() {
self.stats.set_queue_depths(
0,
self.webrtc_rx.as_ref().map(|r| r.len()).unwrap_or(0),
);
let snap = self.stats.snapshot_and_reset();
tracing::info!("stats: {snap}");
}
Ok(())
}
pub fn negotiate_format(&mut self, format: u32, width: u32, height: u32) {
let stage_data = match mem::replace(&mut self.stage, EncConstructionStage::Intermediate) {
EncConstructionStage::EverythingButFmt {
output_info,
output,
hw_device_ctx: _hw_device_ctx,
hw_device_ctx,
cap,
screencopy_manager,
dmabuf,
} => (output_info, output, cap, screencopy_manager, dmabuf),
} => (
output_info,
output,
hw_device_ctx,
cap,
screencopy_manager,
dmabuf,
),
other => {
tracing::warn!("negotiate_format: not in EverythingButFmt stage");
self.stage = other;
return;
}
};
let (output_info, output, cap, screencopy_manager, dmabuf) = stage_data;
let (output_info, output, hw_device_ctx, cap, screencopy_manager, dmabuf) = stage_data;
let drm_path = self.resolve_drm_path();
let fps = self.args.fps;
let bitrate = self.args.bitrate.unwrap_or_else(|| {
2 * (width as u64) * (height as u64) * (fps as u64) / 100
});
let enc = match crate::avhw::create_encoder(
let bitrate = self
.args
.bitrate
.unwrap_or_else(|| 2 * (width as u64) * (height as u64) * (fps as u64) / 100);
let enc = if let Some(ref tx) = self.webrtc_tx {
let (enc_w, enc_h) =
transpose_if_transform_transposed(output_info.transform, width as i32, height as i32);
let actual_gop_size = self.args.gop_size.unwrap_or((fps / 2).max(10));
match SwEncState::new_webrtc(
&drm_path,
Path::new(&self.args.output),
width,
height,
enc_w as u32,
enc_h as u32,
fps,
bitrate,
actual_gop_size,
tx.clone(),
self.webrtc_paused.as_ref().expect("webrtc_paused must exist when webrtc_tx exists").clone(),
) {
Ok(enc) => StreamingEncoder::WebRtc(enc),
Err(e) => {
tracing::error!("SwEncState::new_webrtc failed: {}", e);
self.errored = true;
return;
}
}
} else {
let output_path = self.args.output.as_deref().expect("output required for MP4 mode");
match crate::avhw::create_encoder(
&drm_path,
Path::new(output_path),
width,
height,
fps,
output_info.transform,
self.args.bitrate,
self.args.gop_size,
Some(hw_device_ctx),
) {
Ok(enc) => enc,
Ok(enc) => StreamingEncoder::Mp4(enc),
Err(e) => {
tracing::error!("EncState::new failed: {}", e);
self.errored = true;
return;
}
}
};
tracing::info!(
"Encoder initialized: {}x{} format={} bitrate={}",
@@ -858,7 +1035,6 @@ impl<S: CaptureSource> Dispatch<WlRegistry, GlobalListContents> for State<S> {
qhandle: &QueueHandle<State<S>>,
) {
use wayland_client::protocol::wl_registry::Event as RegistryEvent;
tracing::debug!("Dispatch<WlRegistry>::event fired: {:?}", event);
match event {
RegistryEvent::Global {
@@ -1192,11 +1368,7 @@ impl<S: CaptureSource> Dispatch<ZwpLinuxBufferParamsV1, ()> for State<S> {
tracing::error!("DMA-BUF buffer creation failed");
let taken = mem::replace(&mut state.in_flight_surface, InFlightSurface::None);
match taken {
InFlightSurface::CopyQueued {
buffer,
frame,
..
} => {
InFlightSurface::CopyQueued { buffer, frame, .. } => {
drop(buffer);
if let EncConstructionStage::Streaming { cap, .. } = &mut state.stage {
cap.on_done_with_frame(frame);
@@ -1228,11 +1400,11 @@ impl Dispatch<ZwlrScreencopyFrameV1, ()> for State<CapWlrScreencopy> {
_qhandle: &QueueHandle<State<CapWlrScreencopy>>,
) {
match event {
// SHM buffer offer — in v3 the compositor enumerates supported buffer
// types (buffer and/or linux_dmabuf) before buffer_done. We only
// support DMA-BUF, so just log and wait for linux_dmabuf / buffer_done.
ScreencopyFrameEvent::Buffer { .. } => {
tracing::warn!(
"Received SHM Buffer event — only DMA-BUF capture is supported. Ignoring."
);
return;
tracing::debug!("Received SHM Buffer offer — only DMA-BUF capture is supported");
}
ScreencopyFrameEvent::LinuxDmabuf {
format,
@@ -1240,6 +1412,12 @@ impl Dispatch<ZwlrScreencopyFrameV1, ()> for State<CapWlrScreencopy> {
height,
} => {
tracing::debug!("Screencopy LinuxDmabuf: format={format}, {width}x{height}");
if !matches!(state.in_flight_surface, InFlightSurface::AllocQueued) {
tracing::warn!("Received LinuxDmabuf while no frame allocation was queued");
return;
}
if matches!(state.stage, EncConstructionStage::EverythingButFmt { .. }) {
state.negotiate_format(format, width, height);
if state.errored {
@@ -1251,6 +1429,20 @@ impl Dispatch<ZwlrScreencopyFrameV1, ()> for State<CapWlrScreencopy> {
}
state.on_frame_allocd((), format, width, height);
}
// v3 terminal event: all buffer offers have been enumerated.
// If still AllocQueued, the compositor never sent linux_dmabuf —
// DMA-BUF screencopy is unsupported, so we must error out.
ScreencopyFrameEvent::BufferDone => {
if matches!(state.in_flight_surface, InFlightSurface::AllocQueued) {
tracing::error!(
"Compositor did not offer DMA-BUF screencopy (only SHM); \
DMA-BUF capture is required"
);
state.in_flight_surface = InFlightSurface::None;
proxy.destroy();
state.errored = true;
}
}
ScreencopyFrameEvent::Ready {
tv_sec_hi,
tv_sec_lo,

File diff suppressed because it is too large Load Diff

500
src/stats.rs Normal file
View File

@@ -0,0 +1,500 @@
// stats.rs — Lightweight windowed pipeline statistics for stutter diagnosis
//
// Tracks per-second snapshots of capture/encode/send pipeline metrics.
// Designed for low overhead: only counters and timing samples are collected,
// with one structured log line emitted per second when `--stats` is enabled.
use std::time::Instant;
/// Per-stage timing for a single encode pipeline frame.
///
/// All values are in microseconds. The caller records timestamps around
/// each stage and passes the deltas to [`PipelineStats::record_frame`].
#[derive(Debug, Default)]
pub struct FrameTimings {
/// DMA-BUF import (av_hwframe_map)
pub import_us: u64,
/// GPU scale (scale_vaapi filter)
pub scale_us: u64,
/// GPU→CPU transfer (av_hwframe_transfer_data)
pub transfer_us: u64,
/// sws_scale NV12→YUV420P
pub sws_us: u64,
/// H.264 encode (avcodec_send_frame + receive_packet)
pub encode_us: u64,
/// Wall-clock total for this frame (import through encode output)
pub total_us: u64,
/// Encoded output size in bytes
pub output_bytes: usize,
}
/// Windowed statistics aggregator for the encode/send pipeline.
///
/// Collects counters and timing samples within a one-second window,
/// then computes avg/p95/max when the snapshot is taken.
pub struct PipelineStats {
// --- counters (reset each window) ---
capture_frames: u64,
encoded_frames: u64,
sent_frames: u64,
pipewire_dropped: u64,
over_budget_count: u64,
// --- queue depth at last observation ---
capture_queue_depth: usize,
encoded_queue_depth: usize,
// --- timing samples ---
capture_gaps_ms: Vec<f64>,
encoded_gaps_ms: Vec<f64>,
sent_gaps_ms: Vec<f64>,
frame_age_ms: Vec<f64>,
send_wait_ms: Vec<f64>,
// --- per-stage timing (microseconds) ---
import_us: Vec<u64>,
scale_us: Vec<u64>,
transfer_us: Vec<u64>,
sws_us: Vec<u64>,
encode_us: Vec<u64>,
total_us: Vec<u64>,
output_bytes: Vec<usize>,
// --- timing state ---
last_capture_time: Option<Instant>,
last_encode_time: Option<Instant>,
last_send_time: Option<Instant>,
window_start: Instant,
}
impl PipelineStats {
pub fn new() -> Self {
Self {
capture_frames: 0,
encoded_frames: 0,
sent_frames: 0,
pipewire_dropped: 0,
over_budget_count: 0,
capture_queue_depth: 0,
encoded_queue_depth: 0,
capture_gaps_ms: Vec::new(),
encoded_gaps_ms: Vec::new(),
sent_gaps_ms: Vec::new(),
frame_age_ms: Vec::new(),
send_wait_ms: Vec::new(),
import_us: Vec::new(),
scale_us: Vec::new(),
transfer_us: Vec::new(),
sws_us: Vec::new(),
encode_us: Vec::new(),
total_us: Vec::new(),
output_bytes: Vec::new(),
last_capture_time: None,
last_encode_time: None,
last_send_time: None,
window_start: Instant::now(),
}
}
/// Record that a capture frame was received from PipeWire.
pub fn record_capture(&mut self) {
let now = Instant::now();
if let Some(last) = self.last_capture_time {
let gap_ms = last.elapsed().as_secs_f64() * 1000.0;
self.capture_gaps_ms.push(gap_ms);
}
self.last_capture_time = Some(now);
self.capture_frames += 1;
}
/// Record that a frame completed encoding with the given timings.
pub fn record_encode(&mut self, timings: &FrameTimings) {
let now = Instant::now();
if let Some(last) = self.last_encode_time {
let gap_ms = last.elapsed().as_secs_f64() * 1000.0;
self.encoded_gaps_ms.push(gap_ms);
}
self.last_encode_time = Some(now);
self.encoded_frames += 1;
self.import_us.push(timings.import_us);
self.scale_us.push(timings.scale_us);
self.transfer_us.push(timings.transfer_us);
self.sws_us.push(timings.sws_us);
self.encode_us.push(timings.encode_us);
self.total_us.push(timings.total_us);
self.output_bytes.push(timings.output_bytes);
}
pub fn record_import(&mut self, import_us: u64) {
self.import_us.push(import_us);
}
pub fn record_encode_thread(&mut self, sws_us: u64, encode_us: u64, output_bytes: usize) {
let now = Instant::now();
if let Some(last) = self.last_encode_time {
let gap_ms = last.elapsed().as_secs_f64() * 1000.0;
self.encoded_gaps_ms.push(gap_ms);
}
self.last_encode_time = Some(now);
self.encoded_frames += 1;
self.sws_us.push(sws_us);
self.encode_us.push(encode_us);
self.total_us.push(sws_us.saturating_add(encode_us));
self.output_bytes.push(output_bytes);
}
/// Record that a frame was sent via WebRTC.
/// `wait_ms` is time spent blocked waiting to send into the channel.
/// `capture_time` is when the frame was originally captured (for frame age).
pub fn record_send(&mut self, wait_ms: f64, capture_time: Option<Instant>) {
let now = Instant::now();
if let Some(last) = self.last_send_time {
let gap_ms = last.elapsed().as_secs_f64() * 1000.0;
self.sent_gaps_ms.push(gap_ms);
}
self.last_send_time = Some(now);
self.sent_frames += 1;
if wait_ms > 0.0 {
self.send_wait_ms.push(wait_ms);
}
if let Some(ct) = capture_time {
let age_ms = ct.elapsed().as_secs_f64() * 1000.0;
self.frame_age_ms.push(age_ms);
}
}
/// Record a frame sent from a background WebRTC thread.
/// `gap_ms` is the pre-computed time since the previous send (0.0 = first frame).
/// Unlike `record_send`, this does not sample `Instant::now()`, so it remains
/// accurate even when batch-drained at stats snapshot time.
pub fn record_send_from_thread(&mut self, gap_ms: f64) {
if gap_ms > 0.0 {
self.sent_gaps_ms.push(gap_ms);
}
self.sent_frames += 1;
}
/// Update PipeWire dropped counter (absolute value from AtomicU64).
pub fn set_pipewire_dropped(&mut self, total_dropped: u64, prev_dropped: u64) {
self.pipewire_dropped = total_dropped.saturating_sub(prev_dropped);
}
/// Update queue depth observations.
pub fn set_queue_depths(&mut self, capture: usize, encoded: usize) {
self.capture_queue_depth = capture;
self.encoded_queue_depth = encoded;
}
/// Record that a frame exceeded its time budget.
pub fn record_over_budget(&mut self) {
self.over_budget_count += 1;
}
/// Returns true if at least 1 second has elapsed since the last snapshot
/// (or since creation). If true, call `snapshot_and_reset` to get the stats.
pub fn should_snapshot(&self) -> bool {
self.window_start.elapsed().as_secs() >= 1
}
/// Compute a snapshot of the current window and reset all counters.
pub fn snapshot_and_reset(&mut self) -> StatsSnapshot {
let elapsed = self.window_start.elapsed().as_secs_f64();
let snap = StatsSnapshot {
elapsed_secs: elapsed,
capture_fps: self.capture_frames as f64 / elapsed,
encoded_fps: self.encoded_frames as f64 / elapsed,
sent_fps: self.sent_frames as f64 / elapsed,
capture_frames: self.capture_frames,
encoded_frames: self.encoded_frames,
sent_frames: self.sent_frames,
pipewire_dropped: self.pipewire_dropped,
over_budget_count: self.over_budget_count,
capture_queue_depth: self.capture_queue_depth,
encoded_queue_depth: self.encoded_queue_depth,
capture_gap_avg_ms: avg_f64(&self.capture_gaps_ms),
capture_gap_p95_ms: p95_f64(&self.capture_gaps_ms),
capture_gap_max_ms: max_f64(&self.capture_gaps_ms),
encoded_gap_avg_ms: avg_f64(&self.encoded_gaps_ms),
encoded_gap_p95_ms: p95_f64(&self.encoded_gaps_ms),
encoded_gap_max_ms: max_f64(&self.encoded_gaps_ms),
sent_gap_avg_ms: avg_f64(&self.sent_gaps_ms),
sent_gap_p95_ms: p95_f64(&self.sent_gaps_ms),
sent_gap_max_ms: max_f64(&self.sent_gaps_ms),
frame_age_avg_ms: avg_f64(&self.frame_age_ms),
frame_age_p95_ms: p95_f64(&self.frame_age_ms),
frame_age_max_ms: max_f64(&self.frame_age_ms),
send_wait_p95_ms: p95_f64(&self.send_wait_ms),
import_avg_ms: avg_ms(&self.import_us),
import_p95_ms: p95_ms(&self.import_us),
scale_avg_ms: avg_ms(&self.scale_us),
scale_p95_ms: p95_ms(&self.scale_us),
transfer_avg_ms: avg_ms(&self.transfer_us),
transfer_p95_ms: p95_ms(&self.transfer_us),
sws_avg_ms: avg_ms(&self.sws_us),
sws_p95_ms: p95_ms(&self.sws_us),
encode_avg_ms: avg_ms(&self.encode_us),
encode_p95_ms: p95_ms(&self.encode_us),
total_avg_ms: avg_ms(&self.total_us),
total_p95_ms: p95_ms(&self.total_us),
output_bytes_per_sec: sum_usize(&self.output_bytes) as f64 / elapsed,
output_frame_bytes_p95: p95_usize(&self.output_bytes),
output_frame_bytes_max: max_usize(&self.output_bytes),
};
// Reset all counters and sample buffers
self.capture_frames = 0;
self.encoded_frames = 0;
self.sent_frames = 0;
self.pipewire_dropped = 0;
self.over_budget_count = 0;
self.capture_queue_depth = 0;
self.encoded_queue_depth = 0;
self.capture_gaps_ms.clear();
self.encoded_gaps_ms.clear();
self.sent_gaps_ms.clear();
self.frame_age_ms.clear();
self.send_wait_ms.clear();
self.import_us.clear();
self.scale_us.clear();
self.transfer_us.clear();
self.sws_us.clear();
self.encode_us.clear();
self.total_us.clear();
self.output_bytes.clear();
self.window_start = Instant::now();
snap
}
}
/// A one-second snapshot of pipeline statistics.
#[derive(Debug)]
pub struct StatsSnapshot {
pub elapsed_secs: f64,
// FPS
pub capture_fps: f64,
pub encoded_fps: f64,
pub sent_fps: f64,
// Counters
pub capture_frames: u64,
pub encoded_frames: u64,
pub sent_frames: u64,
pub pipewire_dropped: u64,
pub over_budget_count: u64,
// Queue depths
pub capture_queue_depth: usize,
pub encoded_queue_depth: usize,
// Gap timing (ms)
pub capture_gap_avg_ms: f64,
pub capture_gap_p95_ms: f64,
pub capture_gap_max_ms: f64,
pub encoded_gap_avg_ms: f64,
pub encoded_gap_p95_ms: f64,
pub encoded_gap_max_ms: f64,
pub sent_gap_avg_ms: f64,
pub sent_gap_p95_ms: f64,
pub sent_gap_max_ms: f64,
// Frame age (capture → send)
pub frame_age_avg_ms: f64,
pub frame_age_p95_ms: f64,
pub frame_age_max_ms: f64,
// Send wait
pub send_wait_p95_ms: f64,
// Per-stage encode timing (ms)
pub import_avg_ms: f64,
pub import_p95_ms: f64,
pub scale_avg_ms: f64,
pub scale_p95_ms: f64,
pub transfer_avg_ms: f64,
pub transfer_p95_ms: f64,
pub sws_avg_ms: f64,
pub sws_p95_ms: f64,
pub encode_avg_ms: f64,
pub encode_p95_ms: f64,
pub total_avg_ms: f64,
pub total_p95_ms: f64,
// Output size
pub output_bytes_per_sec: f64,
pub output_frame_bytes_p95: usize,
pub output_frame_bytes_max: usize,
}
impl std::fmt::Display for StatsSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"capture_fps={:.1} encoded_fps={:.1} sent_fps={:.1} \
pw_dropped={} over_budget={} \
cap_q={} enc_q={} \
cap_gap_p95={:.1}ms cap_gap_max={:.1}ms \
enc_gap_p95={:.1}ms enc_gap_max={:.1}ms \
sent_gap_p95={:.1}ms sent_gap_max={:.1}ms \
frame_age_p95={:.1}ms frame_age_max={:.1}ms \
send_wait_p95={:.1}ms \
import_p95={:.1}ms scale_p95={:.1}ms transfer_p95={:.1}ms \
sws_p95={:.1}ms encode_p95={:.1}ms total_p95={:.1}ms \
output_bps={:.0} frame_bytes_max={}",
self.capture_fps,
self.encoded_fps,
self.sent_fps,
self.pipewire_dropped,
self.over_budget_count,
self.capture_queue_depth,
self.encoded_queue_depth,
self.capture_gap_p95_ms,
self.capture_gap_max_ms,
self.encoded_gap_p95_ms,
self.encoded_gap_max_ms,
self.sent_gap_p95_ms,
self.sent_gap_max_ms,
self.frame_age_p95_ms,
self.frame_age_max_ms,
self.send_wait_p95_ms,
self.import_p95_ms,
self.scale_p95_ms,
self.transfer_p95_ms,
self.sws_p95_ms,
self.encode_p95_ms,
self.total_p95_ms,
self.output_bytes_per_sec,
self.output_frame_bytes_max,
)
}
}
// ---------------------------------------------------------------------------
// Statistics helpers
// ---------------------------------------------------------------------------
fn avg_f64(data: &[f64]) -> f64 {
if data.is_empty() {
return 0.0;
}
data.iter().sum::<f64>() / data.len() as f64
}
fn p95_f64(data: &[f64]) -> f64 {
if data.is_empty() {
return 0.0;
}
let mut sorted: Vec<f64> = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((sorted.len() as f64) * 0.95).floor() as usize;
sorted[idx.min(sorted.len() - 1)]
}
fn max_f64(data: &[f64]) -> f64 {
data.iter().copied().fold(0.0_f64, f64::max)
}
fn avg_ms(data: &[u64]) -> f64 {
if data.is_empty() {
return 0.0;
}
data.iter().sum::<u64>() as f64 / data.len() as f64 / 1000.0
}
fn p95_ms(data: &[u64]) -> f64 {
if data.is_empty() {
return 0.0;
}
let mut sorted = data.to_vec();
sorted.sort_unstable();
let idx = ((sorted.len() as f64) * 0.95).floor() as usize;
sorted[idx.min(sorted.len() - 1)] as f64 / 1000.0
}
fn sum_usize(data: &[usize]) -> usize {
data.iter().sum()
}
fn p95_usize(data: &[usize]) -> usize {
if data.is_empty() {
return 0;
}
let mut sorted = data.to_vec();
sorted.sort_unstable();
let idx = ((sorted.len() as f64) * 0.95).floor() as usize;
sorted[idx.min(sorted.len() - 1)]
}
fn max_usize(data: &[usize]) -> usize {
data.iter().copied().max().unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_stats_snapshot() {
let mut stats = PipelineStats::new();
let snap = stats.snapshot_and_reset();
assert_eq!(snap.capture_frames, 0);
assert_eq!(snap.encoded_frames, 0);
assert_eq!(snap.sent_frames, 0);
}
#[test]
fn record_and_snapshot_counts() {
let mut stats = PipelineStats::new();
stats.record_capture();
stats.record_capture();
stats.record_encode(&FrameTimings {
total_us: 5000,
output_bytes: 1000,
..Default::default()
});
stats.record_send(0.1, None);
let snap = stats.snapshot_and_reset();
assert_eq!(snap.capture_frames, 2);
assert_eq!(snap.encoded_frames, 1);
assert_eq!(snap.sent_frames, 1);
}
#[test]
fn p95_computation() {
// 100 values: 0.0 through 99.0
let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
let result = p95_f64(&data);
assert!((result - 95.0).abs() < 1.0, "p95 of 0..100 should be ~95, got {result}");
}
#[test]
fn p95_ms_microseconds() {
let data: Vec<u64> = (0..100).map(|i| i * 1000).collect(); // 0ms..99ms
let result = p95_ms(&data);
assert!((result - 95.0).abs() < 1.0, "p95_ms should be ~95ms, got {result}");
}
#[test]
fn snapshot_resets_counters() {
let mut stats = PipelineStats::new();
stats.record_capture();
let _ = stats.snapshot_and_reset();
let snap = stats.snapshot_and_reset();
assert_eq!(snap.capture_frames, 0);
}
#[test]
fn display_format_contains_key_fields() {
let mut stats = PipelineStats::new();
stats.record_capture();
stats.record_encode(&FrameTimings {
total_us: 10000,
output_bytes: 5000,
..Default::default()
});
stats.record_send(0.5, None);
let snap = stats.snapshot_and_reset();
let text = format!("{snap}");
assert!(text.contains("capture_fps="));
assert!(text.contains("encoded_fps="));
assert!(text.contains("sent_fps="));
assert!(text.contains("total_p95="));
}
}

749
src/webrtc.rs Normal file
View File

@@ -0,0 +1,749 @@
// WebRTC 传输模块 — 使用 str0m (Sans-IO) 将 H.264 编码帧推送到浏览器
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, UdpSocket};
use std::time::Instant;
use anyhow::{bail, Result};
use str0m::change::SdpOffer;
use str0m::format::Codec;
use str0m::media::{Frequency, MediaKind, MediaTime, Mid, Pt};
use str0m::net::{Protocol, Receive};
use str0m::{Candidate, Event, IceConnectionState, Input, Output, Rtc, RtcConfig};
// ── 嵌入式 HTML 测试页面 ──────────────────────────────────────────────────
const HTML_PAGE: &str = r#"<!DOCTYPE html>
<html>
<head><title>wl-webrtc P0</title>
<style>body{background:#000;color:#fff;font-family:monospace;display:flex;flex-direction:column;align-items:center;justify-content:center;height:100vh;margin:0}
video{max-width:90vw;max-height:80vh;border:1px solid #333}
#status{margin:12px;font-size:14px;color:#aaa}
#debug{position:fixed;bottom:8px;left:8px;font-size:11px;color:#666;max-width:90vw;white-space:pre-wrap}
#stats-panel{position:fixed;top:8px;right:8px;background:rgba(0,0,0,0.7);color:#0f0;font:11px monospace;padding:6px 10px;border-radius:4px;z-index:100;pointer-events:none;max-width:90vw;white-space:pre;line-height:1.5}
</style></head>
<body>
<div id="status">Connecting...</div>
<video id="video" autoplay playsinline muted></video>
<pre id="debug"></pre>
<div id="stats-panel"></div>
<script>
const status = document.getElementById('status');
const video = document.getElementById('video');
const debug = document.getElementById('debug');
let pc = null;
const log = msg => { debug.textContent += msg + '\n'; console.log(msg); };
function preferH264(sdp) {
const lines = sdp.split('\r\n');
const h264Pts = lines
.filter(line => line.startsWith('a=rtpmap:') && line.toUpperCase().includes('H264/90000'))
.map(line => line.match(/^a=rtpmap:(\d+)/)?.[1])
.filter(Boolean);
if (h264Pts.length === 0) return sdp;
return lines.map(line => {
if (!line.startsWith('m=video ')) return line;
const parts = line.split(' ');
const header = parts.slice(0, 3);
const pts = parts.slice(3);
const preferred = h264Pts.filter(pt => pts.includes(pt));
const rest = pts.filter(pt => !preferred.includes(pt));
return [...header, ...preferred, ...rest].join(' ');
}).join('\r\n');
}
function installStatsLogger(peer) {
const panel = document.getElementById('stats-panel');
let prev = null;
const intervalSecs = 1;
setInterval(() => {
if (peer !== pc) return;
peer.getStats().then(stats => {
let rtp = null, rtt = null, codecStr = '';
let freezeCount = null, totalFreezesDuration = null;
stats.forEach(report => {
if (report.type === 'inbound-rtp' && report.kind === 'video') rtp = report;
if (report.type === 'codec' && report.mimeType && report.mimeType.includes('H264'))
codecStr = report.mimeType + ' ' + (report.payloadType || '');
// candidate-pair: feature-detect 'selected' property
if (report.type === 'candidate-pair') {
const isSel = ('selected' in report) ? report.selected : report.state === 'succeeded';
if (isSel && typeof report.currentRoundTripTime === 'number') rtt = report.currentRoundTripTime;
}
});
// Freeze stats (feature-detect)
if (rtp && typeof rtp.freezeCount !== 'undefined') {
freezeCount = rtp.freezeCount;
totalFreezesDuration = rtp.totalFreezesDuration;
}
if (!rtp) return;
const cur = {
framesDecoded: rtp.framesDecoded || 0,
framesDropped: rtp.framesDropped || 0,
framesPerSecond: rtp.framesPerSecond || 0,
packetsLost: rtp.packetsLost || 0,
jitter: rtp.jitter || 0,
bytesReceived: rtp.bytesReceived || 0,
totalDecodeTime: rtp.totalDecodeTime || 0,
jitterBufferDelay: rtp.jitterBufferDelay || 0,
jitterBufferEmittedCount: rtp.jitterBufferEmittedCount || 0,
freezeCount: freezeCount,
totalFreezesDuration: totalFreezesDuration,
rtt: rtt,
};
// Raw log to debug element (backward compat)
log('RTP-in: decoded=' + cur.framesDecoded + ' lost=' + cur.packetsLost +
' bytes=' + cur.bytesReceived + ' fps=' + cur.framesPerSecond +
(codecStr ? ' codec=' + codecStr : ''));
if (!prev) { prev = cur; return; }
// Compute deltas
const dFrames = cur.framesDecoded - prev.framesDecoded;
const dDropped = cur.framesDropped - prev.framesDropped;
const dLost = cur.packetsLost - prev.packetsLost;
const dBytes = cur.bytesReceived - prev.bytesReceived;
const dDecodeTime = cur.totalDecodeTime - prev.totalDecodeTime;
const dJitterBufDelay = cur.jitterBufferDelay - prev.jitterBufferDelay;
const dJitterBufCount = cur.jitterBufferEmittedCount - prev.jitterBufferEmittedCount;
const kbps = Math.round(dBytes * 8 / intervalSecs / 1000);
const decodeMs = dFrames > 0 ? (dDecodeTime / dFrames * 1000).toFixed(1) : '—';
const jitterBufMs = dJitterBufCount > 0 ? (dJitterBufDelay / dJitterBufCount * 1000).toFixed(1) : '—';
const jitterMs = (cur.jitter * 1000).toFixed(1);
const rttMs = cur.rtt !== null ? (cur.rtt * 1000).toFixed(1) : null;
let line = 'FPS:' + cur.framesPerSecond +
' Decoded:' + cur.framesDecoded + '(+' + dFrames + ')' +
' Dropped:' + cur.framesDropped + (dDropped > 0 ? '(+' + dDropped + ')' : '') +
' Lost:' + dLost +
' Jitter:' + jitterMs + 'ms' +
(rttMs !== null ? ' RTT:' + rttMs + 'ms' : '') +
' Decode:' + decodeMs + 'ms' +
' JBuf:' + jitterBufMs + 'ms';
if (freezeCount !== null) {
const dFreeze = cur.freezeCount - (prev.freezeCount || 0);
if (cur.freezeCount > 0 || dFreeze > 0)
line += ' Freeze:' + cur.freezeCount + '(+' + dFreeze + ')';
}
line += ' ' + kbps + 'kbps';
panel.textContent = line;
prev = cur;
}).catch(() => {});
}, intervalSecs * 1000);
}
function connect() {
if (pc) pc.close();
pc = new RTCPeerConnection();
const peer = pc;
peer.ontrack = e => {
log('ontrack: streams=' + e.streams.length + ' kind=' + e.track.kind);
video.srcObject = e.streams[0];
status.textContent = 'Track received';
};
peer.oniceconnectionstatechange = () => {
log('ICE: ' + peer.iceConnectionState);
status.textContent = 'ICE: ' + peer.iceConnectionState;
};
peer.addTransceiver('video', { direction: 'recvonly' });
installStatsLogger(peer);
peer.createOffer().then(offer => {
offer.sdp = preferH264(offer.sdp);
return peer.setLocalDescription(offer);
})
.then(() => new Promise(resolve => {
if (peer.iceGatheringState === 'complete') resolve();
else peer.onicegatheringstatechange = () => { if (peer.iceGatheringState === 'complete') resolve(); };
}))
.then(() => fetch('/sdp', { method: 'POST', body: JSON.stringify(peer.localDescription) }))
.then(r => { if (!r.ok) throw new Error('SDP exchange failed: ' + r.status); return r.json(); })
.then(answer => { if (answer.error) throw new Error(answer.error); return peer.setRemoteDescription(answer); })
.then(() => log('SDP answer set'))
.catch(e => {
status.textContent = 'Error: ' + e.message;
log('ERROR: ' + e.message + ' — retrying in 2s...');
console.error(e);
setTimeout(connect, 2000);
});
}
connect();
</script>
</body></html>"#;
// ── WebRTC 状态 ───────────────────────────────────────────────────────────
pub struct WebRtcState {
signal_listener: TcpListener,
inner: Option<WebRtcInner>,
fps: u32,
}
struct WebRtcInner {
rtc: Rtc,
socket: UdpSocket,
udp_addr: SocketAddr,
video_mid: Option<Mid>,
video_pt: Option<Pt>,
connected: bool,
need_keyframe: bool,
rtp_clock: u32,
buf: Vec<u8>,
}
impl WebRtcState {
pub fn new(port: u16, fps: u32) -> Result<Self> {
let signal_listener = TcpListener::bind(format!("0.0.0.0:{port}"))?;
signal_listener.set_nonblocking(true)?;
tracing::info!("WebRTC signaling on http://0.0.0.0:{port}/");
Ok(Self {
signal_listener,
inner: None,
fps,
})
}
pub fn handle_signaling(&mut self) -> Result<bool> {
let mut handled = false;
loop {
let (mut stream, _addr) = match self.signal_listener.accept() {
Ok(s) => s,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => bail!("TCP accept error: {e}"),
};
handled = true;
stream.set_nonblocking(true)?;
let mut req = vec![0u8; 65536];
let n = match stream.read(&mut req) {
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
Err(e) => {
tracing::warn!("TCP read error: {e}");
continue;
}
};
let req_str = String::from_utf8_lossy(&req[..n]);
if req_str.starts_with("GET / ")
|| req_str.starts_with("GET /sdp ")
&& !req_str.contains("Content-Type: application/json")
{
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
HTML_PAGE.len(),
HTML_PAGE
);
if let Err(e) = stream.write_all(resp.as_bytes()) {
tracing::debug!("HTTP write error: {e}");
}
} else if req_str.starts_with("POST /sdp") {
let body = extract_body(&req_str);
if body.is_empty() {
let resp = "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\nempty body";
if let Err(e) = stream.write_all(resp.as_bytes()) {
tracing::debug!("HTTP write error: {e}");
}
continue;
}
match WebRtcInner::new(self.fps)
.and_then(|mut new_inner| {
let answer_json = new_inner.handle_sdp_offer(body.as_bytes())?;
Ok((new_inner, answer_json))
}) {
Ok((new_inner, answer_json)) => {
let replacing = self.inner.is_some();
self.inner = Some(new_inner);
if replacing {
tracing::info!("Replaced WebRTC connection (old dropped)");
} else {
tracing::info!("New WebRTC connection");
}
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
answer_json.len(),
answer_json
);
if let Err(e) = stream.write_all(resp.as_bytes()) {
tracing::debug!("HTTP write error: {e}");
}
}
Err(e) => {
tracing::error!("SDP offer handling failed: {e}");
let resp = "HTTP/1.1 500 Internal Server Error\r\nConnection: close\r\n\r\n";
if let Err(e) = stream.write_all(resp.as_bytes()) {
tracing::debug!("HTTP write error: {e}");
}
}
}
} else {
let resp = "HTTP/1.1 404 Not Found\r\nConnection: close\r\n\r\n";
if let Err(e) = stream.write_all(resp.as_bytes()) {
tracing::debug!("HTTP write error: {e}");
}
}
}
Ok(handled)
}
pub fn poll_rtc(&mut self) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
if inner.poll_rtc()? {
tracing::info!("WebRTC connection closed; clearing connection state");
self.inner = None;
}
}
Ok(())
}
pub fn feed_network(&mut self) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.feed_network()?;
}
Ok(())
}
pub fn poll_and_feed(&mut self) -> Result<()> {
self.poll_rtc()?;
self.feed_network()?;
self.poll_rtc()
}
pub fn write_h264_frame(&mut self, data: &[u8], frame_number: u64, fps: u32) -> Result<()> {
let should_destroy = if let Some(inner) = self.inner.as_mut() {
inner.write_h264_frame(data, frame_number, fps)?
} else {
false
};
if should_destroy {
tracing::info!("WebRTC connection failed during write; clearing connection state");
self.inner = None;
}
Ok(())
}
pub fn is_connected(&self) -> bool {
self.inner.as_ref().is_some_and(WebRtcInner::is_connected)
}
}
impl WebRtcInner {
fn new(fps: u32) -> Result<Self> {
let _ = fps;
let mut rtc = RtcConfig::new().build(Instant::now());
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_nonblocking(true)?;
// Increase UDP send buffer to absorb IDR frame bursts (256KB IDR → ~145 RTP
// packets in a single poll_rtc loop). Default Linux wmem is ~208KB which
// causes EAGAIN on large keyframes. 2MB comfortably buffers several IDRs.
const SND_BUF_REQ: usize = 2 * 1024 * 1024;
// SAFETY: fd is a valid UDP socket; setsockopt/getsockopt with SOL_SOCKET +
// SO_SNDBUF are safe on Linux. We check the return value and log the actual
// kernel-assigned buffer (Linux may cap at wmem_max and/or double the value).
unsafe {
let fd = std::os::unix::io::AsRawFd::as_raw_fd(&socket);
let val: libc::c_int = SND_BUF_REQ as libc::c_int;
let ret = libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_SNDBUF,
&val as *const libc::c_int as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
if ret < 0 {
tracing::warn!("setsockopt SO_SNDBUF failed (errno {})", std::io::Error::last_os_error());
}
let mut actual: libc::c_int = 0;
let mut actual_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let gret = libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_SNDBUF,
&mut actual as *mut libc::c_int as *mut libc::c_void,
&mut actual_len,
);
if gret == 0 {
tracing::info!(
"UDP send buffer: requested {}KB, actual {}KB",
SND_BUF_REQ / 1024,
actual / 1024,
);
}
}
let local_addr = socket.local_addr()?;
let lan_ip = local_ip().unwrap_or_else(|| {
tracing::debug!("Failed to detect LAN IP, falling back to 127.0.0.1");
"127.0.0.1".to_string()
});
let candidate_addr: SocketAddr = format!("{lan_ip}:{}", local_addr.port()).parse()?;
let candidate = Candidate::host(candidate_addr, "udp")
.map_err(|e| anyhow::anyhow!("candidate: {e}"))?;
rtc.add_local_candidate(candidate);
tracing::info!("WebRTC UDP: {candidate_addr} (bound 0.0.0.0)");
Ok(Self {
rtc,
socket,
udp_addr: candidate_addr,
video_mid: None,
video_pt: None,
connected: false,
need_keyframe: false,
rtp_clock: 0,
buf: vec![0u8; 65535],
})
}
fn handle_sdp_offer(&mut self, body: &[u8]) -> Result<String> {
let offer: SdpOffer = serde_json::from_slice(body)
.map_err(|e| anyhow::anyhow!("parse SDP offer: {e}"))?;
let answer = self
.rtc
.sdp_api()
.accept_offer(offer)
.map_err(|e| anyhow::anyhow!("accept_offer: {e}"))?;
self.need_keyframe = true;
tracing::info!("SDP exchange complete, waiting for ICE/DTLS...");
self.discover_video_params();
let answer_json =
serde_json::to_vec(&answer).map_err(|e| anyhow::anyhow!("serialize answer: {e}"))?;
String::from_utf8(answer_json).map_err(|e| anyhow::anyhow!("answer utf8: {e}"))
}
fn discover_video_params(&mut self) {
let mid = match self.video_mid {
Some(m) => m,
None => {
tracing::debug!("discover_video_params: no video_mid yet");
return;
}
};
self.video_pt = None;
if let Some(writer) = self.rtc.writer(mid) {
for pp in writer.payload_params() {
tracing::debug!("Codec: pt={:?} spec={:?}", pp.pt(), pp.spec());
if pp.spec().codec.is_video() && pp.spec().codec == Codec::H264 {
self.video_pt = Some(pp.pt());
tracing::info!("H.264 payload type: {:?}", pp.pt());
break;
}
}
}
if self.video_pt.is_none() {
tracing::warn!("discover_video_params: no H.264 codec found for mid={mid}");
}
}
fn poll_rtc(&mut self) -> Result<bool> {
loop {
match self.rtc.poll_output() {
Ok(Output::Transmit(t)) => {
tracing::trace!("TX {} bytes -> {}", t.contents.len(), t.destination);
if let Err(e) = self.socket.send_to(&t.contents, t.destination) {
if e.kind() == std::io::ErrorKind::WouldBlock {
tracing::debug!(
"UDP send WouldBlock ({} bytes) — send buffer full",
t.contents.len(),
);
} else {
tracing::warn!("UDP send error to {}: {e}", t.destination);
}
}
}
Ok(Output::Event(e)) => {
tracing::debug!("RTC event: {e:?}");
match &e {
Event::Connected => {
tracing::info!("WebRTC connected!");
self.connected = true;
self.need_keyframe = true;
self.discover_video_params();
}
Event::IceConnectionStateChange(IceConnectionState::Disconnected) => {
tracing::warn!("WebRTC disconnected");
self.connected = false;
return Ok(true);
}
Event::MediaAdded(ma) => {
tracing::info!("Media added: mid={} kind={:?}", ma.mid, ma.kind);
if ma.kind == MediaKind::Video {
if let Some(media) = self.rtc.media(ma.mid) {
if media.direction().is_sending()
&& self.video_mid.is_none()
{
self.video_mid = Some(ma.mid);
tracing::info!("Captured video mid: {}", ma.mid);
self.discover_video_params();
}
}
}
}
_ => {
tracing::debug!("WebRTC event: {:?}", e);
}
}
}
Ok(Output::Timeout(_t)) => break,
Err(e) => {
tracing::error!("rtc.poll_output error: {e}");
self.connected = false;
return Ok(true);
}
}
}
Ok(false)
}
fn feed_network(&mut self) -> Result<()> {
let mut recv_count = 0u32;
loop {
match self.socket.recv_from(&mut self.buf) {
Ok((n, source)) => {
recv_count += 1;
if recv_count <= 5 {
tracing::trace!("UDP recv {} bytes from {}", n, source);
}
let input = Input::Receive(
Instant::now(),
Receive {
proto: Protocol::Udp,
source,
destination: self.udp_addr,
contents: self.buf[..n]
.try_into()
.map_err(|e| anyhow::anyhow!("receive contents: {e}"))?,
},
);
self.rtc
.handle_input(input)
.map_err(|e| anyhow::anyhow!("handle_input({n} bytes from {source}): {e}"))?;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => bail!("UDP recv error: {e}"),
}
}
self.rtc
.handle_input(Input::Timeout(Instant::now()))
.map_err(|e| anyhow::anyhow!("handle timeout: {e}"))?;
Ok(())
}
fn write_h264_frame(&mut self, data: &[u8], frame_number: u64, fps: u32) -> Result<bool> {
if !self.connected {
return Ok(false);
}
let mid = match self.video_mid {
Some(m) => m,
None => {
tracing::debug!("write_h264: no video_mid");
return Ok(false);
}
};
let pt = match self.video_pt {
Some(p) => p,
None => {
tracing::debug!("write_h264: no video_pt");
return Ok(false);
}
};
if self.need_keyframe {
if !is_idr_nalu(data) {
tracing::debug!(
"write_h264: skipping non-IDR frame ({} bytes), waiting for keyframe",
data.len()
);
return Ok(false);
}
tracing::info!(
"write_h264: got IDR keyframe ({} bytes), starting playback",
data.len()
);
self.need_keyframe = false;
}
let ticks_per_second = 90_000u64;
let fps = fps.max(1) as u64;
let rtp_timestamp = frame_number.saturating_mul(ticks_per_second) / fps;
self.rtp_clock = rtp_timestamp as u32;
let rtp_time = MediaTime::new(rtp_timestamp, Frequency::NINETY_KHZ);
let writer = match self.rtc.writer(mid) {
Some(w) => w,
None => {
tracing::debug!("write_h264: no writer for mid={mid}");
return Ok(false);
}
};
tracing::debug!(
"write_h264: {} bytes, pt={:?}, rtp={}",
data.len(),
pt,
self.rtp_clock
);
writer
.write(pt, Instant::now(), rtp_time, data)
.map_err(|e| anyhow::anyhow!("writer.write: {e}"))?;
let should_destroy = self.poll_rtc()?;
Ok(should_destroy)
}
fn is_connected(&self) -> bool {
self.connected
}
}
// ── 工具函数 ──────────────────────────────────────────────────────────────
/// 从 HTTP 请求中提取 body在 \r\n\r\n 之后)
fn extract_body(req: &str) -> &str {
if let Some(idx) = req.find("\r\n\r\n") {
req.get(idx + 4..).unwrap_or("")
} else {
""
}
}
fn local_ip() -> Option<String> {
std::net::UdpSocket::bind("0.0.0.0:0")
.ok()
.and_then(|s| {
s.connect("1.1.1.1:80").ok()?;
let addr = s.local_addr().ok()?;
drop(s);
let ip = addr.ip().to_string();
if ip == "0.0.0.0" || ip.starts_with("127.") {
return None;
}
Some(ip)
})
}
fn is_idr_nalu(data: &[u8]) -> bool {
let mut i = 0;
while i < data.len() {
let tail = &data[i..];
if tail.starts_with(&[0, 0, 0, 1]) {
let Some(&header) = tail.get(4) else { break };
if header & 0x1F == 5 {
return true;
}
i += 5;
} else if tail.starts_with(&[0, 0, 1]) {
let Some(&header) = tail.get(3) else { break };
if header & 0x1F == 5 {
return true;
}
i += 4;
} else {
i += 1;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_data() {
assert!(!is_idr_nalu(&[]));
}
#[test]
fn short_data_no_start_code() {
assert!(!is_idr_nalu(&[0]));
assert!(!is_idr_nalu(&[0, 0]));
assert!(!is_idr_nalu(&[1, 2, 3]));
}
#[test]
fn three_byte_start_code_no_nal_header() {
assert!(!is_idr_nalu(&[0, 0, 1]));
}
#[test]
fn four_byte_start_code_no_nal_header() {
assert!(!is_idr_nalu(&[0, 0, 0, 1]));
}
#[test]
fn three_byte_start_code_idr_at_tail() {
assert!(is_idr_nalu(&[0, 0, 1, 0x65]));
assert!(!is_idr_nalu(&[0, 0, 1, 0x01]));
}
#[test]
fn four_byte_start_code_idr_at_tail() {
assert!(is_idr_nalu(&[0, 0, 0, 1, 0x65]));
assert!(!is_idr_nalu(&[0, 0, 0, 1, 0x01]));
}
#[test]
fn idr_in_middle_of_frame() {
let data: Vec<u8> = [
&[0, 0, 0, 1, 0x67][..], // SPS
&[0, 0, 0, 1, 0x68][..], // PPS
&[0, 0, 0, 1, 0x65][..], // IDR
]
.concat();
assert!(is_idr_nalu(&data));
}
#[test]
fn no_idr_in_frame() {
let data: Vec<u8> = [
&[0, 0, 0, 1, 0x67][..], // SPS
&[0, 0, 0, 1, 0x68][..], // PPS
]
.concat();
assert!(!is_idr_nalu(&data));
}
#[test]
fn mixed_start_code_lengths() {
let data: Vec<u8> = [
&[0, 0, 0, 1, 0x67][..], // SPS (4-byte start code)
&[0, 0, 1, 0x65][..], // IDR (3-byte start code)
]
.concat();
assert!(is_idr_nalu(&data));
}
#[test]
fn all_zeros() {
assert!(!is_idr_nalu(&[0, 0, 0, 0, 0, 0, 0, 0]));
}
}