feat: add WebRTC streaming via str0m + portal session persistence

- Add src/webrtc.rs: HTTP signaling server + str0m Sans-IO WebRTC transport
  with H.264 Annex-B → RTP packetization and key-frame request handling
- avhw: introduce FrameOutput enum (Muxer | Channel) so SwEncState can
  output to either MP4 muxer or crossbeam channel for WebRTC
- cap_portal: support portal session restore tokens (PersistMode::ExplicitlyRevoked)
  to skip re-authorization dialog; add --no-persist flag to force fresh dialog
- args: make --output optional when --port is used for WebRTC mode
- state_portal: integrate WebRTC pipeline (encoder channel → RTP forwarding)
  with shorter GOP for WebRTC (fps/2, min 10)
- main: redirect tracing to stderr; validate --output or --port required
- Add dependencies: str0m 0.20, serde_json 1, dirs 6
This commit is contained in:
dailz
2026-06-04 20:54:16 +08:00
parent 74f4dc826d
commit b0ed6548a6
10 changed files with 1611 additions and 70 deletions

794
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -25,3 +25,6 @@ tokio = { version = "1", features = ["rt"] }
pipewire = { version = "0.9", features = ["v0_3_45"] }
libspa = "0.9"
crossbeam-channel = "0.5"
str0m = "0.20"
serde_json = "1"
dirs = "6"

View File

