From 46367ef6b58ba98219dca222d12dea1334594010 Mon Sep 17 00:00:00 2001 From: dailz Date: Thu, 4 Jun 2026 22:10:46 +0800 Subject: [PATCH] fix(state): add WebRTC support to wlr-screencopy backend Fixes #1 -- --port mode with wlr-screencopy backend caused panic at negotiate_format() because self.args.output is None and .expect() was called unconditionally. Changes: - Introduce StreamingEncoder enum wrapping EncState (MP4) and SwEncState (WebRTC) with unified frames_rgb/encode_frame/flush API - Add WebRTC fields to State (webrtc, webrtc_tx, webrtc_rx, webrtc_frames_sent) matching Portal backend pattern - State::new() returns Result for clean WebRtcState init failure - negotiate_format() branches on webrtc_tx: WebRTC path uses SwEncState::new_webrtc(), MP4 path unchanged (hardware VAAPI) - Add poll_webrtc() method to drive signaling + channel drain - Event loop calls poll_webrtc() each iteration - Fix pre-existing test/bench Args construction (Option output, missing no_persist field) --- src/backend_detect.rs | 3 +- src/bin/sw_encode_bench.rs | 3 +- src/bin/vaapi_import_bench.rs | 3 +- src/main.rs | 4 +- src/state.rs | 143 +++++++++++++++++++++++++++++----- src/state_portal.rs | 6 +- 6 files changed, 136 insertions(+), 26 deletions(-) diff --git a/src/backend_detect.rs b/src/backend_detect.rs index 609c373..38ea435 100644 --- a/src/backend_detect.rs +++ b/src/backend_detect.rs @@ -178,7 +178,7 @@ mod tests { // 测试辅助函数:构造指定后端参数的 Args 实例 fn make_args(backend: Option<&str>) -> Args { Args { - output: "test.mp4".to_string(), + output: Some("test.mp4".to_string()), output_name: None, fps: 30, codec: "h264".to_string(), @@ -189,6 +189,7 @@ mod tests { verbose: false, backend: backend.map(String::from), port: 0, + no_persist: false, } } diff --git a/src/bin/sw_encode_bench.rs b/src/bin/sw_encode_bench.rs index d421224..a1bcb5d 100644 --- a/src/bin/sw_encode_bench.rs +++ b/src/bin/sw_encode_bench.rs @@ -102,7 +102,7 @@ fn main() -> Result<()> { println!(" (Select a screen to share in the portal dialog)"); let portal_args = Args { - output: bench_args.output.clone(), + output: Some(bench_args.output.clone()), output_name: None, fps: 60, codec: "h264".to_string(), @@ -113,6 +113,7 @@ fn main() -> Result<()> { verbose: false, backend: Some("portal".to_string()), port: 0, + no_persist: false, }; let cap = CapPortal::new(&portal_args)?; diff --git a/src/bin/vaapi_import_bench.rs b/src/bin/vaapi_import_bench.rs index 6af1b94..0e2e025 100644 --- a/src/bin/vaapi_import_bench.rs +++ b/src/bin/vaapi_import_bench.rs @@ -871,7 +871,7 @@ fn main() -> Result<()> { println!(" (Select a screen to share in the portal dialog)"); let portal_args = Args { - output: bench_args.output.clone(), + output: Some(bench_args.output.clone()), output_name: None, fps: 60, codec: "h264".to_string(), @@ -882,6 +882,7 @@ fn main() -> Result<()> { verbose: false, backend: Some("portal".to_string()), port: 0, + no_persist: false, }; let cap = CapPortal::new(&portal_args)?; diff --git a/src/main.rs b/src/main.rs index 9df2698..8ea5840 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,7 +100,7 @@ fn run_wlr_screencopy(args: Args) -> Result<()> { let qhandle = queue.handle(); // State 是 wlr-screencopy 后端的核心状态机, // 内部管理输出探测、截屏请求、编码器构建、帧采集等阶段 - let mut state = State::new(gm, args, qhandle); + let mut state = State::new(gm, args, qhandle)?; // Extract the Wayland fd and consume any immediately-available events. // prepare_read() flushes outgoing requests; read() pulls whatever the @@ -246,6 +246,8 @@ fn run_wlr_screencopy(args: Args) -> Result<()> { // - Streaming: 正常采集中,请求下一帧 state.queue_alloc_frame(); + state.poll_webrtc()?; + // 状态机遇到致命错误时退出 if state.errored { tracing::error!("Fatal error in state machine, exiting"); diff --git a/src/state.rs b/src/state.rs index a82a3bd..0f5f244 100644 --- a/src/state.rs +++ b/src/state.rs @@ -41,10 +41,11 @@ use ffmpeg_next as ff; use ffmpeg_next::ffi; use crate::args::Args; -use crate::avhw::{AvHwDevCtx, EncState}; +use crate::avhw::{AvHwDevCtx, EncState, SwEncState}; use crate::cap_wlr_screencopy::CapWlrScreencopy; use crate::fps_limit::FpsLimit; use crate::transform::{transpose_if_transform_transposed, Transform}; +use crate::webrtc::WebRtcState; // --------------------------------------------------------------------------- // CaptureSource trait @@ -113,6 +114,42 @@ struct WlrHeadInfo { /// User data for XdgOutput dispatch to identify which WlOutput it belongs to. pub struct OutputId(pub u32); +// --------------------------------------------------------------------------- +// StreamingEncoder +// --------------------------------------------------------------------------- + +/// Wraps the two possible encoder backends for the streaming stage. +/// +/// - `Mp4(EncState)` — hardware VAAPI encoder writing to an MP4 file +/// - `WebRtc(SwEncState)` — software encoder feeding H.264 NALUs into a WebRTC channel +pub enum StreamingEncoder { + Mp4(EncState), + WebRtc(SwEncState), +} + +impl StreamingEncoder { + fn frames_rgb(&self) -> &crate::avhw::AvHwFrameCtx { + match self { + StreamingEncoder::Mp4(enc) => enc.frames_rgb(), + StreamingEncoder::WebRtc(enc) => enc.frames_rgb(), + } + } + + fn encode_frame(&mut self, hw_frame: &ffmpeg_next::frame::Video) -> anyhow::Result<()> { + match self { + StreamingEncoder::Mp4(enc) => enc.encode_frame(hw_frame), + StreamingEncoder::WebRtc(enc) => enc.encode_frame(hw_frame), + } + } + + pub fn flush(&mut self) -> anyhow::Result<()> { + match self { + StreamingEncoder::Mp4(enc) => enc.flush(), + StreamingEncoder::WebRtc(enc) => enc.flush(), + } + } +} + // --------------------------------------------------------------------------- // EncConstructionStage // --------------------------------------------------------------------------- @@ -142,7 +179,7 @@ pub enum EncConstructionStage { Streaming { output_info: OutputInfo, output: WlOutput, - enc: EncState, + enc: StreamingEncoder, cap: S, screencopy_manager: ZwlrScreencopyManagerV1, dmabuf: ZwpLinuxDmabufV1, @@ -182,6 +219,10 @@ pub struct State { pub qhandle: QueueHandle>, pub drm_device: Option, pub drm_device_from_compositor: Option, + pub webrtc: Option, + pub webrtc_tx: Option>>, + webrtc_rx: Option>>, + webrtc_frames_sent: u64, } // --------------------------------------------------------------------------- @@ -228,9 +269,18 @@ impl State { // --------------------------------------------------------------------------- impl State { - pub fn new(gm: GlobalList, args: Args, qhandle: QueueHandle>) -> Self { + pub fn new(gm: GlobalList, args: Args, qhandle: QueueHandle>) -> Result { let fps = args.fps; let drm_device = args.drm_device.as_ref().map(PathBuf::from); + + let (webrtc, webrtc_tx, webrtc_rx) = if args.port > 0 { + let (tx, rx) = crossbeam_channel::bounded(32); + let wrtc = WebRtcState::new(args.port, args.fps)?; + (Some(wrtc), Some(tx), Some(rx)) + } else { + (None, None, None) + }; + let mut state = Self { stage: EncConstructionStage::ProbingOutputs { outputs: Vec::new(), @@ -255,6 +305,10 @@ impl State { qhandle, drm_device, drm_device_from_compositor: None, + webrtc, + webrtc_tx, + webrtc_rx, + webrtc_frames_sent: 0, }; // registry_queue_init consumes registry events internally during its @@ -262,7 +316,7 @@ impl State { // We must manually bind the initial globals here. state.bind_initial_globals(); - state + Ok(state) } /// Iterate over the GlobalList from registry_queue_init and bind all @@ -581,6 +635,29 @@ impl State { self.errored = true; } + pub fn poll_webrtc(&mut self) -> Result<()> { + let Some(ref mut wrtc) = self.webrtc else { return Ok(()) }; + + wrtc.handle_signaling()?; + wrtc.poll_and_feed()?; + + if let Some(ref rx) = self.webrtc_rx { + let mut count = 0u32; + while let Ok(data) = rx.try_recv() { + count += 1; + if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) { + tracing::debug!("WebRTC write frame error: {e}"); + } + self.webrtc_frames_sent = self.webrtc_frames_sent.saturating_add(1); + } + if count > 0 { + tracing::info!("WebRTC forwarded {count} frames from channel"); + } + } + + Ok(()) + } + pub fn negotiate_format(&mut self, format: u32, width: u32, height: u32) { let stage_data = match mem::replace(&mut self.stage, EncConstructionStage::Intermediate) { EncConstructionStage::EverythingButFmt { @@ -611,22 +688,48 @@ impl State { .args .bitrate .unwrap_or_else(|| 2 * (width as u64) * (height as u64) * (fps as u64) / 100); - let enc = match crate::avhw::create_encoder( - &drm_path, - Path::new(self.args.output.as_deref().expect("output required for MP4 mode")), - width, - height, - fps, - output_info.transform, - self.args.bitrate, - self.args.gop_size, - Some(hw_device_ctx), - ) { - Ok(enc) => enc, - Err(e) => { - tracing::error!("EncState::new failed: {}", e); - self.errored = true; - return; + + let enc = if let Some(ref tx) = self.webrtc_tx { + let (enc_w, enc_h) = + transpose_if_transform_transposed(output_info.transform, width as i32, height as i32); + let actual_gop_size = self.args.gop_size.unwrap_or((fps / 2).max(10)); + match SwEncState::new_webrtc( + &drm_path, + width, + height, + enc_w as u32, + enc_h as u32, + fps, + bitrate, + actual_gop_size, + tx.clone(), + ) { + Ok(enc) => StreamingEncoder::WebRtc(enc), + Err(e) => { + tracing::error!("SwEncState::new_webrtc failed: {}", e); + self.errored = true; + return; + } + } + } else { + let output_path = self.args.output.as_deref().expect("output required for MP4 mode"); + match crate::avhw::create_encoder( + &drm_path, + Path::new(output_path), + width, + height, + fps, + output_info.transform, + self.args.bitrate, + self.args.gop_size, + Some(hw_device_ctx), + ) { + Ok(enc) => StreamingEncoder::Mp4(enc), + Err(e) => { + tracing::error!("EncState::new failed: {}", e); + self.errored = true; + return; + } } }; tracing::info!( diff --git a/src/state_portal.rs b/src/state_portal.rs index 176e466..871ea02 100644 --- a/src/state_portal.rs +++ b/src/state_portal.rs @@ -424,7 +424,7 @@ mod tests { #[test] fn resolve_drm_device_explicit() { let args = Args { - output: "test.mp4".to_string(), + output: Some("test.mp4".to_string()), output_name: None, fps: 30, codec: "h264".to_string(), @@ -435,6 +435,7 @@ mod tests { verbose: false, backend: None, port: 0, + no_persist: false, }; let result = resolve_drm_device(&args).unwrap(); assert_eq!( @@ -446,7 +447,7 @@ mod tests { #[test] fn resolve_drm_device_none_when_not_specified() { let args = Args { - output: "test.mp4".to_string(), + output: Some("test.mp4".to_string()), output_name: None, fps: 30, codec: "h264".to_string(), @@ -457,6 +458,7 @@ mod tests { verbose: false, backend: None, port: 0, + no_persist: false, }; let result = resolve_drm_device(&args).unwrap(); assert_eq!(result, None);