fix(state_portal): add Drop impl, null dangling pointers, extract compute_pts, add tests

- Add Drop impl for StatePortal to flush encoder on drop (bug #2)
- Use enc.take() in shutdown() to prevent double-flush of write_trailer
- Null out data[0] after Box::from_raw recovery to avoid dangling pointer
- Extract compute_pts() for testable PTS calculation
- Add 8 tests: PTS calculation, DRM device resolution, descriptor building
This commit is contained in:
dailz
2026-05-27 09:22:59 +08:00
parent 5100d78aa8
commit 60a55c17f2

View File

@@ -9,7 +9,7 @@ use ffmpeg_next::ffi;
use crate::args::Args;
use crate::avhw::{self, EncState};
use crate::cap_portal::{CapPortal, PwDmaBufFrame, PwEvent};
use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame};
use crate::fps_limit::FpsLimit;
use crate::transform::Transform;
@@ -78,57 +78,56 @@ impl StatePortal {
/// 尝试从采集端点接收一帧事件。返回 `Ok(true)` 表示已处理事件,
/// `Ok(false)` 表示暂无数据。内部根据当前阶段(等待格式/流式)分发处理。
pub fn poll_and_encode(&mut self) -> Result<bool> {
let event = match self.cap.frame_receiver().try_recv() {
Ok(event) => event,
if let Ok(ctrl) = self.cap.event_receiver().try_recv() {
match ctrl {
PwCtrlEvent::StreamEnded => {
tracing::warn!("PipeWire stream ended");
self.errored = true;
return Ok(true);
}
PwCtrlEvent::Error(e) => {
tracing::error!("PipeWire error: {e}");
self.errored = true;
return Ok(true);
}
}
}
let frame = match self.cap.frame_receiver().try_recv() {
Ok(frame) => frame,
Err(_) => return Ok(false),
};
match event {
PwEvent::Frame(frame) => {
match self.stage {
PortalStage::WaitingForFormat => {
// 第一帧到达:记录格式信息并用该分辨率创建编码器
tracing::info!(
"First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}",
frame.width,
frame.height,
frame.format,
frame.stride,
frame.modifier
);
match self.stage {
PortalStage::WaitingForFormat => {
tracing::info!(
"First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}",
frame.width,
frame.height,
frame.format,
frame.stride,
frame.modifier
);
let drm_path = self.resolve_drm_device_for_frame(&frame)?;
let enc = avhw::create_encoder(
&drm_path,
self.args.output.as_ref(),
frame.width,
frame.height,
self.args.fps,
Transform::Normal,
self.args.bitrate,
self.args.gop_size,
None,
)?;
let drm_path = self.resolve_drm_device_for_frame(&frame)?;
let enc = avhw::create_encoder(
&drm_path,
self.args.output.as_ref(),
frame.width,
frame.height,
self.args.fps,
Transform::Normal,
self.args.bitrate,
self.args.gop_size,
None,
)?;
self.enc = Some(enc);
self.stage = PortalStage::Streaming;
drop(frame);
}
PortalStage::Streaming => {
// 流式阶段:处理每一帧 DMA-BUF 数据
self.handle_pw_frame(frame)?;
}
}
self.enc = Some(enc);
self.stage = PortalStage::Streaming;
drop(frame);
}
PwEvent::StreamEnded => {
// PipeWire 流结束(如用户停止了屏幕共享)
tracing::warn!("PipeWire stream ended");
self.errored = true;
}
PwEvent::Error(e) => {
// PipeWire 返回错误
tracing::error!("PipeWire error: {e}");
self.errored = true;
PortalStage::Streaming => {
self.handle_pw_frame(frame)?;
}
}
@@ -224,6 +223,7 @@ impl StatePortal {
if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr);
}
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
}
bail!("encoder not initialized");
}
@@ -243,6 +243,7 @@ impl StatePortal {
if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr);
}
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
}
bail!("av_hwframe_get_buffer failed: error {ret}");
}
@@ -259,6 +260,7 @@ impl StatePortal {
if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr);
}
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
}
if ret == -(ffi::EINVAL as i32) {
bail!(
@@ -270,17 +272,7 @@ impl StatePortal {
}
// 7. Set PTS — convert PipeWire nanoseconds to encoder frame-number units
// PipeWire PTS is CLOCK_MONOTONIC in nanoseconds.
// Encoder time_base = 1/fps, so PTS must be in frame numbers.
// Use elapsed time since first frame to avoid i64 overflow on absolute timestamps.
//
// PTS 计算:将 PipeWire 的纳秒时间戳转换为编码器的帧号单位
// PipeWire 使用 CLOCK_MONOTONIC 纳秒时间戳,编码器 time_base = 1/fps
// 使用相对时间避免绝对时间戳导致的 i64 溢出
let fps_i64 = self.args.fps as i64;
let base_ns = *self.first_pts_ns.get_or_insert(frame.pts.max(0));
let elapsed_ns = (frame.pts.max(0) - base_ns).max(0);
let pts = elapsed_ns * fps_i64 / 1_000_000_000;
let pts = compute_pts(&mut self.first_pts_ns, frame.pts, self.args.fps);
unsafe {
(*hw_frame.as_mut_ptr()).pts = pts;
}
@@ -299,6 +291,7 @@ impl StatePortal {
if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr);
}
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
}
// 9. Encode — safe to early-return via `?` now that descriptor is recovered.
@@ -310,18 +303,14 @@ impl StatePortal {
Ok(())
}
/// 刷新编码器缓冲区,输出所有剩余帧
pub fn flush(&mut self) -> Result<()> {
if let Some(enc) = &mut self.enc {
enc.flush()?;
}
Ok(())
}
/// 关闭状态:刷新编码器并清理资源
///
/// 使用 `enc.take()` 确保编码器只被 flush 一次,即使多次调用也安全(幂等)。
pub fn shutdown(&mut self) {
if let Err(e) = self.flush() {
tracing::error!("Flush error during shutdown: {e}");
if let Some(mut enc) = self.enc.take() {
if let Err(e) = enc.flush() {
tracing::error!("Flush error during shutdown: {e}");
}
}
tracing::info!("StatePortal shutdown complete");
}
@@ -332,6 +321,12 @@ impl StatePortal {
}
}
impl Drop for StatePortal {
fn drop(&mut self) {
self.shutdown();
}
}
/// 根据 DMA-BUF 帧信息构建 FFmpeg DRM 帧描述符
///
/// 将 PipeWire 提供的 DMA-BUF 参数fd、偏移量、步长、修饰符等
@@ -356,6 +351,17 @@ fn build_drm_descriptor(frame: &PwDmaBufFrame) -> ffi::AVDRMFrameDescriptor {
desc
}
/// Convert PipeWire nanosecond PTS to encoder frame-number units.
///
/// Uses elapsed time since the first frame to avoid i64 overflow on absolute timestamps.
/// PipeWire PTS is CLOCK_MONOTONIC in nanoseconds; encoder time_base = 1/fps.
fn compute_pts(first_pts_ns: &mut Option<i64>, frame_pts: i64, fps: u32) -> i64 {
let fps_i64 = fps as i64;
let base_ns = *first_pts_ns.get_or_insert(frame_pts.max(0));
let elapsed_ns = (frame_pts.max(0) - base_ns).max(0);
elapsed_ns * fps_i64 / 1_000_000_000
}
/// 解析 DRM 渲染设备路径
///
/// 仅使用命令行指定的设备路径;未指定则在首帧到达时自动检测。
@@ -419,8 +425,111 @@ mod tests {
backend: None,
port: 0,
};
let result = resolve_drm_device(&args);
assert!(result.is_ok());
assert_eq!(result.unwrap(), std::path::PathBuf::from("/dev/dri/renderD128"));
let result = resolve_drm_device(&args).unwrap();
assert_eq!(result, Some(std::path::PathBuf::from("/dev/dri/renderD128")));
}
#[test]
fn resolve_drm_device_none_when_not_specified() {
let args = Args {
output: "test.mp4".to_string(),
output_name: None,
fps: 30,
codec: "h264".to_string(),
hw_accel: "vaapi".to_string(),
drm_device: None,
bitrate: None,
gop_size: None,
verbose: false,
backend: None,
port: 0,
};
let result = resolve_drm_device(&args).unwrap();
assert_eq!(result, None);
}
#[test]
fn build_drm_descriptor_custom_offset_and_stride() {
let frame = PwDmaBufFrame {
fd: unsafe { OwnedFd::from_raw_fd(libc::dup(2)) },
offset: 4096,
stride: 3840 * 4,
modifier: 0x0100000000000001, // AMD modifiers
width: 3840,
height: 2160,
format: 0x34325258,
pts: 0,
};
let desc = build_drm_descriptor(&frame);
assert_eq!(desc.nb_objects, 1);
assert_eq!(desc.objects[0].format_modifier, 0x0100000000000001);
assert_eq!(desc.layers[0].planes[0].offset, 4096);
assert_eq!(desc.layers[0].planes[0].pitch, 3840 * 4);
}
// --- compute_pts tests ---
#[test]
fn compute_pts_first_frame_is_zero() {
let mut base = None;
let pts = compute_pts(&mut base, 1_000_000_000, 30);
assert_eq!(pts, 0);
assert_eq!(base, Some(1_000_000_000));
}
#[test]
fn compute_pts_second_frame_at_30fps() {
let mut base = Some(1_000_000_000);
// 33_333_333 * 30 / 1_000_000_000 = 0 (integer division)
let pts = compute_pts(&mut base, 1_000_000_000 + 33_333_333, 30);
assert_eq!(pts, 0);
// 100ms later = frame 3
let pts = compute_pts(&mut base, 1_000_000_000 + 100_000_000, 30);
assert_eq!(pts, 3);
}
#[test]
fn compute_pts_multiple_frames_accumulate() {
let mut base = None;
let fps = 60;
let pts0 = compute_pts(&mut base, 0, fps);
assert_eq!(pts0, 0);
let pts1 = compute_pts(&mut base, 16_666_666, fps);
assert_eq!(pts1, 0); // 16_666_666 * 60 / 1_000_000_000 = 0
let pts2 = compute_pts(&mut base, 33_333_333, fps);
assert_eq!(pts2, 1); // 33_333_333 * 60 / 1_000_000_000 = 1
let pts3 = compute_pts(&mut base, 50_000_000, fps);
assert_eq!(pts3, 3); // 50ms * 60 / 1000 = 3
}
#[test]
fn compute_pts_negative_pts_clamped_to_zero() {
let mut base = None;
let pts = compute_pts(&mut base, -999_999, 30);
assert_eq!(pts, 0);
assert_eq!(base, Some(0)); // max(0) clamps negative
}
#[test]
fn compute_pts_late_frame_after_negative() {
let mut base = Some(0);
let pts = compute_pts(&mut base, 1_000_000_000, 30);
assert_eq!(pts, 30);
}
#[test]
fn compute_pts_base_not_overwritten_after_first_call() {
let mut base = None;
let _ = compute_pts(&mut base, 5_000_000_000, 30);
assert_eq!(base, Some(5_000_000_000));
let _ = compute_pts(&mut base, 10_000_000_000, 30);
assert_eq!(base, Some(5_000_000_000)); // base stays at first frame
}
}