From 226768c3e340589e67b15cfbbea0a003b14f7c06 Mon Sep 17 00:00:00 2001 From: dailz Date: Sat, 6 Jun 2026 15:12:49 +0800 Subject: [PATCH] fix(avhw): handle tx.send() failure and pause encoding on WebRTC disconnect (closes #6) - Replace 'let _ = tx.send()' with proper error handling: log warning, set webrtc_disconnected flag, and break drain loop on SendError - Add Arc webrtc_paused shared between State/StatePortal and SwEncState, synced from wrtc.is_connected() in poll_webrtc() - Skip encoding in encode_filtered_frame() when paused or disconnected - Drain and discard stale channel frames on disconnect - Resume encoding automatically on WebRTC reconnection --- src/avhw.rs | 26 +++++++- src/state.rs | 29 ++++++++- src/state_portal.rs | 155 ++++++++++++++++++++++++++++++++------------ 3 files changed, 163 insertions(+), 47 deletions(-) diff --git a/src/avhw.rs b/src/avhw.rs index 0e62690..a679d3b 100644 --- a/src/avhw.rs +++ b/src/avhw.rs @@ -3,6 +3,8 @@ use std::mem; use std::os::fd::{AsRawFd, RawFd}; use std::os::raw::c_void; use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::ptr; use anyhow::{bail, Result}; @@ -611,6 +613,8 @@ pub struct SwEncState { yuv_frame: *mut ffi::AVFrame, starting_timestamp: Option, frames_written: bool, + webrtc_disconnected: bool, + webrtc_paused: Option>, } unsafe impl Send for SwEncState {} @@ -660,6 +664,8 @@ impl SwEncState { yuv_frame, starting_timestamp: None, frames_written: false, + webrtc_disconnected: false, + webrtc_paused: None, }) } @@ -674,6 +680,7 @@ impl SwEncState { bitrate: u64, gop_size: u32, tx: crossbeam_channel::Sender>, + webrtc_paused: Arc, ) -> Result { tracing::info!( "SwEncState::new_webrtc: GPU downscale {width}x{height} BGRA -> {enc_width}x{enc_height} NV12, software H.264 -> WebRTC" @@ -705,6 +712,8 @@ impl SwEncState { yuv_frame, starting_timestamp: None, frames_written: false, + webrtc_disconnected: false, + webrtc_paused: Some(webrtc_paused), }) } @@ -774,6 +783,14 @@ impl SwEncState { } fn encode_filtered_frame(&mut self, filtered: &ff::frame::Video) -> Result<()> { + if self.webrtc_disconnected { + return Ok(()); + } + if let Some(ref paused) = self.webrtc_paused { + if paused.load(Ordering::Relaxed) { + return Ok(()); + } + } let mut sw_nv12 = unsafe { ffi::av_frame_alloc() }; if sw_nv12.is_null() { bail!("av_frame_alloc failed for NV12 transfer frame"); @@ -876,7 +893,14 @@ impl SwEncState { let data: &[u8] = unsafe { std::slice::from_raw_parts(raw.data, raw.size as usize) }; - let _ = tx.send(data.to_vec()); + if let Err(e) = tx.send(data.to_vec()) { + tracing::warn!( + "WebRTC channel send failed (receiver dropped): {} bytes lost", + e.0.len() + ); + self.webrtc_disconnected = true; + break; + } } } None => {} diff --git a/src/state.rs b/src/state.rs index 0f5f244..59ddd2d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,6 +3,8 @@ use std::mem; use std::os::fd::{AsFd, OwnedFd}; use std::os::unix::io::FromRawFd; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::time::Instant; use anyhow::Result; @@ -223,6 +225,7 @@ pub struct State { pub webrtc_tx: Option>>, webrtc_rx: Option>>, webrtc_frames_sent: u64, + webrtc_paused: Option>, } // --------------------------------------------------------------------------- @@ -273,12 +276,14 @@ impl State { let fps = args.fps; let drm_device = args.drm_device.as_ref().map(PathBuf::from); - let (webrtc, webrtc_tx, webrtc_rx) = if args.port > 0 { + let (webrtc, webrtc_tx, webrtc_rx, webrtc_paused) = if args.port > 0 { let (tx, rx) = crossbeam_channel::bounded(32); let wrtc = WebRtcState::new(args.port, args.fps)?; - (Some(wrtc), Some(tx), Some(rx)) + // paused=true until first WebRTC client connects + let paused = Arc::new(AtomicBool::new(true)); + (Some(wrtc), Some(tx), Some(rx), Some(paused)) } else { - (None, None, None) + (None, None, None, None) }; let mut state = Self { @@ -309,6 +314,7 @@ impl State { webrtc_tx, webrtc_rx, webrtc_frames_sent: 0, + webrtc_paused, }; // registry_queue_init consumes registry events internally during its @@ -641,9 +647,25 @@ impl State { wrtc.handle_signaling()?; wrtc.poll_and_feed()?; + let connected = wrtc.is_connected(); + + if let Some(ref paused) = self.webrtc_paused { + let was_paused = paused.load(Ordering::Relaxed); + let now_paused = !connected; + if was_paused && !now_paused { + tracing::info!("WebRTC client connected, resuming encoding"); + } else if !was_paused && now_paused { + tracing::warn!("WebRTC client disconnected, pausing encoding"); + } + paused.store(now_paused, Ordering::Relaxed); + } + if let Some(ref rx) = self.webrtc_rx { let mut count = 0u32; while let Ok(data) = rx.try_recv() { + if !connected { + continue; + } count += 1; if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) { tracing::debug!("WebRTC write frame error: {e}"); @@ -703,6 +725,7 @@ impl State { bitrate, actual_gop_size, tx.clone(), + self.webrtc_paused.as_ref().expect("webrtc_paused must exist when webrtc_tx exists").clone(), ) { Ok(enc) => StreamingEncoder::WebRtc(enc), Err(e) => { diff --git a/src/state_portal.rs b/src/state_portal.rs index 871ea02..3fba65a 100644 --- a/src/state_portal.rs +++ b/src/state_portal.rs @@ -1,14 +1,16 @@ // 采集门户状态模块 —— 通过 PipeWire/DMA-BUF 进行屏幕采集并编码 use std::os::fd::AsRawFd; use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::{bail, Result}; +use anyhow::{bail, Result}; // 错误处理工具 -use crate::args::Args; -use crate::avhw::{self, SwEncState}; -use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame}; -use crate::webrtc::WebRtcState; +use crate::args::Args; // 命令行参数 +use crate::avhw::{self, SwEncState}; // 软件编码器状态(VAAPI 导入 + H.264 编码) +use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame}; // PipeWire 屏幕采集端点 +use crate::webrtc::WebRtcState; // WebRTC 信令与媒体传输 /// 门户采集的阶段状态 /// - WaitingForFormat: 等待接收到第一帧 DMA-BUF 以确定视频格式参数 @@ -23,20 +25,21 @@ enum PortalStage { /// 负责管理从 PipeWire 采集屏幕帧、通过 VAAPI 硬件编码的完整生命周期。 /// 工作流程:等待第一帧 → 创建编码器 → 持续编码帧数据。 pub struct StatePortal { - stage: PortalStage, - enc: Option, - cap: CapPortal, - args: Args, - errored: bool, - drm_device: Option, - frames_encoded: u64, - start_time: Option, - last_stats_time: Option, - last_stats_frames: u64, - webrtc: Option, - webrtc_tx: Option>>, + stage: PortalStage, // 当前采集阶段(等待首帧 / 流式编码中) + enc: Option, // 软件编码器,首帧到达后初始化 + cap: CapPortal, // PipeWire 屏幕采集端点 + args: Args, // 用户命令行参数 + errored: bool, // 是否遇到不可恢复的错误 + drm_device: Option, // DRM 渲染设备路径(可自动检测) + frames_encoded: u64, // 已编码帧数 + start_time: Option, // 编码开始时间 + last_stats_time: Option, // 上一次统计日志时间 + last_stats_frames: u64, // 上一次统计时的已编码帧数 + webrtc: Option, // WebRTC 状态(仅 WebRTC 模式启用) + webrtc_tx: Option>>, // 编码帧发送通道 webrtc_rx: Option>>, webrtc_frames_sent: u64, + webrtc_paused: Option>, } impl StatePortal { @@ -53,12 +56,13 @@ impl StatePortal { let cap = CapPortal::new(&args)?; - let (webrtc, webrtc_tx, webrtc_rx) = if args.port > 0 { + let (webrtc, webrtc_tx, webrtc_rx, webrtc_paused) = if args.port > 0 { let (tx, rx) = crossbeam_channel::bounded(32); let wrtc = WebRtcState::new(args.port, args.fps)?; - (Some(wrtc), Some(tx), Some(rx)) + let paused = Arc::new(AtomicBool::new(true)); + (Some(wrtc), Some(tx), Some(rx), Some(paused)) } else { - (None, None, None) + (None, None, None, None) }; Ok(Self { @@ -76,6 +80,7 @@ impl StatePortal { webrtc_tx, webrtc_rx, webrtc_frames_sent: 0, + webrtc_paused, }) } @@ -85,9 +90,11 @@ impl StatePortal { /// `block=false` 时使用 try_recv 非阻塞检查。 /// 返回 `Ok(true)` 表示已处理事件,`Ok(false)` 表示暂无数据。 pub fn poll_and_encode(&mut self, block: bool) -> Result { + // 先处理 WebRTC 信令、网络轮询,并转发已编码帧 // WebRTC: process signaling, network, and forward encoded frames self.poll_webrtc()?; + // 检查 PipeWire 控制事件(流结束 / 错误) if let Ok(ctrl) = self.cap.event_receiver().try_recv() { match ctrl { PwCtrlEvent::StreamEnded => { @@ -103,12 +110,15 @@ impl StatePortal { } } + // 根据阻塞模式选择不同的帧接收策略 let frame = if block { + // 阻塞模式:最多等待 10ms 接收帧 match self.cap.frame_receiver().recv_timeout(std::time::Duration::from_millis(10)) { Ok(frame) => frame, Err(_) => return Ok(false), } } else { + // 非阻塞模式:立即尝试接收,无数据则返回 match self.cap.frame_receiver().try_recv() { Ok(frame) => frame, Err(_) => return Ok(false), @@ -117,6 +127,7 @@ impl StatePortal { match self.stage { PortalStage::WaitingForFormat => { + // 首帧到达,记录 DMA-BUF 格式信息 tracing::info!( "First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}", frame.width, @@ -126,7 +137,9 @@ impl StatePortal { frame.modifier ); + // 自动检测或确认 DRM 设备是否支持导入该帧 let drm_path = self.resolve_drm_device_for_frame(&frame)?; + // 计算编码目标分辨率(不超过 2560x1440) let (enc_width, enc_height) = portal_encode_dimensions(frame.width, frame.height); tracing::info!( "Portal software encode target: {}x{} -> {}x{} @ {} fps", @@ -136,9 +149,11 @@ impl StatePortal { enc_height, self.args.fps, ); + // 码率:未指定时按分辨率 × 帧率动态计算 let actual_bitrate = self.args.bitrate.unwrap_or_else(|| { 2 * (enc_width as u64) * (enc_height as u64) * (self.args.fps as u64) / 100 }); + // GOP 大小:WebRTC 模式使用更小的 GOP(fps/2,最低10),MP4 模式使用 fps let actual_gop_size = self.args.gop_size.unwrap_or_else(|| { if self.webrtc_tx.is_some() { (self.args.fps / 2).max(10) @@ -147,7 +162,9 @@ impl StatePortal { } }); + // 根据是否启用 WebRTC 选择不同的编码器构造方式 let enc = if let Some(ref tx) = self.webrtc_tx { + let paused = self.webrtc_paused.as_ref().expect("webrtc_paused must exist when webrtc_tx exists"); avhw::SwEncState::new_webrtc( &drm_path, frame.width, @@ -158,8 +175,10 @@ impl StatePortal { actual_bitrate, actual_gop_size, tx.clone(), + paused.clone(), )? } else { + // MP4 模式:编码输出写入文件 avhw::SwEncState::new( &drm_path, std::path::Path::new(self.args.output.as_deref().expect("output required for MP4 mode")), @@ -174,37 +193,47 @@ impl StatePortal { }; self.enc = Some(enc); - self.stage = PortalStage::Streaming; + self.stage = PortalStage::Streaming; // 切换到流式编码阶段 self.start_time = Some(Instant::now()); self.last_stats_time = Some(Instant::now()); tracing::info!("First frame processed, encoder initialized, transitioning to Streaming"); - drop(frame); + drop(frame); // 首帧仅用于初始化,不参与编码 } PortalStage::Streaming => { + // 流式编码阶段:直接处理帧 self.handle_pw_frame(frame)?; } } + // 在返回前再次轮询 WebRTC,确保本帧编码后的数据及时转发 // WebRTC: drain encoded frames produced by this poll before returning. self.poll_webrtc()?; Ok(true) } + /// 为当前帧解析可用的 DRM 渲染设备 + /// + /// 如果用户已通过 `--drm-device` 指定设备,直接返回; + /// 否则遍历系统中所有 DRM render node,逐个尝试导入 DMA-BUF 帧来找到兼容设备。 fn resolve_drm_device_for_frame(&mut self, frame: &PwDmaBufFrame) -> Result { + // 用户已显式指定 DRM 设备,直接使用 if let Some(ref drm) = self.drm_device { return Ok(drm.clone()); } + // 查找系统中所有 DRM render node(如 /dev/dri/renderD128) let candidates = crate::state::find_drm_render_nodes(); if candidates.is_empty() { bail!("No DRM render device found. Specify --drm-device."); } + // 逐个尝试导入 DMA-BUF 帧,找到第一个兼容的设备 let mut failures = Vec::new(); for candidate in &candidates { match crate::avhw::test_dma_buf_import(candidate, frame) { Ok(()) => { + // 成功导入,缓存检测结果并返回 tracing::info!( "Auto-detected DRM device: {} (tested {} candidates)", candidate.display(), @@ -214,6 +243,7 @@ impl StatePortal { return Ok(candidate.clone()); } Err(e) => { + // 导入失败,记录原因,继续尝试下一个设备 tracing::debug!( "DRM device {} cannot import DMA-BUF: {e}", candidate.display(), @@ -223,8 +253,8 @@ impl StatePortal { } } + // 所有候选设备均失败,返回详细错误信息 bail!( - "No DRM render device can import the DMA-BUF frame. Tried: {}", failures .into_iter() .map(|(p, e)| format!("{} ({e})", p.display())) @@ -238,11 +268,13 @@ impl StatePortal { /// 通过 `av_hwframe_map` 零拷贝导入 VAAPI,然后交给 SwEncState 完成: /// scale_vaapi GPU 缩放、2K NV12 回读、YUV420P 格式转换、软件 H.264 编码。 fn handle_pw_frame(&mut self, frame: PwDmaBufFrame) -> Result<()> { + // 获取已初始化的编码器引用 let enc = match self.enc.as_mut() { Some(enc) => enc, None => bail!("encoder not initialized"), }; + // 将 DMA-BUF 帧零拷贝导入 VAAPI 硬件帧池 let mut vaapi_frame = unsafe { avhw::import_dma_buf_to_vaapi( enc.frames_rgb().as_ptr(), @@ -256,14 +288,17 @@ impl StatePortal { ) }?; + // 设置帧的显示时间戳(PTS),基于已编码帧序号 let pts = self.frames_encoded as i64; unsafe { (*vaapi_frame.as_mut_ptr()).pts = pts; } + // 送入编码器完成:缩放 → 回读 → 格式转换 → H.264 编码 enc.encode_frame(&vaapi_frame)?; self.frames_encoded += 1; + // 每 10 秒输出一次编码统计(已编码帧数、实时帧率) if let Some(last) = self.last_stats_time { if last.elapsed() >= Duration::from_secs(10) { let delta_frames = self.frames_encoded - self.last_stats_frames; @@ -310,15 +345,35 @@ impl StatePortal { self.errored } + /// 轮询 WebRTC 信令通道并转发编码帧 + /// + /// 处理信令交换、网络轮询,以及从编码通道中取出已编码的 H.264 数据 + /// 并通过 WebRTC 发送。 fn poll_webrtc(&mut self) -> Result<()> { - let Some(ref mut wrtc) = self.webrtc else { return Ok(()); }; + let Some(ref mut wrtc) = self.webrtc else { return Ok(()) }; wrtc.handle_signaling()?; wrtc.poll_and_feed()?; + let connected = wrtc.is_connected(); + + if let Some(ref paused) = self.webrtc_paused { + let was_paused = paused.load(Ordering::Relaxed); + let now_paused = !connected; + if was_paused && !now_paused { + tracing::info!("WebRTC client connected, resuming encoding"); + } else if !was_paused && now_paused { + tracing::warn!("WebRTC client disconnected, pausing encoding"); + } + paused.store(now_paused, Ordering::Relaxed); + } + if let Some(ref rx) = self.webrtc_rx { let mut count = 0u32; while let Ok(data) = rx.try_recv() { + if !connected { + continue; + } count += 1; if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) { tracing::debug!("WebRTC write frame error: {e}"); @@ -335,23 +390,31 @@ impl StatePortal { } impl Drop for StatePortal { + // 析构时自动调用 shutdown,确保编码器被刷新、资源被释放 fn drop(&mut self) { self.shutdown(); } } +/// 计算编码目标分辨率 +/// +/// 将原始分辨率等比缩放至不超过 2560×1440(2K),并确保宽高为偶数 +/// (H.264 编码要求偶数尺寸)。 fn portal_encode_dimensions(width: u32, height: u32) -> (u32, u32) { - const TARGET_W: u32 = 2560; - const TARGET_H: u32 = 1440; + const TARGET_W: u32 = 2560; // 目标最大宽度 + const TARGET_H: u32 = 1440; // 目标最大高度 + // 原始分辨率已在 2K 以内,直接对齐偶数 if width <= TARGET_W && height <= TARGET_H { - return (width & !1, height & !1); + return (width & !1, height & !1); // & !1 确保为偶数 } + // 按宽度限制等比缩放 let width_limited_h = ((height as u64) * (TARGET_W as u64) / (width as u64)) as u32; if width_limited_h <= TARGET_H { (TARGET_W & !1, width_limited_h & !1) } else { + // 按高度限制等比缩放 let height_limited_w = ((width as u64) * (TARGET_H as u64) / (height as u64)) as u32; (height_limited_w & !1, TARGET_H & !1) } @@ -367,19 +430,23 @@ fn resolve_drm_device(args: &Args) -> Result> { Ok(None) } +/// 构建测试用的 AVDRMFrameDescriptor(仅测试用途) +/// +/// 将 PwDmaBufFrame 转换为 FFmpeg 的 DRM 帧描述符结构体, +/// 用于验证 DMA-BUF 元数据映射的正确性。 #[cfg(test)] fn build_drm_descriptor(frame: &PwDmaBufFrame) -> ffmpeg_next::ffi::AVDRMFrameDescriptor { let mut desc: ffmpeg_next::ffi::AVDRMFrameDescriptor = unsafe { std::mem::zeroed() }; - desc.nb_objects = 1; - desc.objects[0].fd = frame.fd.as_raw_fd(); - desc.objects[0].size = 0; - desc.objects[0].format_modifier = frame.modifier; - desc.nb_layers = 1; - desc.layers[0].format = frame.format; - desc.layers[0].nb_planes = 1; - desc.layers[0].planes[0].object_index = 0; - desc.layers[0].planes[0].offset = frame.offset as isize; - desc.layers[0].planes[0].pitch = frame.stride as isize; + desc.nb_objects = 1; // 单个 DMA-BUF 对象 + desc.objects[0].fd = frame.fd.as_raw_fd(); // DMA-BUF 文件描述符 + desc.objects[0].size = 0; // 大小设为 0(内核自动确定) + desc.objects[0].format_modifier = frame.modifier; // DRM 格式修饰符(如线性、tiled) + desc.nb_layers = 1; // 单层 + desc.layers[0].format = frame.format; // 像素格式(如 XR24) + desc.layers[0].nb_planes = 1; // 单平面 + desc.layers[0].planes[0].object_index = 0; // 指向第 0 个对象 + desc.layers[0].planes[0].offset = frame.offset as isize; // 帧数据偏移 + desc.layers[0].planes[0].pitch = frame.stride as isize; // 行跨度(stride) desc } @@ -391,15 +458,16 @@ mod tests { /// 创建测试用的 DMA-BUF 帧数据(使用 stderr fd 的副本作为占位) fn make_test_frame() -> PwDmaBufFrame { // Create a dummy fd from stderr (always valid fd 2) + // 使用 stderr(fd 2)的副本作为虚拟文件描述符 let fd = unsafe { OwnedFd::from_raw_fd(libc::dup(2)) }; PwDmaBufFrame { fd, offset: 0, - stride: 1920 * 4, - modifier: 0, // DRM_FORMAT_MOD_LINEAR + stride: 1920 * 4, // 每行 1920 像素 × 4 字节(XRGB) + modifier: 0, // DRM_FORMAT_MOD_LINEAR(线性布局) width: 1920, height: 1080, - format: 0x34325258, // XR24 little-endian + format: 0x34325258, // XR24 little-endian(XRGB8888) pts: 12345, } } @@ -464,13 +532,14 @@ mod tests { assert_eq!(result, None); } + /// 测试:使用自定义偏移量和 stride 构建 DRM 描述符 #[test] fn build_drm_descriptor_custom_offset_and_stride() { let frame = PwDmaBufFrame { fd: unsafe { OwnedFd::from_raw_fd(libc::dup(2)) }, - offset: 4096, - stride: 3840 * 4, - modifier: 0x0100000000000001, // AMD modifiers + offset: 4096, // 4KB 对齐偏移 + stride: 3840 * 4, // 4K 宽度 × 4 字节 + modifier: 0x0100000000000001, // AMD modifiers width: 3840, height: 2160, format: 0x34325258,