@@ -3,9 +3,9 @@ use clap::Parser;
#[derive(Parser, Debug, Clone)]
#[command(name = "wl-webrtc", about = "Wayland screen capture and encoding tool")]
pub struct Args {
/// Output file path (e.g., output.mp4, output.mkv)
/// Output file path (e.g., output.mp4, output.mkv). Optional when using --port for WebRTC mode
#[arg(short, long)]
pub output: String,
pub output: Option<String>,
/// Wayland output name to capture
#[arg(long)]
@@ -43,7 +43,11 @@ pub struct Args {
#[arg(long)]
pub backend: Option<String>,
/// Port for WebTransport server (Phase 2, unused in MVP)
/// Port for WebRTC HTTP signaling server; 0 keeps MP4 file output mode
#[arg(long, default_value_t = 0)]
pub port: u16,
/// Force re-authorization dialog (ignore saved portal restore token)
#[arg(long)]
pub no_persist: bool,
}

View File

@@ -596,13 +596,18 @@ impl EncState {
// SwEncState - VAAPI GPU downscale + software H.264 encode
// ---------------------------------------------------------------------------
pub enum FrameOutput {
Muxer(ff::format::context::Output),
Channel(crossbeam_channel::Sender<Vec<u8>>),
}
pub struct SwEncState {
hw_dev: AvHwDevCtx,
frames_rgb: AvHwFrameCtx,
filter_graph: ff::filter::Graph,
sws_ctx: *mut ffi::SwsContext,
enc_video: ff::codec::encoder::video::Video,
octx: ff::format::context::Output,
output: Option<FrameOutput>,
yuv_frame: *mut ffi::AVFrame,
starting_timestamp: Option<i64>,
frames_written: bool,
@@ -651,7 +656,52 @@ impl SwEncState {
filter_graph,
sws_ctx,
enc_video,
octx,
output: Some(FrameOutput::Muxer(octx)),
yuv_frame,
starting_timestamp: None,
frames_written: false,
})
}
#[allow(clippy::too_many_arguments)]
pub fn new_webrtc(
drm_device: &Path,
width: u32,
height: u32,
enc_width: u32,
enc_height: u32,
fps: u32,
bitrate: u64,
gop_size: u32,
tx: crossbeam_channel::Sender<Vec<u8>>,
) -> Result<Self> {
tracing::info!(
"SwEncState::new_webrtc: GPU downscale {width}x{height} BGRA -> {enc_width}x{enc_height} NV12, software H.264 -> WebRTC"
);
let hw_dev = AvHwDevCtx::new_vaapi(drm_device)?;
let frames_rgb =
AvHwFrameCtx::for_capture(&hw_dev, width, height, ff::format::Pixel::BGRA)?;
let filter_graph = build_swenc_filter_graph(
&hw_dev,
&frames_rgb,
width,
height,
enc_width,
enc_height,
fps,
)?;
let sws_ctx = create_nv12_to_yuv420p_sws(enc_width, enc_height)?;
let enc_video = create_software_h264_encoder(enc_width, enc_height, fps, bitrate, gop_size)?;
let yuv_frame = alloc_yuv420p_frame(enc_width, enc_height)?;
Ok(Self {
hw_dev,
frames_rgb,
filter_graph,
sws_ctx,
enc_video,
output: Some(FrameOutput::Channel(tx)),
yuv_frame,
starting_timestamp: None,
frames_written: false,
@@ -704,7 +754,6 @@ impl SwEncState {
}
}
// SAFETY: Sending a null frame flushes the encoder without transferring ownership.
unsafe {
let ret = ffi::avcodec_send_frame(self.enc_video.as_mut_ptr(), ptr::null());
if ret < 0 && ret != ffi::AVERROR_EOF {
@@ -715,9 +764,10 @@ impl SwEncState {
self.drain_encoder(start_ts)?;
if self.frames_written {
self.octx
.write_trailer()
.map_err(|e| anyhow::anyhow!("Failed to write trailer: {e}"))?;
if let Some(FrameOutput::Muxer(ref mut octx)) = self.output {
octx.write_trailer()
.map_err(|e| anyhow::anyhow!("Failed to write trailer: {e}"))?;
}
}
Ok(())
@@ -793,25 +843,39 @@ impl SwEncState {
bail!("avcodec_receive_packet failed: error {ret}");
}
let enc_tb = self.enc_video.time_base();
let stream_tb = unsafe {
let streams = (*self.octx.as_ptr()).streams;
let st = *streams.add(0);
ff::Rational::from((*st).time_base)
};
pkt.rescale_ts(enc_tb, stream_tb);
match self.output {
Some(FrameOutput::Muxer(ref mut octx)) => {
let enc_tb = self.enc_video.time_base();
let stream_tb = unsafe {
let streams = (*octx.as_ptr()).streams;
let st = *streams.add(0);
ff::Rational::from((*st).time_base)
};
pkt.rescale_ts(enc_tb, stream_tb);
if let Some(pts) = pkt.pts() {
pkt.set_pts(Some(pts - start_ts));
}
if let Some(dts) = pkt.dts() {
pkt.set_dts(Some(dts - start_ts));
}
if let Some(pts) = pkt.pts() {
pkt.set_pts(Some(pts - start_ts));
}
if let Some(dts) = pkt.dts() {
pkt.set_dts(Some(dts - start_ts));
}
pkt.set_stream(0);
pkt.write_interleaved(&mut self.octx)
.map_err(|e| anyhow::anyhow!("Failed to write packet: {e}"))?;
self.frames_written = true;
pkt.set_stream(0);
pkt.write_interleaved(octx)
.map_err(|e| anyhow::anyhow!("Failed to write packet: {e}"))?;
self.frames_written = true;
}
Some(FrameOutput::Channel(ref tx)) => {
let data: &[u8] = unsafe {
std::slice::from_raw_parts(
(*pkt.as_mut_ptr()).data,
(*pkt.as_mut_ptr()).size as usize,
)
};
let _ = tx.send(data.to_vec());
}
None => {}
}
}
Ok(())
}
@@ -1115,6 +1179,54 @@ fn create_software_h264_muxer(
Ok((enc_video, octx))
}
fn create_software_h264_encoder(
width: u32,
height: u32,
fps: u32,
bitrate: u64,
gop_size: u32,
) -> Result<ff::codec::encoder::video::Video> {
let codec = ff::encoder::find_by_name("libx264")
.or_else(|| ff::encoder::find_by_name("libopenh264"))
.ok_or_else(|| anyhow::anyhow!("No H.264 software encoder found"))?;
let codec_name = codec.name().to_string();
let mut enc = {
let ctx = ff::codec::Context::new_with_codec(codec);
ctx.encoder().video()?
};
enc.set_width(width);
enc.set_height(height);
enc.set_format(ff::format::Pixel::YUV420P);
enc.set_bit_rate(bitrate as usize);
enc.set_gop(gop_size);
enc.set_time_base(ff::Rational::new(1, fps as i32));
enc.set_max_b_frames(0);
if codec_name == "libx264" {
unsafe {
let key = CString::new("preset").unwrap();
let val = CString::new("ultrafast").unwrap();
ffi::av_opt_set((*enc.as_mut_ptr()).priv_data, key.as_ptr(), val.as_ptr(), 0);
let key = CString::new("tune").unwrap();
let val = CString::new("zerolatency").unwrap();
ffi::av_opt_set((*enc.as_mut_ptr()).priv_data, key.as_ptr(), val.as_ptr(), 0);
let key = CString::new("threads").unwrap();
let val = CString::new("6").unwrap();
ffi::av_opt_set((*enc.as_mut_ptr()).priv_data, key.as_ptr(), val.as_ptr(), 0);
let key = CString::new("x264opts").unwrap();
let val = CString::new("repeat_headers=1").unwrap();
ffi::av_opt_set((*enc.as_mut_ptr()).priv_data, key.as_ptr(), val.as_ptr(), 0);
}
}
let opened = enc
.open()
.map_err(|e| anyhow::anyhow!("Failed to open {codec_name} encoder: {e}"))?;
tracing::info!("WebRTC encoder: {codec_name} {width}x{height} @ {fps}fps {bitrate}bps");
Ok(opened.0)
}
// ---------------------------------------------------------------------------
// Filter graph (inline)
// ---------------------------------------------------------------------------

View File

@@ -12,6 +12,7 @@
// - crossbeam-channel: 高性能有界通道,用于线程间帧传递
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
@@ -98,12 +99,10 @@ impl CapPortal {
/// 4. 创建 eventfd 对,用于线程安全的关闭信号传递
/// 5. 启动 PipeWire 捕获线程
pub fn new(args: &Args) -> Result<Self> {
// 创建独立的 Tokio 运行时,仅用于 setup_portal 中的异步 Portal D-Bus 调用
let rt = Runtime::new()?;
// 通过 Portal 获取 PipeWire 连接 fd 和节点 ID
// block_on 在此处同步等待异步 Portal 调用完成
let (pw_fd, node_id) = rt.block_on(async { Self::setup_portal().await })?;
let no_persist = args.no_persist;
let (pw_fd, node_id) = rt.block_on(async { Self::setup_portal(no_persist).await })?;
let (frame_tx, frame_rx) = bounded(16);
let (event_tx, event_rx) = bounded(8);
@@ -172,44 +171,50 @@ impl CapPortal {
/// 5. 打开 PipeWire 远程连接,获取文件描述符
///
/// 返回 (PipeWire fd, node_id),供 PipeWire 线程连接使用
async fn setup_portal() -> Result<(OwnedFd, u32)> {
async fn setup_portal(no_persist: bool) -> Result<(OwnedFd, u32)> {
use ashpd::desktop::screencast::{
CursorMode, Screencast, SelectSourcesOptions, SourceType,
};
use ashpd::desktop::PersistMode;
// 创建 Screencast D-Bus 代理,与桌面环境的 Portal 服务通信
let proxy = Screencast::new()
.await
.map_err(|e| anyhow::anyhow!("Failed to create Screencast proxy: {e}"))?;
// 创建 ScreenCast 会话(每个会话对应一次屏幕录制请求)
let session = proxy
.create_session(Default::default())
.await
.map_err(|e| anyhow::anyhow!("Failed to create ScreenCast session: {e}"))?;
// 配置录制源选择参数:
// - CursorMode::Embedded: 光标嵌入到帧数据中(而非单独的元数据)
// - SourceType::Monitor: 仅捕获显示器(不捕获窗口)
// - multiple: false: 不允许多源选择
// - PersistMode::DoNot: 不持久化会话(每次需要重新授权)
let version_supported = proxy.version() >= 4;
let (persist_mode, saved_token) = if !no_persist && version_supported {
let token = load_restore_token();
if token.is_some() {
tracing::info!("Attempting to restore portal session with saved token");
}
(PersistMode::ExplicitlyRevoked, token)
} else {
(PersistMode::DoNot, None)
};
let mut options = SelectSourcesOptions::default()
.set_cursor_mode(CursorMode::Embedded)
.set_sources(ashpd::enumflags2::BitFlags::from(SourceType::Monitor))
.set_multiple(false)
.set_persist_mode(persist_mode);
if let Some(ref token) = saved_token {
options = options.set_restore_token(token.as_str());
}
proxy
.select_sources(
&session,
SelectSourcesOptions::default()
.set_cursor_mode(CursorMode::Embedded)
.set_sources(ashpd::enumflags2::BitFlags::from(SourceType::Monitor))
.set_multiple(false)
.set_persist_mode(PersistMode::DoNot),
)
.select_sources(&session, options)
.await
.map_err(|e| {
anyhow::anyhow!("屏幕共享权限被拒绝 / Screen sharing permission denied: {e}")
anyhow::anyhow!("Screen sharing permission denied: {e}")
})?;
// 启动录制会话,此时桌面环境会弹出权限确认对话框
// 用户确认后返回包含 PipeWire 流信息的响应
let response = proxy
.start(&session, None, Default::default())
.await
@@ -217,18 +222,19 @@ impl CapPortal {
.response()
.map_err(|e| anyhow::anyhow!("ScreenCast response error: {e}"))?;
// 获取返回的第一个(也是唯一的)视频流
// 每个流对应一个 PipeWire 节点
if !no_persist && version_supported {
if let Some(new_token) = response.restore_token() {
save_restore_token(new_token);
}
}
let stream = response
.streams()
.first()
.ok_or_else(|| anyhow::anyhow!("No streams returned from ScreenCast"))?;
// 提取 PipeWire 节点 ID用于后续连接到该节点的视频流
let node_id = stream.pipe_wire_node_id();
// 打开 PipeWire 远程连接,获取文件描述符
// 这个 fd 允许直接与 PipeWire 守护进程通信
let fd = proxy
.open_pipe_wire_remote(&session, Default::default())
.await
@@ -240,6 +246,30 @@ impl CapPortal {
}
}
fn token_path() -> PathBuf {
let base = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"));
base.join("wl-webrtc").join("portal-restore-token")
}
fn load_restore_token() -> Option<String> {
let path = token_path();
let token = std::fs::read_to_string(&path).ok()?;
let trimmed = token.trim().to_string();
if trimmed.is_empty() { None } else { Some(trimmed) }
}
fn save_restore_token(token: &str) {
let path = token_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
match std::fs::write(&path, token) {
Ok(()) => tracing::info!("Saved portal restore token"),
Err(e) => tracing::warn!("Failed to save restore token: {e}"),
}
}
impl Drop for CapPortal {
/// 析构时安全关闭 PipeWire 线程
///

View File

@@ -7,3 +7,4 @@ pub mod fps_limit;
pub mod state;
pub mod state_portal;
pub mod transform;
pub mod webrtc;

View File

@@ -18,6 +18,7 @@ mod fps_limit; // 帧率限制器
mod state; // wlr-screencopy 后端的主状态机
mod state_portal; // Portal/PipeWire 后端的主状态机
mod transform; // 图像变换(旋转/翻转)
mod webrtc; // WebRTC 传输str0m Sans-IO
use crate::args::Args;
use crate::cap_wlr_screencopy::CapWlrScreencopy;
@@ -49,6 +50,7 @@ fn main() -> Result<()> {
} else {
tracing::Level::INFO
})
.with_writer(std::io::stderr)
.init();
tracing::info!("wl-webrtc starting");
@@ -59,6 +61,10 @@ fn main() -> Result<()> {
anyhow::bail!("HEVC not supported in MVP. Use --codec h264");
}
if args.output.is_none() && args.port == 0 {
anyhow::bail!("Either --output or --port is required");
}
// 自动检测当前桌面环境可用的截屏后端
// 会尝试列举 Wayland 全局对象,判断合成器是否支持 wlr-screencopy 协议
let backend = crate::backend_detect::detect_backend(&args)?;

View File

@@ -613,7 +613,7 @@ impl<S: CaptureSource> State<S> {
.unwrap_or_else(|| 2 * (width as u64) * (height as u64) * (fps as u64) / 100);
let enc = match crate::avhw::create_encoder(
&drm_path,
Path::new(&self.args.output),
Path::new(self.args.output.as_deref().expect("output required for MP4 mode")),
width,
height,
fps,

View File

@@ -8,6 +8,7 @@ use anyhow::{bail, Result};
use crate::args::Args;
use crate::avhw::{self, SwEncState};
use crate::cap_portal::{CapPortal, PwCtrlEvent, PwDmaBufFrame};
use crate::webrtc::WebRtcState;
/// 门户采集的阶段状态
/// - WaitingForFormat: 等待接收到第一帧 DMA-BUF 以确定视频格式参数
@@ -32,6 +33,10 @@ pub struct StatePortal {
start_time: Option<Instant>,
last_stats_time: Option<Instant>,
last_stats_frames: u64,
webrtc: Option<WebRtcState>,
webrtc_tx: Option<crossbeam_channel::Sender<Vec<u8>>>,
webrtc_rx: Option<crossbeam_channel::Receiver<Vec<u8>>>,
webrtc_frames_sent: u64,
}
impl StatePortal {
@@ -48,6 +53,14 @@ impl StatePortal {
let cap = CapPortal::new(&args)?;
let (webrtc, webrtc_tx, webrtc_rx) = if args.port > 0 {
let (tx, rx) = crossbeam_channel::bounded(32);
let wrtc = WebRtcState::new(args.port, args.fps)?;
(Some(wrtc), Some(tx), Some(rx))
} else {
(None, None, None)
};
Ok(Self {
stage: PortalStage::WaitingForFormat,
enc: None,
@@ -59,6 +72,10 @@ impl StatePortal {
start_time: None,
last_stats_time: None,
last_stats_frames: 0,
webrtc,
webrtc_tx,
webrtc_rx,
webrtc_frames_sent: 0,
})
}
@@ -68,6 +85,9 @@ impl StatePortal {
/// `block=false` 时使用 try_recv 非阻塞检查。
/// 返回 `Ok(true)` 表示已处理事件,`Ok(false)` 表示暂无数据。
pub fn poll_and_encode(&mut self, block: bool) -> Result<bool> {
// WebRTC: process signaling, network, and forward encoded frames
self.poll_webrtc()?;
if let Ok(ctrl) = self.cap.event_receiver().try_recv() {
match ctrl {
PwCtrlEvent::StreamEnded => {
@@ -119,19 +139,39 @@ impl StatePortal {
let actual_bitrate = self.args.bitrate.unwrap_or_else(|| {
2 * (enc_width as u64) * (enc_height as u64) * (self.args.fps as u64) / 100
});
let actual_gop_size = self.args.gop_size.unwrap_or(self.args.fps);
let actual_gop_size = self.args.gop_size.unwrap_or_else(|| {
if self.webrtc_tx.is_some() {
(self.args.fps / 2).max(10)
} else {
self.args.fps
}
});
let enc = avhw::SwEncState::new(
&drm_path,
self.args.output.as_ref(),
frame.width,
frame.height,
enc_width,
enc_height,
self.args.fps,
actual_bitrate,
actual_gop_size,
)?;
let enc = if let Some(ref tx) = self.webrtc_tx {
avhw::SwEncState::new_webrtc(
&drm_path,
frame.width,
frame.height,
enc_width,
enc_height,
self.args.fps,
actual_bitrate,
actual_gop_size,
tx.clone(),
)?
} else {
avhw::SwEncState::new(
&drm_path,
std::path::Path::new(self.args.output.as_deref().expect("output required for MP4 mode")),
frame.width,
frame.height,
enc_width,
enc_height,
self.args.fps,
actual_bitrate,
actual_gop_size,
)?
};
self.enc = Some(enc);
self.stage = PortalStage::Streaming;
@@ -145,6 +185,9 @@ impl StatePortal {
}
}
// WebRTC: drain encoded frames produced by this poll before returning.
self.poll_webrtc()?;
Ok(true)
}
@@ -266,6 +309,29 @@ impl StatePortal {
pub fn is_errored(&self) -> bool {
self.errored
}
fn poll_webrtc(&mut self) -> Result<()> {
let Some(ref mut wrtc) = self.webrtc else { return Ok(()); };
wrtc.handle_signaling()?;
wrtc.poll_and_feed()?;
if let Some(ref rx) = self.webrtc_rx {
let mut count = 0u32;
while let Ok(data) = rx.try_recv() {
count += 1;
if let Err(e) = wrtc.write_h264_frame(&data, self.webrtc_frames_sent, self.args.fps) {
tracing::debug!("WebRTC write frame error: {e}");
}
self.webrtc_frames_sent = self.webrtc_frames_sent.saturating_add(1);
}
if count > 0 {
tracing::info!("WebRTC forwarded {count} frames from channel");
}
}
Ok(())
}
}
impl Drop for StatePortal {

531
src/webrtc.rs Normal file
View File

@@ -0,0 +1,531 @@
// WebRTC 传输模块 — 使用 str0m (Sans-IO) 将 H.264 编码帧推送到浏览器
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, UdpSocket};
use std::time::Instant;
use anyhow::{bail, Result};
use str0m::change::SdpOffer;
use str0m::format::Codec;
use str0m::media::{Frequency, MediaKind, MediaTime, Mid, Pt};
use str0m::net::{Protocol, Receive};
use str0m::{Candidate, Event, IceConnectionState, Input, Output, Rtc, RtcConfig};
// ── 嵌入式 HTML 测试页面 ──────────────────────────────────────────────────
const HTML_PAGE: &str = r#"<!DOCTYPE html>
<html>
<head><title>wl-webrtc P0</title>
<style>body{background:#000;color:#fff;font-family:monospace;display:flex;flex-direction:column;align-items:center;justify-content:center;height:100vh;margin:0}
video{max-width:90vw;max-height:80vh;border:1px solid #333}
#status{margin:12px;font-size:14px;color:#aaa}
#debug{position:fixed;bottom:8px;left:8px;font-size:11px;color:#666;max-width:90vw;white-space:pre-wrap}
</style></head>
<body>
<div id="status">Connecting...</div>
<video id="video" autoplay playsinline muted></video>
<pre id="debug"></pre>
<script>
const status = document.getElementById('status');
const video = document.getElementById('video');
const debug = document.getElementById('debug');
let pc = null;
const log = msg => { debug.textContent += msg + '\n'; console.log(msg); };
function preferH264(sdp) {
const lines = sdp.split('\r\n');
const h264Pts = lines
.filter(line => line.startsWith('a=rtpmap:') && line.toUpperCase().includes('H264/90000'))
.map(line => line.match(/^a=rtpmap:(\d+)/)?.[1])
.filter(Boolean);
if (h264Pts.length === 0) return sdp;
return lines.map(line => {
if (!line.startsWith('m=video ')) return line;
const parts = line.split(' ');
const header = parts.slice(0, 3);
const pts = parts.slice(3);
const preferred = h264Pts.filter(pt => pts.includes(pt));
const rest = pts.filter(pt => !preferred.includes(pt));
return [...header, ...preferred, ...rest].join(' ');
}).join('\r\n');
}
function installStatsLogger(peer) {
setInterval(() => {
if (peer !== pc) return;
const v = video;
log(`video: readyState=${v.readyState} currentTime=${v.currentTime.toFixed(2)} ` +
`paused=${v.paused} width=${v.videoWidth} height=${v.videoHeight} ` +
`srcObject=${v.srcObject ? 'yes' : 'no'}`);
peer.getStats().then(stats => {
stats.forEach(report => {
if (report.type === 'inbound-rtp' && report.kind === 'video') {
log(`RTP-in: packetsReceived=${report.packetsReceived} packetsLost=${report.packetsLost} ` +
`bytesReceived=${report.bytesReceived} framesDecoded=${report.framesDecoded} ` +
`framesDropped=${report.framesDropped} codecId=${report.codecId}`);
}
if (report.type === 'codec' && report.mimeType && report.mimeType.includes('H264')) {
log(`Codec: ${report.mimeType} ${report.payloadType} sdpFmtpLine=${report.sdpFmtpLine}`);
}
});
}).catch(() => {});
}, 2000);
}
function connect() {
if (pc) pc.close();
pc = new RTCPeerConnection();
const peer = pc;
peer.ontrack = e => {
log('ontrack: streams=' + e.streams.length + ' kind=' + e.track.kind);
video.srcObject = e.streams[0];
status.textContent = 'Track received';
};
peer.oniceconnectionstatechange = () => {
log('ICE: ' + peer.iceConnectionState);
status.textContent = 'ICE: ' + peer.iceConnectionState;
};
peer.addTransceiver('video', { direction: 'recvonly' });
installStatsLogger(peer);
peer.createOffer().then(offer => {
offer.sdp = preferH264(offer.sdp);
return peer.setLocalDescription(offer);
})
.then(() => new Promise(resolve => {
if (peer.iceGatheringState === 'complete') resolve();
else peer.onicegatheringstatechange = () => { if (peer.iceGatheringState === 'complete') resolve(); };
}))
.then(() => fetch('/sdp', { method: 'POST', body: JSON.stringify(peer.localDescription) }))
.then(r => { if (!r.ok) throw new Error('SDP exchange failed: ' + r.status); return r.json(); })
.then(answer => { if (answer.error) throw new Error(answer.error); return peer.setRemoteDescription(answer); })
.then(() => log('SDP answer set'))
.catch(e => {
status.textContent = 'Error: ' + e.message;
log('ERROR: ' + e.message + ' — retrying in 2s...');
console.error(e);
setTimeout(connect, 2000);
});
}
connect();
</script>
</body></html>"#;
// ── WebRTC 状态 ───────────────────────────────────────────────────────────
pub struct WebRtcState {
signal_listener: TcpListener,
inner: Option<WebRtcInner>,
fps: u32,
}
struct WebRtcInner {
rtc: Rtc,
socket: UdpSocket,
udp_addr: SocketAddr,
video_mid: Option<Mid>,
video_pt: Option<Pt>,
connected: bool,
need_keyframe: bool,
rtp_clock: u32,
buf: Vec<u8>,
}
impl WebRtcState {
pub fn new(port: u16, fps: u32) -> Result<Self> {
let signal_listener = TcpListener::bind(format!("0.0.0.0:{port}"))?;
signal_listener.set_nonblocking(true)?;
tracing::info!("WebRTC signaling on http://0.0.0.0:{port}/");
Ok(Self {
signal_listener,
inner: None,
fps,
})
}
pub fn handle_signaling(&mut self) -> Result<bool> {
let mut handled = false;
loop {
let (mut stream, _addr) = match self.signal_listener.accept() {
Ok(s) => s,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => bail!("TCP accept error: {e}"),
};
handled = true;
stream.set_nonblocking(true)?;
let mut req = vec![0u8; 65536];
let n = match stream.read(&mut req) {
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
Err(e) => {
tracing::warn!("TCP read error: {e}");
continue;
}
};
let req_str = String::from_utf8_lossy(&req[..n]);
if req_str.starts_with("GET / ")
|| req_str.starts_with("GET /sdp ")
&& !req_str.contains("Content-Type: application/json")
{
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
HTML_PAGE.len(),
HTML_PAGE
);
let _ = stream.write_all(resp.as_bytes());
} else if req_str.starts_with("POST /sdp") {
let body = extract_body(&req_str);
if body.is_empty() {
let resp = "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\nempty body";
let _ = stream.write_all(resp.as_bytes());
continue;
}
match WebRtcInner::new(self.fps)
.and_then(|mut new_inner| {
let answer_json = new_inner.handle_sdp_offer(body.as_bytes())?;
Ok((new_inner, answer_json))
}) {
Ok((new_inner, answer_json)) => {
let replacing = self.inner.is_some();
self.inner = Some(new_inner);
if replacing {
tracing::info!("Replaced WebRTC connection (old dropped)");
} else {
tracing::info!("New WebRTC connection");
}
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
answer_json.len(),
answer_json
);
let _ = stream.write_all(resp.as_bytes());
}
Err(e) => {
tracing::error!("SDP offer handling failed: {e}");
let resp = format!("HTTP/1.1 500 Error\r\nConnection: close\r\n\r\n{e}");
let _ = stream.write_all(resp.as_bytes());
}
}
} else {
let resp = "HTTP/1.1 404 Not Found\r\nConnection: close\r\n\r\n";
let _ = stream.write_all(resp.as_bytes());
}
}
Ok(handled)
}
pub fn poll_rtc(&mut self) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
if inner.poll_rtc()? {
tracing::warn!("WebRTC connection closed/failed; clearing connection state");
self.inner = None;
}
}
Ok(())
}
pub fn feed_network(&mut self) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.feed_network()?;
}
Ok(())
}
pub fn poll_and_feed(&mut self) -> Result<()> {
self.poll_rtc()?;
self.feed_network()?;
self.poll_rtc()
}
pub fn write_h264_frame(&mut self, data: &[u8], frame_number: u64, fps: u32) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.write_h264_frame(data, frame_number, fps)?;
}
Ok(())
}
pub fn is_connected(&self) -> bool {
self.inner.as_ref().is_some_and(WebRtcInner::is_connected)
}
}
impl WebRtcInner {
fn new(fps: u32) -> Result<Self> {
let _ = fps;
let mut rtc = RtcConfig::new().build(Instant::now());
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_nonblocking(true)?;
let local_addr = socket.local_addr()?;
let lan_ip = local_ip().unwrap_or_else(|| {
tracing::warn!("Failed to detect LAN IP, falling back to 127.0.0.1");
"127.0.0.1".to_string()
});
let candidate_addr: SocketAddr = format!("{lan_ip}:{}", local_addr.port()).parse()?;
let candidate = Candidate::host(candidate_addr, "udp")
.map_err(|e| anyhow::anyhow!("candidate: {e}"))?;
rtc.add_local_candidate(candidate);
tracing::info!("WebRTC UDP: {candidate_addr} (bound 0.0.0.0)");
Ok(Self {
rtc,
socket,
udp_addr: candidate_addr,
video_mid: None,
video_pt: None,
connected: false,
need_keyframe: false,
rtp_clock: 0,
buf: vec![0u8; 65535],
})
}
fn handle_sdp_offer(&mut self, body: &[u8]) -> Result<String> {
let offer: SdpOffer = serde_json::from_slice(body)
.map_err(|e| anyhow::anyhow!("parse SDP offer: {e}"))?;
let answer = self
.rtc
.sdp_api()
.accept_offer(offer)
.map_err(|e| anyhow::anyhow!("accept_offer: {e}"))?;
self.need_keyframe = true;
tracing::info!("SDP exchange complete, waiting for ICE/DTLS...");
self.discover_video_params();
let answer_json =
serde_json::to_vec(&answer).map_err(|e| anyhow::anyhow!("serialize answer: {e}"))?;
String::from_utf8(answer_json).map_err(|e| anyhow::anyhow!("answer utf8: {e}"))
}
fn discover_video_params(&mut self) {
for s in ["0", "1", "2", "3"] {
let mid: Mid = s.into();
if let Some(media) = self.rtc.media(mid) {
if media.kind() == MediaKind::Video {
tracing::info!("Found video media: mid={mid}");
self.video_mid = Some(mid);
break;
}
}
}
if let Some(mid) = self.video_mid {
if let Some(writer) = self.rtc.writer(mid) {
for pp in writer.payload_params() {
tracing::debug!("Codec: pt={:?} spec={:?}", pp.pt(), pp.spec());
if pp.spec().codec.is_video() && pp.spec().codec == Codec::H264 {
self.video_pt = Some(pp.pt());
tracing::info!("H.264 payload type: {:?}", pp.pt());
break;
}
}
}
}
}
fn poll_rtc(&mut self) -> Result<bool> {
loop {
match self.rtc.poll_output() {
Ok(Output::Transmit(t)) => {
tracing::info!("TX {} bytes -> {}", t.contents.len(), t.destination);
if let Err(e) = self.socket.send_to(&t.contents, t.destination) {
tracing::warn!("UDP send error: {e}");
}
}
Ok(Output::Event(e)) => {
tracing::info!("RTC event: {e:?}");
match &e {
Event::Connected => {
tracing::info!("WebRTC connected!");
self.connected = true;
self.need_keyframe = true;
self.discover_video_params();
}
Event::IceConnectionStateChange(IceConnectionState::Disconnected) => {
tracing::warn!("WebRTC disconnected");
self.connected = false;
}
Event::MediaAdded(ma) => {
tracing::info!("Media added: mid={:?}", ma.mid);
}
_ => {
tracing::debug!("WebRTC event: {:?}", e);
}
}
}
Ok(Output::Timeout(_t)) => break,
Err(e) => {
tracing::error!("rtc.poll_output error: {e}");
break;
}
}
}
Ok(false)
}
fn feed_network(&mut self) -> Result<()> {
let mut recv_count = 0u32;
loop {
match self.socket.recv_from(&mut self.buf) {
Ok((n, source)) => {
recv_count += 1;
if recv_count <= 5 {
tracing::info!("UDP recv {} bytes from {}", n, source);
}
let input = Input::Receive(
Instant::now(),
Receive {
proto: Protocol::Udp,
source,
destination: self.udp_addr,
contents: self.buf[..n]
.try_into()
.map_err(|e| anyhow::anyhow!("receive contents: {e}"))?,
},
);
self.rtc
.handle_input(input)
.map_err(|e| anyhow::anyhow!("handle_input({n} bytes from {source}): {e}"))?;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => bail!("UDP recv error: {e}"),
}
}
self.rtc
.handle_input(Input::Timeout(Instant::now()))
.map_err(|e| anyhow::anyhow!("handle timeout: {e}"))?;
Ok(())
}
fn write_h264_frame(&mut self, data: &[u8], frame_number: u64, fps: u32) -> Result<()> {
if !self.connected {
return Ok(());
}
let mid = match self.video_mid {
Some(m) => m,
None => {
tracing::warn!("write_h264: no video_mid");
return Ok(());
}
};
let pt = match self.video_pt {
Some(p) => p,
None => {
tracing::warn!("write_h264: no video_pt");
return Ok(());
}
};
if self.need_keyframe {
if !is_idr_nalu(data) {
tracing::debug!(
"write_h264: skipping non-IDR frame ({} bytes), waiting for keyframe",
data.len()
);
return Ok(());
}
tracing::info!(
"write_h264: got IDR keyframe ({} bytes), starting playback",
data.len()
);
self.need_keyframe = false;
}
let ticks_per_second = 90_000u64;
let fps = fps.max(1) as u64;
let rtp_timestamp = frame_number.saturating_mul(ticks_per_second) / fps;
self.rtp_clock = rtp_timestamp as u32;
let rtp_time = MediaTime::new(rtp_timestamp, Frequency::NINETY_KHZ);
let writer = match self.rtc.writer(mid) {
Some(w) => w,
None => {
tracing::warn!("write_h264: no writer for mid={mid}");
return Ok(());
}
};
tracing::debug!(
"write_h264: {} bytes, pt={:?}, rtp={}",
data.len(),
pt,
self.rtp_clock
);
writer
.write(pt, Instant::now(), rtp_time, data)
.map_err(|e| anyhow::anyhow!("writer.write: {e}"))?;
self.poll_rtc()?;
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
}
// ── 工具函数 ──────────────────────────────────────────────────────────────
/// 从 HTTP 请求中提取 body在 \r\n\r\n 之后)
fn extract_body(req: &str) -> &str {
if let Some(idx) = req.find("\r\n\r\n") {
req.get(idx + 4..).unwrap_or("")
} else {
""
}
}
fn local_ip() -> Option<String> {
std::net::UdpSocket::bind("0.0.0.0:0")
.ok()
.and_then(|s| {
s.connect("1.1.1.1:80").ok()?;
let addr = s.local_addr().ok()?;
drop(s);
let ip = addr.ip().to_string();
if ip == "0.0.0.0" || ip.starts_with("127.") {
return None;
}
Some(ip)
})
}
fn is_idr_nalu(data: &[u8]) -> bool {
let mut i = 0;
while i + 4 < data.len() {
if data[i..i + 4] == [0, 0, 0, 1] {
let nal_type = data[i + 4] & 0x1F;
if nal_type == 5 {
return true;
}
i += 5;
} else if i + 3 < data.len() && data[i..i + 3] == [0, 0, 1] {
let nal_type = data[i + 3] & 0x1F;
if nal_type == 5 {
return true;
}
i += 4;
} else {
i += 1;
}
}
false
}