fix(avhw): handle tx.send() failure and pause encoding on WebRTC disconnect (closes #6)

- Replace 'let _ = tx.send()' with proper error handling: log warning,
  set webrtc_disconnected flag, and break drain loop on SendError
- Add Arc<AtomicBool> webrtc_paused shared between State/StatePortal
  and SwEncState, synced from wrtc.is_connected() in poll_webrtc()
- Skip encoding in encode_filtered_frame() when paused or disconnected
- Drain and discard stale channel frames on disconnect
- Resume encoding automatically on WebRTC reconnection
This commit is contained in:
dailz
2026-06-06 15:12:49 +08:00
parent fd170b66d9
commit 226768c3e3
3 changed files with 163 additions and 47 deletions

View File

@@ -3,6 +3,8 @@ use std::mem;
use std::os::fd::{AsRawFd, RawFd};
use std::os::raw::c_void;
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::ptr;
use anyhow::{bail, Result};
@@ -611,6 +613,8 @@ pub struct SwEncState {
yuv_frame: *mut ffi::AVFrame,
starting_timestamp: Option<i64>,
frames_written: bool,
webrtc_disconnected: bool,
webrtc_paused: Option<Arc<AtomicBool>>,
}
unsafe impl Send for SwEncState {}
@@ -660,6 +664,8 @@ impl SwEncState {
yuv_frame,
starting_timestamp: None,
frames_written: false,
webrtc_disconnected: false,
webrtc_paused: None,
})
}
@@ -674,6 +680,7 @@ impl SwEncState {
bitrate: u64,
gop_size: u32,
tx: crossbeam_channel::Sender<Vec<u8>>,
webrtc_paused: Arc<AtomicBool>,
) -> Result<Self> {
tracing::info!(
"SwEncState::new_webrtc: GPU downscale {width}x{height} BGRA -> {enc_width}x{enc_height} NV12, software H.264 -> WebRTC"
@@ -705,6 +712,8 @@ impl SwEncState {
yuv_frame,
starting_timestamp: None,
frames_written: false,
webrtc_disconnected: false,
webrtc_paused: Some(webrtc_paused),
})
}
@@ -774,6 +783,14 @@ impl SwEncState {
}
fn encode_filtered_frame(&mut self, filtered: &ff::frame::Video) -> Result<()> {
if self.webrtc_disconnected {
return Ok(());
}
if let Some(ref paused) = self.webrtc_paused {
if paused.load(Ordering::Relaxed) {
return Ok(());
}
}
let mut sw_nv12 = unsafe { ffi::av_frame_alloc() };
if sw_nv12.is_null() {
bail!("av_frame_alloc failed for NV12 transfer frame");
@@ -876,7 +893,14 @@ impl SwEncState {
let data: &[u8] = unsafe {
std::slice::from_raw_parts(raw.data, raw.size as usize)
};
let _ = tx.send(data.to_vec());
if let Err(e) = tx.send(data.to_vec()) {
tracing::warn!(
"WebRTC channel send failed (receiver dropped): {} bytes lost",
e.0.len()
);
self.webrtc_disconnected = true;
break;
}
}
}
None => {}

View File

@@ -3,6 +3,8 @@ use std::mem;
use std::os::fd::{AsFd, OwnedFd};
use std::os::unix::io::FromRawFd;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
@@ -223,6 +225,7 @@ pub struct State<S: CaptureSource> {
pub webrtc_tx: Option<crossbeam_channel::Sender<Vec<u8>>>,
webrtc_rx: Option<crossbeam_channel::Receiver<Vec<u8>>>,
webrtc_frames_sent: u64,
webrtc_paused: Option<Arc<AtomicBool>>,
}
// ---------------------------------------------------------------------------
@@ -273,12 +276,14 @@ impl<S: CaptureSource> State<S> {
let fps = args.fps;
let drm_device = args.drm_device.as_ref().map(PathBuf::from);
let (webrtc, webrtc_tx, webrtc_rx) = if args.port > 0 {
let (webrtc, webrtc_tx, webrtc_rx, webrtc_paused) = if args.port > 0 {
let (tx, rx) = crossbeam_channel::bounded(32);
let wrtc = WebRtcState::new(args.port, args.fps)?;
(Some(wrtc), Some(tx), Some(rx))
// paused=true until first WebRTC client connects
let paused = Arc::new(AtomicBool::new(true));
(Some(wrtc), Some(tx), Some(rx), Some(paused))
} else {
(None, None, None)
(None, None, None, None)
};
let mut state = Self {
@@ -309,6 +314,7 @@ impl<S: CaptureSource> State<S> {
webrtc_tx,
webrtc_rx,
webrtc_frames_sent: 0,
webrtc_paused,
};
// registry_queue_init consumes registry events internally during its
@@ -641,9 +647,25 @@ impl<S: CaptureSource> State<S> {
wrtc.handle_signaling()?;
wrtc.poll_and_feed()?;
let connected = wrtc.is_connected();
if let Some(ref paused) = self.webrtc_paused {
let was_paused = paused.load(Ordering::Relaxed);
let now_paused = !connected;
if was_paused && !now_paused {
tracing::info!("WebRTC client connected, resuming encoding");
} else if !was_paused && now_paused {
tracing::warn!("WebRTC client disconnected, pausing encoding");
}
paused.store(now_paused, Ordering::Relaxed);
}
if let Some(ref rx) = self.webrtc_rx {
let mut count = 0u32;
while let Ok(data) = rx.try_recv() {
if !connected {
continue;
}
count += 1;
if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) {
tracing::debug!("WebRTC write frame error: {e}");
@@ -703,6 +725,7 @@ impl<S: CaptureSource> State<S> {
bitrate,
actual_gop_size,
tx.clone(),
self.webrtc_paused.as_ref().expect("webrtc_paused must exist when webrtc_tx exists").clone(),
) {
Ok(enc) => StreamingEncoder::WebRtc(enc),
Err(e) => {

View File

@@ -1,14 +1,16 @@
// 采集门户状态模块 —— 通过 PipeWire/DMA-BUF 进行屏幕采集并编码
use std::os::fd::AsRawFd;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{bail, Result};
use anyhow::{bail, Result}; // 错误处理工具
use crate::args::Args;
use crate::avhw::{self, SwEncState};
use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame};
use crate::webrtc::WebRtcState;
use crate::args::Args; // 命令行参数
use crate::avhw::{self, SwEncState}; // 软件编码器状态VAAPI 导入 + H.264 编码)
use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame}; // PipeWire 屏幕采集端点
use crate::webrtc::WebRtcState; // WebRTC 信令与媒体传输
/// 门户采集的阶段状态
/// - WaitingForFormat: 等待接收到第一帧 DMA-BUF 以确定视频格式参数
@@ -23,20 +25,21 @@ enum PortalStage {
/// 负责管理从 PipeWire 采集屏幕帧、通过 VAAPI 硬件编码的完整生命周期。
/// 工作流程:等待第一帧 → 创建编码器 → 持续编码帧数据。
pub struct StatePortal {
stage: PortalStage,
enc: Option<SwEncState>,
cap: CapPortal,
args: Args,
errored: bool,
drm_device: Option<PathBuf>,
frames_encoded: u64,
start_time: Option<Instant>,
last_stats_time: Option<Instant>,
last_stats_frames: u64,
webrtc: Option<WebRtcState>,
webrtc_tx: Option<crossbeam_channel::Sender<Vec<u8>>>,
stage: PortalStage, // 当前采集阶段(等待首帧 / 流式编码中)
enc: Option<SwEncState>, // 软件编码器,首帧到达后初始化
cap: CapPortal, // PipeWire 屏幕采集端点
args: Args, // 用户命令行参数
errored: bool, // 是否遇到不可恢复的错误
drm_device: Option<PathBuf>, // DRM 渲染设备路径(可自动检测)
frames_encoded: u64, // 已编码帧数
start_time: Option<Instant>, // 编码开始时间
last_stats_time: Option<Instant>, // 上一次统计日志时间
last_stats_frames: u64, // 上一次统计时的已编码帧数
webrtc: Option<WebRtcState>, // WebRTC 状态(仅 WebRTC 模式启用)
webrtc_tx: Option<crossbeam_channel::Sender<Vec<u8>>>, // 编码帧发送通道
webrtc_rx: Option<crossbeam_channel::Receiver<Vec<u8>>>,
webrtc_frames_sent: u64,
webrtc_paused: Option<Arc<AtomicBool>>,
}
impl StatePortal {
@@ -53,12 +56,13 @@ impl StatePortal {
let cap = CapPortal::new(&args)?;
let (webrtc, webrtc_tx, webrtc_rx) = if args.port > 0 {
let (webrtc, webrtc_tx, webrtc_rx, webrtc_paused) = if args.port > 0 {
let (tx, rx) = crossbeam_channel::bounded(32);
let wrtc = WebRtcState::new(args.port, args.fps)?;
(Some(wrtc), Some(tx), Some(rx))
let paused = Arc::new(AtomicBool::new(true));
(Some(wrtc), Some(tx), Some(rx), Some(paused))
} else {
(None, None, None)
(None, None, None, None)
};
Ok(Self {
@@ -76,6 +80,7 @@ impl StatePortal {
webrtc_tx,
webrtc_rx,
webrtc_frames_sent: 0,
webrtc_paused,
})
}
@@ -85,9 +90,11 @@ impl StatePortal {
/// `block=false` 时使用 try_recv 非阻塞检查。
/// 返回 `Ok(true)` 表示已处理事件,`Ok(false)` 表示暂无数据。
pub fn poll_and_encode(&mut self, block: bool) -> Result<bool> {
// 先处理 WebRTC 信令、网络轮询,并转发已编码帧
// WebRTC: process signaling, network, and forward encoded frames
self.poll_webrtc()?;
// 检查 PipeWire 控制事件(流结束 / 错误)
if let Ok(ctrl) = self.cap.event_receiver().try_recv() {
match ctrl {
PwCtrlEvent::StreamEnded => {
@@ -103,12 +110,15 @@ impl StatePortal {
}
}
// 根据阻塞模式选择不同的帧接收策略
let frame = if block {
// 阻塞模式:最多等待 10ms 接收帧
match self.cap.frame_receiver().recv_timeout(std::time::Duration::from_millis(10)) {
Ok(frame) => frame,
Err(_) => return Ok(false),
}
} else {
// 非阻塞模式:立即尝试接收,无数据则返回
match self.cap.frame_receiver().try_recv() {
Ok(frame) => frame,
Err(_) => return Ok(false),
@@ -117,6 +127,7 @@ impl StatePortal {
match self.stage {
PortalStage::WaitingForFormat => {
// 首帧到达,记录 DMA-BUF 格式信息
tracing::info!(
"First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}",
frame.width,
@@ -126,7 +137,9 @@ impl StatePortal {
frame.modifier
);
// 自动检测或确认 DRM 设备是否支持导入该帧
let drm_path = self.resolve_drm_device_for_frame(&frame)?;
// 计算编码目标分辨率(不超过 2560x1440
let (enc_width, enc_height) = portal_encode_dimensions(frame.width, frame.height);
tracing::info!(
"Portal software encode target: {}x{} -> {}x{} @ {} fps",
@@ -136,9 +149,11 @@ impl StatePortal {
enc_height,
self.args.fps,
);
// 码率:未指定时按分辨率 × 帧率动态计算
let actual_bitrate = self.args.bitrate.unwrap_or_else(|| {
2 * (enc_width as u64) * (enc_height as u64) * (self.args.fps as u64) / 100
});
// GOP 大小WebRTC 模式使用更小的 GOPfps/2最低10MP4 模式使用 fps
let actual_gop_size = self.args.gop_size.unwrap_or_else(|| {
if self.webrtc_tx.is_some() {
(self.args.fps / 2).max(10)
@@ -147,7 +162,9 @@ impl StatePortal {
}
});
// 根据是否启用 WebRTC 选择不同的编码器构造方式
let enc = if let Some(ref tx) = self.webrtc_tx {
let paused = self.webrtc_paused.as_ref().expect("webrtc_paused must exist when webrtc_tx exists");
avhw::SwEncState::new_webrtc(
&drm_path,
frame.width,
@@ -158,8 +175,10 @@ impl StatePortal {
actual_bitrate,
actual_gop_size,
tx.clone(),
paused.clone(),
)?
} else {
// MP4 模式:编码输出写入文件
avhw::SwEncState::new(
&drm_path,
std::path::Path::new(self.args.output.as_deref().expect("output required for MP4 mode")),
@@ -174,37 +193,47 @@ impl StatePortal {
};
self.enc = Some(enc);
self.stage = PortalStage::Streaming;
self.stage = PortalStage::Streaming; // 切换到流式编码阶段
self.start_time = Some(Instant::now());
self.last_stats_time = Some(Instant::now());
tracing::info!("First frame processed, encoder initialized, transitioning to Streaming");
drop(frame);
drop(frame); // 首帧仅用于初始化,不参与编码
}
PortalStage::Streaming => {
// 流式编码阶段:直接处理帧
self.handle_pw_frame(frame)?;
}
}
// 在返回前再次轮询 WebRTC确保本帧编码后的数据及时转发
// WebRTC: drain encoded frames produced by this poll before returning.
self.poll_webrtc()?;
Ok(true)
}
/// 为当前帧解析可用的 DRM 渲染设备
///
/// 如果用户已通过 `--drm-device` 指定设备,直接返回;
/// 否则遍历系统中所有 DRM render node逐个尝试导入 DMA-BUF 帧来找到兼容设备。
fn resolve_drm_device_for_frame(&mut self, frame: &PwDmaBufFrame) -> Result<PathBuf> {
// 用户已显式指定 DRM 设备,直接使用
if let Some(ref drm) = self.drm_device {
return Ok(drm.clone());
}
// 查找系统中所有 DRM render node如 /dev/dri/renderD128
let candidates = crate::state::find_drm_render_nodes();
if candidates.is_empty() {
bail!("No DRM render device found. Specify --drm-device.");
}
// 逐个尝试导入 DMA-BUF 帧,找到第一个兼容的设备
let mut failures = Vec::new();
for candidate in &candidates {
match crate::avhw::test_dma_buf_import(candidate, frame) {
Ok(()) => {
// 成功导入,缓存检测结果并返回
tracing::info!(
"Auto-detected DRM device: {} (tested {} candidates)",
candidate.display(),
@@ -214,6 +243,7 @@ impl StatePortal {
return Ok(candidate.clone());
}
Err(e) => {
// 导入失败,记录原因,继续尝试下一个设备
tracing::debug!(
"DRM device {} cannot import DMA-BUF: {e}",
candidate.display(),
@@ -223,8 +253,8 @@ impl StatePortal {
}
}
// 所有候选设备均失败,返回详细错误信息
bail!(
"No DRM render device can import the DMA-BUF frame. Tried: {}",
failures
.into_iter()
.map(|(p, e)| format!("{} ({e})", p.display()))
@@ -238,11 +268,13 @@ impl StatePortal {
/// 通过 `av_hwframe_map` 零拷贝导入 VAAPI然后交给 SwEncState 完成:
/// scale_vaapi GPU 缩放、2K NV12 回读、YUV420P 格式转换、软件 H.264 编码。
fn handle_pw_frame(&mut self, frame: PwDmaBufFrame) -> Result<()> {
// 获取已初始化的编码器引用
let enc = match self.enc.as_mut() {
Some(enc) => enc,
None => bail!("encoder not initialized"),
};
// 将 DMA-BUF 帧零拷贝导入 VAAPI 硬件帧池
let mut vaapi_frame = unsafe {
avhw::import_dma_buf_to_vaapi(
enc.frames_rgb().as_ptr(),
@@ -256,14 +288,17 @@ impl StatePortal {
)
}?;
// 设置帧的显示时间戳PTS基于已编码帧序号
let pts = self.frames_encoded as i64;
unsafe {
(*vaapi_frame.as_mut_ptr()).pts = pts;
}
// 送入编码器完成:缩放 → 回读 → 格式转换 → H.264 编码
enc.encode_frame(&vaapi_frame)?;
self.frames_encoded += 1;
// 每 10 秒输出一次编码统计(已编码帧数、实时帧率)
if let Some(last) = self.last_stats_time {
if last.elapsed() >= Duration::from_secs(10) {
let delta_frames = self.frames_encoded - self.last_stats_frames;
@@ -310,15 +345,35 @@ impl StatePortal {
self.errored
}
/// 轮询 WebRTC 信令通道并转发编码帧
///
/// 处理信令交换、网络轮询,以及从编码通道中取出已编码的 H.264 数据
/// 并通过 WebRTC 发送。
fn poll_webrtc(&mut self) -> Result<()> {
let Some(ref mut wrtc) = self.webrtc else { return Ok(()); };
let Some(ref mut wrtc) = self.webrtc else { return Ok(()) };
wrtc.handle_signaling()?;
wrtc.poll_and_feed()?;
let connected = wrtc.is_connected();
if let Some(ref paused) = self.webrtc_paused {
let was_paused = paused.load(Ordering::Relaxed);
let now_paused = !connected;
if was_paused && !now_paused {
tracing::info!("WebRTC client connected, resuming encoding");
} else if !was_paused && now_paused {
tracing::warn!("WebRTC client disconnected, pausing encoding");
}
paused.store(now_paused, Ordering::Relaxed);
}
if let Some(ref rx) = self.webrtc_rx {
let mut count = 0u32;
while let Ok(data) = rx.try_recv() {
if !connected {
continue;
}
count += 1;
if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) {
tracing::debug!("WebRTC write frame error: {e}");
@@ -335,23 +390,31 @@ impl StatePortal {
}
impl Drop for StatePortal {
// 析构时自动调用 shutdown确保编码器被刷新、资源被释放
fn drop(&mut self) {
self.shutdown();
}
}
/// 计算编码目标分辨率
///
/// 将原始分辨率等比缩放至不超过 2560×14402K并确保宽高为偶数
/// H.264 编码要求偶数尺寸)。
fn portal_encode_dimensions(width: u32, height: u32) -> (u32, u32) {
const TARGET_W: u32 = 2560;
const TARGET_H: u32 = 1440;
const TARGET_W: u32 = 2560; // 目标最大宽度
const TARGET_H: u32 = 1440; // 目标最大高度
// 原始分辨率已在 2K 以内,直接对齐偶数
if width <= TARGET_W && height <= TARGET_H {
return (width & !1, height & !1);
return (width & !1, height & !1); // & !1 确保为偶数
}
// 按宽度限制等比缩放
let width_limited_h = ((height as u64) * (TARGET_W as u64) / (width as u64)) as u32;
if width_limited_h <= TARGET_H {
(TARGET_W & !1, width_limited_h & !1)
} else {
// 按高度限制等比缩放
let height_limited_w = ((width as u64) * (TARGET_H as u64) / (height as u64)) as u32;
(height_limited_w & !1, TARGET_H & !1)
}
@@ -367,19 +430,23 @@ fn resolve_drm_device(args: &Args) -> Result<Option<PathBuf>> {
Ok(None)
}
/// 构建测试用的 AVDRMFrameDescriptor仅测试用途
///
/// 将 PwDmaBufFrame 转换为 FFmpeg 的 DRM 帧描述符结构体,
/// 用于验证 DMA-BUF 元数据映射的正确性。
#[cfg(test)]
fn build_drm_descriptor(frame: &PwDmaBufFrame) -> ffmpeg_next::ffi::AVDRMFrameDescriptor {
let mut desc: ffmpeg_next::ffi::AVDRMFrameDescriptor = unsafe { std::mem::zeroed() };
desc.nb_objects = 1;
desc.objects[0].fd = frame.fd.as_raw_fd();
desc.objects[0].size = 0;
desc.objects[0].format_modifier = frame.modifier;
desc.nb_layers = 1;
desc.layers[0].format = frame.format;
desc.layers[0].nb_planes = 1;
desc.layers[0].planes[0].object_index = 0;
desc.layers[0].planes[0].offset = frame.offset as isize;
desc.layers[0].planes[0].pitch = frame.stride as isize;
desc.nb_objects = 1; // 单个 DMA-BUF 对象
desc.objects[0].fd = frame.fd.as_raw_fd(); // DMA-BUF 文件描述符
desc.objects[0].size = 0; // 大小设为 0内核自动确定
desc.objects[0].format_modifier = frame.modifier; // DRM 格式修饰符如线性、tiled
desc.nb_layers = 1; // 单层
desc.layers[0].format = frame.format; // 像素格式(如 XR24
desc.layers[0].nb_planes = 1; // 单平面
desc.layers[0].planes[0].object_index = 0; // 指向第 0 个对象
desc.layers[0].planes[0].offset = frame.offset as isize; // 帧数据偏移
desc.layers[0].planes[0].pitch = frame.stride as isize; // 行跨度stride
desc
}
@@ -391,15 +458,16 @@ mod tests {
/// 创建测试用的 DMA-BUF 帧数据(使用 stderr fd 的副本作为占位)
fn make_test_frame() -> PwDmaBufFrame {
// Create a dummy fd from stderr (always valid fd 2)
// 使用 stderrfd 2的副本作为虚拟文件描述符
let fd = unsafe { OwnedFd::from_raw_fd(libc::dup(2)) };
PwDmaBufFrame {
fd,
offset: 0,
stride: 1920 * 4,
modifier: 0, // DRM_FORMAT_MOD_LINEAR
stride: 1920 * 4, // 每行 1920 像素 × 4 字节XRGB
modifier: 0, // DRM_FORMAT_MOD_LINEAR(线性布局)
width: 1920,
height: 1080,
format: 0x34325258, // XR24 little-endian
format: 0x34325258, // XR24 little-endianXRGB8888
pts: 12345,
}
}
@@ -464,13 +532,14 @@ mod tests {
assert_eq!(result, None);
}
/// 测试:使用自定义偏移量和 stride 构建 DRM 描述符
#[test]
fn build_drm_descriptor_custom_offset_and_stride() {
let frame = PwDmaBufFrame {
fd: unsafe { OwnedFd::from_raw_fd(libc::dup(2)) },
offset: 4096,
stride: 3840 * 4,
modifier: 0x0100000000000001, // AMD modifiers
offset: 4096, // 4KB 对齐偏移
stride: 3840 * 4, // 4K 宽度 × 4 字节
modifier: 0x0100000000000001, // AMD modifiers
width: 3840,
height: 2160,
format: 0x34325258,