diff --git a/src/webrtc.rs b/src/webrtc.rs index 380d315..eaa012e 100644 --- a/src/webrtc.rs +++ b/src/webrtc.rs @@ -520,16 +520,17 @@ fn local_ip() -> Option { fn is_idr_nalu(data: &[u8]) -> bool { let mut i = 0; - while i + 4 < data.len() { - if data[i..i + 4] == [0, 0, 0, 1] { - let nal_type = data[i + 4] & 0x1F; - if nal_type == 5 { + while i < data.len() { + let tail = &data[i..]; + if tail.starts_with(&[0, 0, 0, 1]) { + let Some(&header) = tail.get(4) else { break }; + if header & 0x1F == 5 { return true; } i += 5; - } else if i + 3 < data.len() && data[i..i + 3] == [0, 0, 1] { - let nal_type = data[i + 3] & 0x1F; - if nal_type == 5 { + } else if tail.starts_with(&[0, 0, 1]) { + let Some(&header) = tail.get(3) else { break }; + if header & 0x1F == 5 { return true; } i += 4; @@ -539,3 +540,78 @@ fn is_idr_nalu(data: &[u8]) -> bool { } false } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty_data() { + assert!(!is_idr_nalu(&[])); + } + + #[test] + fn short_data_no_start_code() { + assert!(!is_idr_nalu(&[0])); + assert!(!is_idr_nalu(&[0, 0])); + assert!(!is_idr_nalu(&[1, 2, 3])); + } + + #[test] + fn three_byte_start_code_no_nal_header() { + assert!(!is_idr_nalu(&[0, 0, 1])); + } + + #[test] + fn four_byte_start_code_no_nal_header() { + assert!(!is_idr_nalu(&[0, 0, 0, 1])); + } + + #[test] + fn three_byte_start_code_idr_at_tail() { + assert!(is_idr_nalu(&[0, 0, 1, 0x65])); + assert!(!is_idr_nalu(&[0, 0, 1, 0x01])); + } + + #[test] + fn four_byte_start_code_idr_at_tail() { + assert!(is_idr_nalu(&[0, 0, 0, 1, 0x65])); + assert!(!is_idr_nalu(&[0, 0, 0, 1, 0x01])); + } + + #[test] + fn idr_in_middle_of_frame() { + let data: Vec = [ + &[0, 0, 0, 1, 0x67][..], // SPS + &[0, 0, 0, 1, 0x68][..], // PPS + &[0, 0, 0, 1, 0x65][..], // IDR + ] + .concat(); + assert!(is_idr_nalu(&data)); + } + + #[test] + fn no_idr_in_frame() { + let data: Vec = [ + &[0, 0, 0, 1, 0x67][..], // SPS + &[0, 0, 0, 1, 0x68][..], // PPS + ] + .concat(); + assert!(!is_idr_nalu(&data)); + } + + #[test] + fn mixed_start_code_lengths() { + let data: Vec = [ + &[0, 0, 0, 1, 0x67][..], // SPS (4-byte start code) + &[0, 0, 1, 0x65][..], // IDR (3-byte start code) + ] + .concat(); + assert!(is_idr_nalu(&data)); + } + + #[test] + fn all_zeros() { + assert!(!is_idr_nalu(&[0, 0, 0, 0, 0, 0, 0, 0])); + } +}