feat(avhw): integrate transform into VA-API filter graph

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
dailz
2026-04-14 17:02:54 +08:00
parent e89689634d
commit ecd78492ee
2 changed files with 96 additions and 51 deletions

View File

@@ -4,9 +4,11 @@ use std::ptr;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use ffmpeg_next as ff; use ffmpeg_next as ff;
use ffmpeg_next::ffi as ffi; use ffmpeg_next::ffi;
use ffmpeg_next::packet::Mut as _; use ffmpeg_next::packet::Mut as _;
use crate::transform::Transform;
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// AvHwDevCtx // AvHwDevCtx
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -69,12 +71,7 @@ pub struct AvHwFrameCtx {
unsafe impl Send for AvHwFrameCtx {} unsafe impl Send for AvHwFrameCtx {}
impl AvHwFrameCtx { impl AvHwFrameCtx {
fn new_inner( fn new_inner(hw_dev: &AvHwDevCtx, w: u32, h: u32, sw_fmt: ff::format::Pixel) -> Result<Self> {
hw_dev: &AvHwDevCtx,
w: u32,
h: u32,
sw_fmt: ff::format::Pixel,
) -> Result<Self> {
let mut p = unsafe { ffi::av_hwframe_ctx_alloc(hw_dev.as_ptr()) }; let mut p = unsafe { ffi::av_hwframe_ctx_alloc(hw_dev.as_ptr()) };
if p.is_null() { if p.is_null() {
bail!("av_hwframe_ctx_alloc returned null"); bail!("av_hwframe_ctx_alloc returned null");
@@ -159,24 +156,25 @@ impl EncState {
output_path: &Path, output_path: &Path,
width: u32, width: u32,
height: u32, height: u32,
enc_width: u32,
enc_height: u32,
bitrate: u64, bitrate: u64,
gop_size: u32, gop_size: u32,
fps: u32, fps: u32,
transform: Transform,
) -> Result<Self> { ) -> Result<Self> {
// 1. VAAPI device // 1. VAAPI device
let hw_device_ctx = AvHwDevCtx::new_vaapi(drm_device)?; let hw_device_ctx = AvHwDevCtx::new_vaapi(drm_device)?;
// 2. Frame contexts (capture=XRGB/RGBZ, encode=NV12) // 2. Frame contexts (capture=XRGB/RGBZ, encode=NV12)
let frames_rgb = AvHwFrameCtx::for_capture( // frames_rgb uses original capture dimensions (matches raw framebuffer)
&hw_device_ctx, // frames_yuv uses encoder dimensions (transposed for 90°/270° rotations)
width, let frames_rgb =
height, AvHwFrameCtx::for_capture(&hw_device_ctx, width, height, ff::format::Pixel::RGBZ)?;
ff::format::Pixel::RGBZ,
)?;
let frames_yuv = AvHwFrameCtx::for_encode( let frames_yuv = AvHwFrameCtx::for_encode(
&hw_device_ctx, &hw_device_ctx,
width, enc_width,
height, enc_height,
ff::format::Pixel::NV12, ff::format::Pixel::NV12,
)?; )?;
@@ -189,8 +187,8 @@ impl EncState {
ctx.encoder().video()? ctx.encoder().video()?
}; };
enc.set_width(width); enc.set_width(enc_width);
enc.set_height(height); enc.set_height(enc_height);
enc.set_format(ff::format::Pixel::VAAPI); enc.set_format(ff::format::Pixel::VAAPI);
enc.set_bit_rate(bitrate as usize); enc.set_bit_rate(bitrate as usize);
enc.set_gop(gop_size); enc.set_gop(gop_size);
@@ -226,14 +224,22 @@ impl EncState {
} }
// 4. Open encoder. Video::open() returns Encoder(Video); .0 extracts the Video. // 4. Open encoder. Video::open() returns Encoder(Video); .0 extracts the Video.
let opened = enc.open().map_err(|e| { let opened = enc
anyhow::anyhow!("Failed to open h264_vaapi encoder: {e}") .open()
})?; .map_err(|e| anyhow::anyhow!("Failed to open h264_vaapi encoder: {e}"))?;
let enc_video = opened.0; let enc_video = opened.0;
// 5. Filter graph (inline) // 5. Filter graph (inline)
let video_filter = let video_filter = build_filter_graph(
build_filter_graph(&hw_device_ctx, &frames_rgb, width, height, fps)?; &hw_device_ctx,
&frames_rgb,
width,
height,
enc_width,
enc_height,
fps,
transform,
)?;
// 6. Muxer setup (strict order) // 6. Muxer setup (strict order)
let output_cstr = CString::new(output_path.to_str().unwrap())?; let output_cstr = CString::new(output_path.to_str().unwrap())?;
@@ -329,9 +335,9 @@ impl EncState {
let mut filter_sink = filter_sink_ctx.sink(); let mut filter_sink = filter_sink_ctx.sink();
// SAFETY: hw_frame is a valid VAAPI hardware frame from capture. // SAFETY: hw_frame is a valid VAAPI hardware frame from capture.
filter_src.add(hw_frame).map_err(|e| { filter_src
anyhow::anyhow!("Filter source add failed: {e}") .add(hw_frame)
})?; .map_err(|e| anyhow::anyhow!("Filter source add failed: {e}"))?;
loop { loop {
let mut filtered = ff::frame::Video::empty(); let mut filtered = ff::frame::Video::empty();
@@ -352,9 +358,8 @@ impl EncState {
let start_ts = self.starting_timestamp.unwrap(); let start_ts = self.starting_timestamp.unwrap();
// SAFETY: avcodec_send_frame sends a valid NV12 VAAPI surface to the encoder. // SAFETY: avcodec_send_frame sends a valid NV12 VAAPI surface to the encoder.
let ret = unsafe { let ret =
ffi::avcodec_send_frame(self.enc_video.as_mut_ptr(), filtered.as_ptr()) unsafe { ffi::avcodec_send_frame(self.enc_video.as_mut_ptr(), filtered.as_ptr()) };
};
if ret < 0 { if ret < 0 {
bail!("avcodec_send_frame failed: error {ret}"); bail!("avcodec_send_frame failed: error {ret}");
} }
@@ -400,9 +405,9 @@ impl EncState {
// Write trailer only if at least one frame was encoded. // Write trailer only if at least one frame was encoded.
if self.frames_written { if self.frames_written {
self.octx.write_trailer().map_err(|e| { self.octx
anyhow::anyhow!("Failed to write trailer: {e}") .write_trailer()
})?; .map_err(|e| anyhow::anyhow!("Failed to write trailer: {e}"))?;
} }
Ok(()) Ok(())
@@ -440,9 +445,8 @@ impl EncState {
} }
pkt.set_stream(0); pkt.set_stream(0);
pkt.write_interleaved(&mut self.octx).map_err(|e| { pkt.write_interleaved(&mut self.octx)
anyhow::anyhow!("Failed to write packet: {e}") .map_err(|e| anyhow::anyhow!("Failed to write packet: {e}"))?;
})?;
self.frames_written = true; self.frames_written = true;
} }
@@ -459,16 +463,19 @@ fn build_filter_graph(
frames_rgb: &AvHwFrameCtx, frames_rgb: &AvHwFrameCtx,
width: u32, width: u32,
height: u32, height: u32,
_enc_width: u32,
_enc_height: u32,
fps: u32, fps: u32,
transform: Transform,
) -> Result<ff::filter::Graph> { ) -> Result<ff::filter::Graph> {
let mut graph = ff::filter::Graph::new(); let mut graph = ff::filter::Graph::new();
let buffersrc = ff::filter::find("buffer") let buffersrc =
.ok_or_else(|| anyhow::anyhow!("filter 'buffer' not found"))?; ff::filter::find("buffer").ok_or_else(|| anyhow::anyhow!("filter 'buffer' not found"))?;
let buffersink = ff::filter::find("buffersink") let buffersink = ff::filter::find("buffersink")
.ok_or_else(|| anyhow::anyhow!("filter 'buffersink' not found"))?; .ok_or_else(|| anyhow::anyhow!("filter 'buffersink' not found"))?;
let format_filter = ff::filter::find("format") let format_filter =
.ok_or_else(|| anyhow::anyhow!("filter 'format' not found"))?; ff::filter::find("format").ok_or_else(|| anyhow::anyhow!("filter 'format' not found"))?;
let scale_vaapi = ff::filter::find("scale_vaapi") let scale_vaapi = ff::filter::find("scale_vaapi")
.ok_or_else(|| anyhow::anyhow!("filter 'scale_vaapi' not found"))?; .ok_or_else(|| anyhow::anyhow!("filter 'scale_vaapi' not found"))?;
@@ -491,7 +498,10 @@ fn build_filter_graph(
(*par).format = Into::<ffi::AVPixelFormat>::into(ff::format::Pixel::VAAPI) as i32; (*par).format = Into::<ffi::AVPixelFormat>::into(ff::format::Pixel::VAAPI) as i32;
(*par).width = width as i32; (*par).width = width as i32;
(*par).height = height as i32; (*par).height = height as i32;
(*par).time_base = ffi::AVRational { num: 1, den: fps as i32 }; (*par).time_base = ffi::AVRational {
num: 1,
den: fps as i32,
};
(*par).hw_frames_ctx = frames_rgb.ref_clone(); (*par).hw_frames_ctx = frames_rgb.ref_clone();
let ret = ffi::av_buffersrc_parameters_set(src_ctx.as_mut_ptr(), par); let ret = ffi::av_buffersrc_parameters_set(src_ctx.as_mut_ptr(), par);
ffi::av_freep(par as *mut _ as *mut _); ffi::av_freep(par as *mut _ as *mut _);
@@ -503,7 +513,7 @@ fn build_filter_graph(
// format filter: negotiate pixel format to NV12 // format filter: negotiate pixel format to NV12
let mut fmt_ctx = graph.add(&format_filter, "fmt", "pix_fmts=nv12")?; let mut fmt_ctx = graph.add(&format_filter, "fmt", "pix_fmts=nv12")?;
// scale_vaapi: hardware scaling and colourspace conversion // scale_vaapi: hardware scaling and colourspace conversion (keeps original dimensions)
let mut scale_ctx = graph.add(&scale_vaapi, "scale", &format!("{width}:{height}"))?; let mut scale_ctx = graph.add(&scale_vaapi, "scale", &format!("{width}:{height}"))?;
// SAFETY: scale_vaapi needs hw_device_ctx for VAAPI device access. // SAFETY: scale_vaapi needs hw_device_ctx for VAAPI device access.
unsafe { unsafe {
@@ -513,14 +523,43 @@ fn build_filter_graph(
// buffersink // buffersink
let mut sink_ctx = graph.add(&buffersink, "out", "")?; let mut sink_ctx = graph.add(&buffersink, "out", "")?;
// Link: src -> format -> scale -> sink // Build filter chain: src -> format -> scale -> [transpose] -> sink
src_ctx.link(0, &mut fmt_ctx, 0); src_ctx.link(0, &mut fmt_ctx, 0);
fmt_ctx.link(0, &mut scale_ctx, 0);
scale_ctx.link(0, &mut sink_ctx, 0);
graph.validate().map_err(|e| { match transform {
anyhow::anyhow!("Filter graph validation failed: {e}") Transform::Normal90 | Transform::Normal270 => {
})?; let transpose = ff::filter::find("transpose_vaapi")
.ok_or_else(|| anyhow::anyhow!("filter 'transpose_vaapi' not found"))?;
let dir_val = match transform {
Transform::Normal90 => "1",
Transform::Normal270 => "2",
_ => unreachable!(),
};
let mut trans_ctx = graph.add(&transpose, "transpose", &format!("dir={dir_val}"))?;
// SAFETY: transpose_vaapi needs hw_device_ctx for VAAPI device access.
unsafe {
(*trans_ctx.as_mut_ptr()).hw_device_ctx = hw_dev.ref_clone();
}
fmt_ctx.link(0, &mut scale_ctx, 0);
scale_ctx.link(0, &mut trans_ctx, 0);
trans_ctx.link(0, &mut sink_ctx, 0);
}
Transform::Normal180 => {
tracing::warn!(
"Normal180 transform detected; rotation correction deferred to follow-up"
);
fmt_ctx.link(0, &mut scale_ctx, 0);
scale_ctx.link(0, &mut sink_ctx, 0);
}
_ => {
fmt_ctx.link(0, &mut scale_ctx, 0);
scale_ctx.link(0, &mut sink_ctx, 0);
}
}
graph
.validate()
.map_err(|e| anyhow::anyhow!("Filter graph validation failed: {e}"))?;
Ok(graph) Ok(graph)
} }

View File

@@ -35,7 +35,7 @@ use crate::args::Args;
use crate::avhw::{AvHwDevCtx, EncState}; use crate::avhw::{AvHwDevCtx, EncState};
use crate::cap_wlr_screencopy::CapWlrScreencopy; use crate::cap_wlr_screencopy::CapWlrScreencopy;
use crate::fps_limit::FpsLimit; use crate::fps_limit::FpsLimit;
use crate::transform::Transform; use crate::transform::{transpose_if_transform_transposed, Transform};
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// CaptureSource trait // CaptureSource trait
@@ -429,14 +429,19 @@ impl<S: CaptureSource> State<S> {
}); });
let gop_size = self.args.gop_size.unwrap_or(self.args.fps); let gop_size = self.args.gop_size.unwrap_or(self.args.fps);
let fps = self.args.fps; let fps = self.args.fps;
let (enc_w, enc_h) =
transpose_if_transform_transposed(output_info.transform, width as i32, height as i32);
let enc = match EncState::new( let enc = match EncState::new(
&drm_path, &drm_path,
Path::new(&self.args.output), Path::new(&self.args.output),
width, width,
height, height,
enc_w as u32,
enc_h as u32,
bitrate, bitrate,
gop_size, gop_size,
fps, fps,
output_info.transform,
) { ) {
Ok(enc) => enc, Ok(enc) => enc,
Err(e) => { Err(e) => {
@@ -473,7 +478,7 @@ impl<S: CaptureSource> State<S> {
match pos { match pos {
Some(i) => Some(i), Some(i) => Some(i),
None => { None => {
let all_probed = outputs.iter().all(|o| o.done_count >= 2); let all_probed = outputs.iter().all(|o| o.done_count >= 1);
if all_probed { if all_probed {
let available: Vec<&str> = let available: Vec<&str> =
outputs.iter().filter_map(|o| o.name.as_deref()).collect(); outputs.iter().filter_map(|o| o.name.as_deref()).collect();
@@ -487,7 +492,7 @@ impl<S: CaptureSource> State<S> {
None None
} }
} }
} else if outputs.iter().all(|o| o.done_count >= 2) { } else if outputs.iter().all(|o| o.done_count >= 1) {
if outputs.is_empty() { if outputs.is_empty() {
return false; return false;
} }
@@ -635,6 +640,7 @@ impl<S: CaptureSource> Dispatch<WlRegistry, GlobalListContents> for State<S> {
qhandle: &QueueHandle<State<S>>, qhandle: &QueueHandle<State<S>>,
) { ) {
use wayland_client::protocol::wl_registry::Event as RegistryEvent; use wayland_client::protocol::wl_registry::Event as RegistryEvent;
tracing::debug!("Dispatch<WlRegistry>::event fired: {:?}", event);
match event { match event {
RegistryEvent::Global { RegistryEvent::Global {
@@ -793,7 +799,7 @@ impl<S: CaptureSource> Dispatch<WlOutput, OutputId> for State<S> {
if let EncConstructionStage::ProbingOutputs { outputs, .. } = &mut state.stage { if let EncConstructionStage::ProbingOutputs { outputs, .. } = &mut state.stage {
if let Some(info) = outputs.get_mut(idx) { if let Some(info) = outputs.get_mut(idx) {
info.done_count += 1; info.done_count += 1;
if info.done_count >= 2 { if info.done_count >= 1 {
state.try_finalize_output(idx); state.try_finalize_output(idx);
} }
} }
@@ -849,7 +855,7 @@ impl<S: CaptureSource> Dispatch<ZxdgOutputV1, OutputId> for State<S> {
if let EncConstructionStage::ProbingOutputs { outputs, .. } = &mut state.stage { if let EncConstructionStage::ProbingOutputs { outputs, .. } = &mut state.stage {
if let Some(info) = outputs.get_mut(idx) { if let Some(info) = outputs.get_mut(idx) {
info.done_count += 1; info.done_count += 1;
if info.done_count >= 2 { if info.done_count >= 1 {
state.try_finalize_output(idx); state.try_finalize_output(idx);
} }
} }