fix(state_portal): add Drop impl, null dangling pointers, extract compute_pts, add tests

- Add Drop impl for StatePortal to flush encoder on drop (bug #2)
- Use enc.take() in shutdown() to prevent double-flush of write_trailer
- Null out data[0] after Box::from_raw recovery to avoid dangling pointer
- Extract compute_pts() for testable PTS calculation
- Add 8 tests: PTS calculation, DRM device resolution, descriptor building
This commit is contained in:
dailz
2026-05-27 09:22:59 +08:00
parent 5100d78aa8
commit 60a55c17f2

View File

@@ -9,7 +9,7 @@ use ffmpeg_next::ffi;
use crate::args::Args; use crate::args::Args;
use crate::avhw::{self, EncState}; use crate::avhw::{self, EncState};
use crate::cap_portal::{CapPortal, PwDmaBufFrame, PwEvent}; use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame};
use crate::fps_limit::FpsLimit; use crate::fps_limit::FpsLimit;
use crate::transform::Transform; use crate::transform::Transform;
@@ -78,16 +78,28 @@ impl StatePortal {
/// 尝试从采集端点接收一帧事件。返回 `Ok(true)` 表示已处理事件, /// 尝试从采集端点接收一帧事件。返回 `Ok(true)` 表示已处理事件,
/// `Ok(false)` 表示暂无数据。内部根据当前阶段(等待格式/流式)分发处理。 /// `Ok(false)` 表示暂无数据。内部根据当前阶段(等待格式/流式)分发处理。
pub fn poll_and_encode(&mut self) -> Result<bool> { pub fn poll_and_encode(&mut self) -> Result<bool> {
let event = match self.cap.frame_receiver().try_recv() { if let Ok(ctrl) = self.cap.event_receiver().try_recv() {
Ok(event) => event, match ctrl {
PwCtrlEvent::StreamEnded => {
tracing::warn!("PipeWire stream ended");
self.errored = true;
return Ok(true);
}
PwCtrlEvent::Error(e) => {
tracing::error!("PipeWire error: {e}");
self.errored = true;
return Ok(true);
}
}
}
let frame = match self.cap.frame_receiver().try_recv() {
Ok(frame) => frame,
Err(_) => return Ok(false), Err(_) => return Ok(false),
}; };
match event {
PwEvent::Frame(frame) => {
match self.stage { match self.stage {
PortalStage::WaitingForFormat => { PortalStage::WaitingForFormat => {
// 第一帧到达:记录格式信息并用该分辨率创建编码器
tracing::info!( tracing::info!(
"First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}", "First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}",
frame.width, frame.width,
@@ -115,22 +127,9 @@ impl StatePortal {
drop(frame); drop(frame);
} }
PortalStage::Streaming => { PortalStage::Streaming => {
// 流式阶段:处理每一帧 DMA-BUF 数据
self.handle_pw_frame(frame)?; self.handle_pw_frame(frame)?;
} }
} }
}
PwEvent::StreamEnded => {
// PipeWire 流结束(如用户停止了屏幕共享)
tracing::warn!("PipeWire stream ended");
self.errored = true;
}
PwEvent::Error(e) => {
// PipeWire 返回错误
tracing::error!("PipeWire error: {e}");
self.errored = true;
}
}
Ok(true) Ok(true)
} }
@@ -224,6 +223,7 @@ impl StatePortal {
if !desc_ptr.is_null() { if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr); let _ = Box::from_raw(desc_ptr);
} }
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
} }
bail!("encoder not initialized"); bail!("encoder not initialized");
} }
@@ -243,6 +243,7 @@ impl StatePortal {
if !desc_ptr.is_null() { if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr); let _ = Box::from_raw(desc_ptr);
} }
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
} }
bail!("av_hwframe_get_buffer failed: error {ret}"); bail!("av_hwframe_get_buffer failed: error {ret}");
} }
@@ -259,6 +260,7 @@ impl StatePortal {
if !desc_ptr.is_null() { if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr); let _ = Box::from_raw(desc_ptr);
} }
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
} }
if ret == -(ffi::EINVAL as i32) { if ret == -(ffi::EINVAL as i32) {
bail!( bail!(
@@ -270,17 +272,7 @@ impl StatePortal {
} }
// 7. Set PTS — convert PipeWire nanoseconds to encoder frame-number units // 7. Set PTS — convert PipeWire nanoseconds to encoder frame-number units
// PipeWire PTS is CLOCK_MONOTONIC in nanoseconds. let pts = compute_pts(&mut self.first_pts_ns, frame.pts, self.args.fps);
// Encoder time_base = 1/fps, so PTS must be in frame numbers.
// Use elapsed time since first frame to avoid i64 overflow on absolute timestamps.
//
// PTS 计算:将 PipeWire 的纳秒时间戳转换为编码器的帧号单位
// PipeWire 使用 CLOCK_MONOTONIC 纳秒时间戳,编码器 time_base = 1/fps
// 使用相对时间避免绝对时间戳导致的 i64 溢出
let fps_i64 = self.args.fps as i64;
let base_ns = *self.first_pts_ns.get_or_insert(frame.pts.max(0));
let elapsed_ns = (frame.pts.max(0) - base_ns).max(0);
let pts = elapsed_ns * fps_i64 / 1_000_000_000;
unsafe { unsafe {
(*hw_frame.as_mut_ptr()).pts = pts; (*hw_frame.as_mut_ptr()).pts = pts;
} }
@@ -299,6 +291,7 @@ impl StatePortal {
if !desc_ptr.is_null() { if !desc_ptr.is_null() {
let _ = Box::from_raw(desc_ptr); let _ = Box::from_raw(desc_ptr);
} }
(*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut();
} }
// 9. Encode — safe to early-return via `?` now that descriptor is recovered. // 9. Encode — safe to early-return via `?` now that descriptor is recovered.
@@ -310,19 +303,15 @@ impl StatePortal {
Ok(()) Ok(())
} }
/// 刷新编码器缓冲区,输出所有剩余帧
pub fn flush(&mut self) -> Result<()> {
if let Some(enc) = &mut self.enc {
enc.flush()?;
}
Ok(())
}
/// 关闭状态:刷新编码器并清理资源 /// 关闭状态:刷新编码器并清理资源
///
/// 使用 `enc.take()` 确保编码器只被 flush 一次,即使多次调用也安全(幂等)。
pub fn shutdown(&mut self) { pub fn shutdown(&mut self) {
if let Err(e) = self.flush() { if let Some(mut enc) = self.enc.take() {
if let Err(e) = enc.flush() {
tracing::error!("Flush error during shutdown: {e}"); tracing::error!("Flush error during shutdown: {e}");
} }
}
tracing::info!("StatePortal shutdown complete"); tracing::info!("StatePortal shutdown complete");
} }
@@ -332,6 +321,12 @@ impl StatePortal {
} }
} }
impl Drop for StatePortal {
fn drop(&mut self) {
self.shutdown();
}
}
/// 根据 DMA-BUF 帧信息构建 FFmpeg DRM 帧描述符 /// 根据 DMA-BUF 帧信息构建 FFmpeg DRM 帧描述符
/// ///
/// 将 PipeWire 提供的 DMA-BUF 参数fd、偏移量、步长、修饰符等 /// 将 PipeWire 提供的 DMA-BUF 参数fd、偏移量、步长、修饰符等
@@ -356,6 +351,17 @@ fn build_drm_descriptor(frame: &PwDmaBufFrame) -> ffi::AVDRMFrameDescriptor {
desc desc
} }
/// Convert PipeWire nanosecond PTS to encoder frame-number units.
///
/// Uses elapsed time since the first frame to avoid i64 overflow on absolute timestamps.
/// PipeWire PTS is CLOCK_MONOTONIC in nanoseconds; encoder time_base = 1/fps.
fn compute_pts(first_pts_ns: &mut Option<i64>, frame_pts: i64, fps: u32) -> i64 {
let fps_i64 = fps as i64;
let base_ns = *first_pts_ns.get_or_insert(frame_pts.max(0));
let elapsed_ns = (frame_pts.max(0) - base_ns).max(0);
elapsed_ns * fps_i64 / 1_000_000_000
}
/// 解析 DRM 渲染设备路径 /// 解析 DRM 渲染设备路径
/// ///
/// 仅使用命令行指定的设备路径;未指定则在首帧到达时自动检测。 /// 仅使用命令行指定的设备路径;未指定则在首帧到达时自动检测。
@@ -419,8 +425,111 @@ mod tests {
backend: None, backend: None,
port: 0, port: 0,
}; };
let result = resolve_drm_device(&args); let result = resolve_drm_device(&args).unwrap();
assert!(result.is_ok()); assert_eq!(result, Some(std::path::PathBuf::from("/dev/dri/renderD128")));
assert_eq!(result.unwrap(), std::path::PathBuf::from("/dev/dri/renderD128")); }
#[test]
fn resolve_drm_device_none_when_not_specified() {
let args = Args {
output: "test.mp4".to_string(),
output_name: None,
fps: 30,
codec: "h264".to_string(),
hw_accel: "vaapi".to_string(),
drm_device: None,
bitrate: None,
gop_size: None,
verbose: false,
backend: None,
port: 0,
};
let result = resolve_drm_device(&args).unwrap();
assert_eq!(result, None);
}
#[test]
fn build_drm_descriptor_custom_offset_and_stride() {
let frame = PwDmaBufFrame {
fd: unsafe { OwnedFd::from_raw_fd(libc::dup(2)) },
offset: 4096,
stride: 3840 * 4,
modifier: 0x0100000000000001, // AMD modifiers
width: 3840,
height: 2160,
format: 0x34325258,
pts: 0,
};
let desc = build_drm_descriptor(&frame);
assert_eq!(desc.nb_objects, 1);
assert_eq!(desc.objects[0].format_modifier, 0x0100000000000001);
assert_eq!(desc.layers[0].planes[0].offset, 4096);
assert_eq!(desc.layers[0].planes[0].pitch, 3840 * 4);
}
// --- compute_pts tests ---
#[test]
fn compute_pts_first_frame_is_zero() {
let mut base = None;
let pts = compute_pts(&mut base, 1_000_000_000, 30);
assert_eq!(pts, 0);
assert_eq!(base, Some(1_000_000_000));
}
#[test]
fn compute_pts_second_frame_at_30fps() {
let mut base = Some(1_000_000_000);
// 33_333_333 * 30 / 1_000_000_000 = 0 (integer division)
let pts = compute_pts(&mut base, 1_000_000_000 + 33_333_333, 30);
assert_eq!(pts, 0);
// 100ms later = frame 3
let pts = compute_pts(&mut base, 1_000_000_000 + 100_000_000, 30);
assert_eq!(pts, 3);
}
#[test]
fn compute_pts_multiple_frames_accumulate() {
let mut base = None;
let fps = 60;
let pts0 = compute_pts(&mut base, 0, fps);
assert_eq!(pts0, 0);
let pts1 = compute_pts(&mut base, 16_666_666, fps);
assert_eq!(pts1, 0); // 16_666_666 * 60 / 1_000_000_000 = 0
let pts2 = compute_pts(&mut base, 33_333_333, fps);
assert_eq!(pts2, 1); // 33_333_333 * 60 / 1_000_000_000 = 1
let pts3 = compute_pts(&mut base, 50_000_000, fps);
assert_eq!(pts3, 3); // 50ms * 60 / 1000 = 3
}
#[test]
fn compute_pts_negative_pts_clamped_to_zero() {
let mut base = None;
let pts = compute_pts(&mut base, -999_999, 30);
assert_eq!(pts, 0);
assert_eq!(base, Some(0)); // max(0) clamps negative
}
#[test]
fn compute_pts_late_frame_after_negative() {
let mut base = Some(0);
let pts = compute_pts(&mut base, 1_000_000_000, 30);
assert_eq!(pts, 30);
}
#[test]
fn compute_pts_base_not_overwritten_after_first_call() {
let mut base = None;
let _ = compute_pts(&mut base, 5_000_000_000, 30);
assert_eq!(base, Some(5_000_000_000));
let _ = compute_pts(&mut base, 10_000_000_000, 30);
assert_eq!(base, Some(5_000_000_000)); // base stays at first frame
} }
} }