rakata_formats/wav/
reader.rs

1//! WAV binary reader.
2
3use std::io::Read;
4
5use super::{
6    binary, Wav, WavEncodingCode, WavError, WavType, WavWaveMetadata, WavWrapperKind,
7    DATA_CHUNK_ID, FMT_CHUNK_ID, MP3_IN_WAV_HEADER_SIZE, MP3_IN_WAV_RIFF_SIZE, RIFF_MAGIC,
8    SFX_HEADER_SIZE, SFX_MAGIC, VO_HEADER_SIZE, WAVE_MAGIC,
9};
10
11/// Reads WAV data from a reader.
12#[cfg_attr(
13    feature = "tracing",
14    tracing::instrument(level = "debug", skip(reader))
15)]
16pub fn read_wav<R: Read>(reader: &mut R) -> Result<Wav, WavError> {
17    let mut bytes = Vec::new();
18    reader.read_to_end(&mut bytes)?;
19    read_wav_from_bytes(&bytes)
20}
21
22/// Reads WAV data from bytes.
23#[cfg_attr(
24    feature = "tracing",
25    tracing::instrument(level = "debug", skip(bytes), fields(bytes_len = bytes.len()))
26)]
27pub fn read_wav_from_bytes(bytes: &[u8]) -> Result<Wav, WavError> {
28    let (kind, skip_size) = detect_wrapper(bytes);
29    if skip_size > bytes.len() {
30        return Err(WavError::InvalidHeader(format!(
31            "wrapper skip offset {skip_size} exceeds file length {}",
32            bytes.len()
33        )));
34    }
35
36    let payload = &bytes[skip_size..];
37    let wav_type = match kind {
38        WavWrapperKind::SfxHeader => WavType::Sfx,
39        WavWrapperKind::Standard => WavType::Standard,
40        _ => WavType::Vo, // VoHeader and Mp3InWav both use the Vo wrapper on write
41    };
42
43    if kind == WavWrapperKind::Mp3InWav {
44        return Ok(Wav::new_mp3(wav_type, payload.to_vec()));
45    }
46
47    parse_wave_payload(payload, wav_type)
48}
49
50fn detect_wrapper(bytes: &[u8]) -> (WavWrapperKind, usize) {
51    if bytes.len() < 12 {
52        return (WavWrapperKind::Standard, 0);
53    }
54
55    if bytes.starts_with(&SFX_MAGIC) {
56        return (WavWrapperKind::SfxHeader, SFX_HEADER_SIZE);
57    }
58
59    if bytes.starts_with(&RIFF_MAGIC) {
60        if bytes.len() >= VO_HEADER_SIZE + 4
61            && bytes[VO_HEADER_SIZE..VO_HEADER_SIZE + 4] == RIFF_MAGIC
62        {
63            return (WavWrapperKind::VoHeader, VO_HEADER_SIZE);
64        }
65
66        if let Ok(riff_size) = binary::read_u32(bytes, 4) {
67            if riff_size == MP3_IN_WAV_RIFF_SIZE {
68                return (WavWrapperKind::Mp3InWav, MP3_IN_WAV_HEADER_SIZE);
69            }
70        }
71    }
72
73    (WavWrapperKind::Standard, 0)
74}
75
76fn parse_wave_payload(bytes: &[u8], wav_type: WavType) -> Result<Wav, WavError> {
77    if bytes.len() < 12 {
78        return Err(WavError::InvalidHeader(
79            "WAVE payload shorter than 12-byte RIFF header".into(),
80        ));
81    }
82    if !bytes.starts_with(&RIFF_MAGIC) {
83        return Err(WavError::InvalidHeader(
84            "missing RIFF magic after wrapper normalization".into(),
85        ));
86    }
87    if bytes[8..12] != WAVE_MAGIC {
88        return Err(WavError::InvalidHeader(
89            "missing WAVE magic in RIFF header".into(),
90        ));
91    }
92
93    let mut cursor = 12usize;
94    let mut fmt_fields: Option<(WavEncodingCode, u16, u32, u32, u16, u16)> = None;
95    let mut data_chunk: Option<Vec<u8>> = None;
96
97    while cursor + 8 <= bytes.len() {
98        let chunk_id = [
99            bytes[cursor],
100            bytes[cursor + 1],
101            bytes[cursor + 2],
102            bytes[cursor + 3],
103        ];
104        let chunk_size_u32 = binary::read_u32(bytes, cursor + 4).map_err(|_| {
105            WavError::InvalidChunk(format!(
106                "unable to read chunk size at offset {}",
107                cursor + 4
108            ))
109        })?;
110        let chunk_size = usize::try_from(chunk_size_u32).map_err(|_| {
111            WavError::InvalidChunk(format!("chunk size {chunk_size_u32} does not fit in usize"))
112        })?;
113        cursor += 8;
114
115        let chunk_end = cursor.checked_add(chunk_size).ok_or_else(|| {
116            WavError::InvalidChunk(format!("chunk size overflow at chunk offset {cursor}"))
117        })?;
118        if chunk_end > bytes.len() {
119            return Err(WavError::InvalidChunk(format!(
120                "chunk at offset {} exceeds file bounds",
121                cursor - 8
122            )));
123        }
124        let chunk_data = &bytes[cursor..chunk_end];
125
126        if chunk_id == FMT_CHUNK_ID {
127            if chunk_data.len() < 16 {
128                return Err(WavError::InvalidChunk(
129                    "`fmt ` chunk shorter than 16 bytes".into(),
130                ));
131            }
132
133            let encoding =
134                WavEncodingCode::from_raw(u16::from_le_bytes([chunk_data[0], chunk_data[1]]));
135            let channels = u16::from_le_bytes([chunk_data[2], chunk_data[3]]);
136            let sample_rate =
137                u32::from_le_bytes([chunk_data[4], chunk_data[5], chunk_data[6], chunk_data[7]]);
138            let bytes_per_sec =
139                u32::from_le_bytes([chunk_data[8], chunk_data[9], chunk_data[10], chunk_data[11]]);
140            let block_align = u16::from_le_bytes([chunk_data[12], chunk_data[13]]);
141            let bits_per_sample = u16::from_le_bytes([chunk_data[14], chunk_data[15]]);
142            fmt_fields = Some((
143                encoding,
144                channels,
145                sample_rate,
146                bytes_per_sec,
147                block_align,
148                bits_per_sample,
149            ));
150        } else if chunk_id == DATA_CHUNK_ID {
151            data_chunk = Some(chunk_data.to_vec());
152            break;
153        }
154
155        cursor = chunk_end;
156        if chunk_size % 2 == 1 && cursor < bytes.len() {
157            cursor += 1;
158        }
159    }
160
161    let (encoding, channels, sample_rate, bytes_per_sec, block_align, bits_per_sample) =
162        fmt_fields.ok_or_else(|| WavError::InvalidChunk("missing required `fmt ` chunk".into()))?;
163    let data =
164        data_chunk.ok_or_else(|| WavError::InvalidChunk("missing required `data` chunk".into()))?;
165
166    Ok(Wav::new_wave(
167        wav_type,
168        WavWaveMetadata {
169            encoding,
170            channels,
171            sample_rate,
172            bytes_per_sec,
173            block_align,
174            bits_per_sample,
175        },
176        data,
177    ))
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::binary::{DecodeBinary, EncodeBinary};
184    use crate::wav::{
185        write_wav_to_vec, write_wav_to_vec_with_options, WavAudioFormat, WavEncoding, WavWriteMode,
186        WavWriteOptions,
187    };
188
189    fn sample_wave() -> Wav {
190        Wav::new_wave(
191            WavType::Vo,
192            WavWaveMetadata {
193                encoding: WavEncodingCode::from(WavEncoding::Pcm),
194                channels: 1,
195                sample_rate: 22_050,
196                bytes_per_sec: 44_100,
197                block_align: 2,
198                bits_per_sample: 16,
199            },
200            vec![1, 2, 3, 4, 5, 6],
201        )
202    }
203
204    #[test]
205    fn roundtrip_clean_wave_payload() {
206        let wav = sample_wave();
207        let bytes = write_wav_to_vec_with_options(
208            &wav,
209            WavWriteOptions {
210                mode: WavWriteMode::Clean,
211            },
212        )
213        .expect("write clean WAV");
214        let parsed = read_wav_from_bytes(&bytes).expect("read clean WAV");
215
216        assert_eq!(parsed.audio_format, WavAudioFormat::Wave);
217        assert_eq!(parsed.encoding.known(), Some(WavEncoding::Pcm));
218        assert_eq!(parsed.channels, 1);
219        assert_eq!(parsed.sample_rate, 22_050);
220        assert_eq!(parsed.data, wav.data);
221        // Clean-mode write produces a plain RIFF file; reading it back gives WavType::Standard.
222        assert_eq!(parsed.wav_type, WavType::Standard);
223    }
224
225    #[test]
226    fn writer_is_deterministic_for_synthetic_wave() {
227        let wav = sample_wave();
228        let first = write_wav_to_vec(&wav).expect("first write should succeed");
229        let second = write_wav_to_vec(&wav).expect("second write should succeed");
230        assert_eq!(first, second);
231    }
232
233    #[test]
234    fn roundtrip_game_sfx_wrapper() {
235        let mut wav = sample_wave();
236        wav.wav_type = WavType::Sfx;
237
238        let bytes = write_wav_to_vec(&wav).expect("write game WAV");
239        assert!(bytes.starts_with(&SFX_MAGIC));
240
241        let parsed = read_wav_from_bytes(&bytes).expect("read SFX WAV");
242        assert_eq!(parsed.wav_type, WavType::Sfx);
243        assert_eq!(parsed.audio_format, WavAudioFormat::Wave);
244        assert_eq!(parsed.data, wav.data);
245    }
246
247    #[test]
248    fn roundtrip_game_vo_wrapper() {
249        let wav = sample_wave();
250        let bytes = write_wav_to_vec(&wav).expect("write game WAV");
251        assert!(bytes.starts_with(&RIFF_MAGIC));
252        assert_eq!(&bytes[VO_HEADER_SIZE..VO_HEADER_SIZE + 4], &RIFF_MAGIC);
253
254        let parsed = read_wav_from_bytes(&bytes).expect("read VO WAV");
255        assert_eq!(parsed.wav_type, WavType::Vo);
256        assert_eq!(parsed.audio_format, WavAudioFormat::Wave);
257        assert_eq!(parsed.data, wav.data);
258    }
259
260    #[test]
261    fn detects_and_unwraps_mp3_in_wav() {
262        let mp3_payload = vec![0x49, 0x44, 0x33, 0x04, 0x00];
263        let mut wrapped = vec![0_u8; MP3_IN_WAV_HEADER_SIZE];
264        wrapped[0..4].copy_from_slice(&RIFF_MAGIC);
265        wrapped[4..8].copy_from_slice(&MP3_IN_WAV_RIFF_SIZE.to_le_bytes());
266        wrapped[8..12].copy_from_slice(&WAVE_MAGIC);
267        wrapped.extend_from_slice(&mp3_payload);
268
269        let parsed = read_wav_from_bytes(&wrapped).expect("read MP3-in-WAV");
270        assert_eq!(parsed.wav_type, WavType::Vo);
271        assert_eq!(parsed.audio_format, WavAudioFormat::Mp3);
272        assert_eq!(parsed.encoding.known(), Some(WavEncoding::Mp3));
273        assert_eq!(parsed.data, mp3_payload);
274    }
275
276    #[test]
277    fn clean_mode_mp3_writes_raw_payload() {
278        let wav = Wav::new_mp3(WavType::Vo, vec![0x01, 0x02, 0x03]);
279        let bytes = write_wav_to_vec_with_options(
280            &wav,
281            WavWriteOptions {
282                mode: WavWriteMode::Clean,
283            },
284        )
285        .expect("write clean MP3");
286        assert_eq!(bytes, vec![0x01, 0x02, 0x03]);
287    }
288
289    #[test]
290    fn rejects_non_riff_wave_payload() {
291        let err = read_wav_from_bytes(b"not-riff").expect_err("must fail");
292        assert!(matches!(err, WavError::InvalidHeader(_)));
293    }
294
295    #[test]
296    fn rejects_truncated_header() {
297        let err = read_wav_from_bytes(&RIFF_MAGIC).expect_err("must fail");
298        assert!(matches!(err, WavError::InvalidHeader(_)));
299    }
300
301    #[test]
302    fn rejects_missing_fmt_chunk() {
303        let mut bytes = Vec::new();
304        bytes.extend_from_slice(&RIFF_MAGIC);
305        bytes.extend_from_slice(&(12_u32 + 8 + 4).to_le_bytes());
306        bytes.extend_from_slice(&WAVE_MAGIC);
307        bytes.extend_from_slice(&DATA_CHUNK_ID);
308        bytes.extend_from_slice(&(4_u32).to_le_bytes());
309        bytes.extend_from_slice(&[0, 1, 2, 3]);
310
311        let err = read_wav_from_bytes(&bytes).expect_err("must fail");
312        assert!(matches!(err, WavError::InvalidChunk(_)));
313    }
314
315    #[test]
316    fn rejects_truncated_chunk_payload() {
317        let mut bytes = Vec::new();
318        bytes.extend_from_slice(&RIFF_MAGIC);
319        bytes.extend_from_slice(&(12_u32 + 8 + 16).to_le_bytes());
320        bytes.extend_from_slice(&WAVE_MAGIC);
321        bytes.extend_from_slice(&FMT_CHUNK_ID);
322        bytes.extend_from_slice(&(16_u32).to_le_bytes());
323        bytes.extend_from_slice(&[1, 0, 1, 0]); // truncated fmt chunk
324
325        let err = read_wav_from_bytes(&bytes).expect_err("must fail");
326        assert!(matches!(err, WavError::InvalidChunk(_)));
327    }
328
329    #[test]
330    fn decode_encode_traits_roundtrip() {
331        let mut wav = sample_wave();
332        wav.wav_type = WavType::Sfx;
333
334        let bytes = wav.encode_binary().expect("encode");
335        let decoded = Wav::decode_binary(&bytes).expect("decode");
336
337        assert_eq!(decoded.wav_type, WavType::Sfx);
338        assert_eq!(decoded.audio_format, WavAudioFormat::Wave);
339        assert_eq!(decoded.data, wav.data);
340    }
341}