diff --git a/src/state_portal.rs b/src/state_portal.rs index 0ab9655..8cdc686 100644 --- a/src/state_portal.rs +++ b/src/state_portal.rs @@ -9,7 +9,7 @@ use ffmpeg_next::ffi; use crate::args::Args; use crate::avhw::{self, EncState}; -use crate::cap_portal::{CapPortal, PwDmaBufFrame, PwEvent}; +use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame}; use crate::fps_limit::FpsLimit; use crate::transform::Transform; @@ -78,57 +78,56 @@ impl StatePortal { /// 尝试从采集端点接收一帧事件。返回 `Ok(true)` 表示已处理事件, /// `Ok(false)` 表示暂无数据。内部根据当前阶段(等待格式/流式)分发处理。 pub fn poll_and_encode(&mut self) -> Result { - let event = match self.cap.frame_receiver().try_recv() { - Ok(event) => event, + if let Ok(ctrl) = self.cap.event_receiver().try_recv() { + match ctrl { + PwCtrlEvent::StreamEnded => { + tracing::warn!("PipeWire stream ended"); + self.errored = true; + return Ok(true); + } + PwCtrlEvent::Error(e) => { + tracing::error!("PipeWire error: {e}"); + self.errored = true; + return Ok(true); + } + } + } + + let frame = match self.cap.frame_receiver().try_recv() { + Ok(frame) => frame, Err(_) => return Ok(false), }; - match event { - PwEvent::Frame(frame) => { - match self.stage { - PortalStage::WaitingForFormat => { - // 第一帧到达:记录格式信息并用该分辨率创建编码器 - tracing::info!( - "First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}", - frame.width, - frame.height, - frame.format, - frame.stride, - frame.modifier - ); + match self.stage { + PortalStage::WaitingForFormat => { + tracing::info!( + "First DMA-BUF frame: {}x{} format=0x{:08X} stride={} modifier=0x{:X}", + frame.width, + frame.height, + frame.format, + frame.stride, + frame.modifier + ); - let drm_path = self.resolve_drm_device_for_frame(&frame)?; - let enc = avhw::create_encoder( - &drm_path, - self.args.output.as_ref(), - frame.width, - frame.height, - self.args.fps, - Transform::Normal, - self.args.bitrate, - self.args.gop_size, - None, - )?; + let drm_path = self.resolve_drm_device_for_frame(&frame)?; + let enc = avhw::create_encoder( + &drm_path, + self.args.output.as_ref(), + frame.width, + frame.height, + self.args.fps, + Transform::Normal, + self.args.bitrate, + self.args.gop_size, + None, + )?; - self.enc = Some(enc); - self.stage = PortalStage::Streaming; - drop(frame); - } - PortalStage::Streaming => { - // 流式阶段:处理每一帧 DMA-BUF 数据 - self.handle_pw_frame(frame)?; - } - } + self.enc = Some(enc); + self.stage = PortalStage::Streaming; + drop(frame); } - PwEvent::StreamEnded => { - // PipeWire 流结束(如用户停止了屏幕共享) - tracing::warn!("PipeWire stream ended"); - self.errored = true; - } - PwEvent::Error(e) => { - // PipeWire 返回错误 - tracing::error!("PipeWire error: {e}"); - self.errored = true; + PortalStage::Streaming => { + self.handle_pw_frame(frame)?; } } @@ -224,6 +223,7 @@ impl StatePortal { if !desc_ptr.is_null() { let _ = Box::from_raw(desc_ptr); } + (*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut(); } bail!("encoder not initialized"); } @@ -243,6 +243,7 @@ impl StatePortal { if !desc_ptr.is_null() { let _ = Box::from_raw(desc_ptr); } + (*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut(); } bail!("av_hwframe_get_buffer failed: error {ret}"); } @@ -259,6 +260,7 @@ impl StatePortal { if !desc_ptr.is_null() { let _ = Box::from_raw(desc_ptr); } + (*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut(); } if ret == -(ffi::EINVAL as i32) { bail!( @@ -270,17 +272,7 @@ impl StatePortal { } // 7. Set PTS — convert PipeWire nanoseconds to encoder frame-number units - // PipeWire PTS is CLOCK_MONOTONIC in nanoseconds. - // Encoder time_base = 1/fps, so PTS must be in frame numbers. - // Use elapsed time since first frame to avoid i64 overflow on absolute timestamps. - // - // PTS 计算:将 PipeWire 的纳秒时间戳转换为编码器的帧号单位 - // PipeWire 使用 CLOCK_MONOTONIC 纳秒时间戳,编码器 time_base = 1/fps - // 使用相对时间避免绝对时间戳导致的 i64 溢出 - let fps_i64 = self.args.fps as i64; - let base_ns = *self.first_pts_ns.get_or_insert(frame.pts.max(0)); - let elapsed_ns = (frame.pts.max(0) - base_ns).max(0); - let pts = elapsed_ns * fps_i64 / 1_000_000_000; + let pts = compute_pts(&mut self.first_pts_ns, frame.pts, self.args.fps); unsafe { (*hw_frame.as_mut_ptr()).pts = pts; } @@ -299,6 +291,7 @@ impl StatePortal { if !desc_ptr.is_null() { let _ = Box::from_raw(desc_ptr); } + (*raw_frame.as_mut_ptr()).data[0] = std::ptr::null_mut(); } // 9. Encode — safe to early-return via `?` now that descriptor is recovered. @@ -310,18 +303,14 @@ impl StatePortal { Ok(()) } - /// 刷新编码器缓冲区,输出所有剩余帧 - pub fn flush(&mut self) -> Result<()> { - if let Some(enc) = &mut self.enc { - enc.flush()?; - } - Ok(()) - } - /// 关闭状态:刷新编码器并清理资源 + /// + /// 使用 `enc.take()` 确保编码器只被 flush 一次,即使多次调用也安全(幂等)。 pub fn shutdown(&mut self) { - if let Err(e) = self.flush() { - tracing::error!("Flush error during shutdown: {e}"); + if let Some(mut enc) = self.enc.take() { + if let Err(e) = enc.flush() { + tracing::error!("Flush error during shutdown: {e}"); + } } tracing::info!("StatePortal shutdown complete"); } @@ -332,6 +321,12 @@ impl StatePortal { } } +impl Drop for StatePortal { + fn drop(&mut self) { + self.shutdown(); + } +} + /// 根据 DMA-BUF 帧信息构建 FFmpeg DRM 帧描述符 /// /// 将 PipeWire 提供的 DMA-BUF 参数(fd、偏移量、步长、修饰符等) @@ -356,6 +351,17 @@ fn build_drm_descriptor(frame: &PwDmaBufFrame) -> ffi::AVDRMFrameDescriptor { desc } +/// Convert PipeWire nanosecond PTS to encoder frame-number units. +/// +/// Uses elapsed time since the first frame to avoid i64 overflow on absolute timestamps. +/// PipeWire PTS is CLOCK_MONOTONIC in nanoseconds; encoder time_base = 1/fps. +fn compute_pts(first_pts_ns: &mut Option, frame_pts: i64, fps: u32) -> i64 { + let fps_i64 = fps as i64; + let base_ns = *first_pts_ns.get_or_insert(frame_pts.max(0)); + let elapsed_ns = (frame_pts.max(0) - base_ns).max(0); + elapsed_ns * fps_i64 / 1_000_000_000 +} + /// 解析 DRM 渲染设备路径 /// /// 仅使用命令行指定的设备路径;未指定则在首帧到达时自动检测。 @@ -419,8 +425,111 @@ mod tests { backend: None, port: 0, }; - let result = resolve_drm_device(&args); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), std::path::PathBuf::from("/dev/dri/renderD128")); + let result = resolve_drm_device(&args).unwrap(); + assert_eq!(result, Some(std::path::PathBuf::from("/dev/dri/renderD128"))); + } + + #[test] + fn resolve_drm_device_none_when_not_specified() { + let args = Args { + output: "test.mp4".to_string(), + output_name: None, + fps: 30, + codec: "h264".to_string(), + hw_accel: "vaapi".to_string(), + drm_device: None, + bitrate: None, + gop_size: None, + verbose: false, + backend: None, + port: 0, + }; + let result = resolve_drm_device(&args).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn build_drm_descriptor_custom_offset_and_stride() { + let frame = PwDmaBufFrame { + fd: unsafe { OwnedFd::from_raw_fd(libc::dup(2)) }, + offset: 4096, + stride: 3840 * 4, + modifier: 0x0100000000000001, // AMD modifiers + width: 3840, + height: 2160, + format: 0x34325258, + pts: 0, + }; + let desc = build_drm_descriptor(&frame); + + assert_eq!(desc.nb_objects, 1); + assert_eq!(desc.objects[0].format_modifier, 0x0100000000000001); + assert_eq!(desc.layers[0].planes[0].offset, 4096); + assert_eq!(desc.layers[0].planes[0].pitch, 3840 * 4); + } + + // --- compute_pts tests --- + + #[test] + fn compute_pts_first_frame_is_zero() { + let mut base = None; + let pts = compute_pts(&mut base, 1_000_000_000, 30); + assert_eq!(pts, 0); + assert_eq!(base, Some(1_000_000_000)); + } + + #[test] + fn compute_pts_second_frame_at_30fps() { + let mut base = Some(1_000_000_000); + // 33_333_333 * 30 / 1_000_000_000 = 0 (integer division) + let pts = compute_pts(&mut base, 1_000_000_000 + 33_333_333, 30); + assert_eq!(pts, 0); + + // 100ms later = frame 3 + let pts = compute_pts(&mut base, 1_000_000_000 + 100_000_000, 30); + assert_eq!(pts, 3); + } + + #[test] + fn compute_pts_multiple_frames_accumulate() { + let mut base = None; + let fps = 60; + + let pts0 = compute_pts(&mut base, 0, fps); + assert_eq!(pts0, 0); + + let pts1 = compute_pts(&mut base, 16_666_666, fps); + assert_eq!(pts1, 0); // 16_666_666 * 60 / 1_000_000_000 = 0 + + let pts2 = compute_pts(&mut base, 33_333_333, fps); + assert_eq!(pts2, 1); // 33_333_333 * 60 / 1_000_000_000 = 1 + + let pts3 = compute_pts(&mut base, 50_000_000, fps); + assert_eq!(pts3, 3); // 50ms * 60 / 1000 = 3 + } + + #[test] + fn compute_pts_negative_pts_clamped_to_zero() { + let mut base = None; + let pts = compute_pts(&mut base, -999_999, 30); + assert_eq!(pts, 0); + assert_eq!(base, Some(0)); // max(0) clamps negative + } + + #[test] + fn compute_pts_late_frame_after_negative() { + let mut base = Some(0); + let pts = compute_pts(&mut base, 1_000_000_000, 30); + assert_eq!(pts, 30); + } + + #[test] + fn compute_pts_base_not_overwritten_after_first_call() { + let mut base = None; + let _ = compute_pts(&mut base, 5_000_000_000, 30); + assert_eq!(base, Some(5_000_000_000)); + + let _ = compute_pts(&mut base, 10_000_000_000, 30); + assert_eq!(base, Some(5_000_000_000)); // base stays at first frame } }