1use std::io::Read;
4
5use super::{
6 binary, Wav, WavEncodingCode, WavError, WavType, WavWaveMetadata, WavWrapperKind,
7 DATA_CHUNK_ID, FMT_CHUNK_ID, MP3_IN_WAV_HEADER_SIZE, MP3_IN_WAV_RIFF_SIZE, RIFF_MAGIC,
8 SFX_HEADER_SIZE, SFX_MAGIC, VO_HEADER_SIZE, WAVE_MAGIC,
9};
10
11#[cfg_attr(
13 feature = "tracing",
14 tracing::instrument(level = "debug", skip(reader))
15)]
16pub fn read_wav<R: Read>(reader: &mut R) -> Result<Wav, WavError> {
17 let mut bytes = Vec::new();
18 reader.read_to_end(&mut bytes)?;
19 read_wav_from_bytes(&bytes)
20}
21
22#[cfg_attr(
24 feature = "tracing",
25 tracing::instrument(level = "debug", skip(bytes), fields(bytes_len = bytes.len()))
26)]
27pub fn read_wav_from_bytes(bytes: &[u8]) -> Result<Wav, WavError> {
28 let (kind, skip_size) = detect_wrapper(bytes);
29 if skip_size > bytes.len() {
30 return Err(WavError::InvalidHeader(format!(
31 "wrapper skip offset {skip_size} exceeds file length {}",
32 bytes.len()
33 )));
34 }
35
36 let payload = &bytes[skip_size..];
37 let wav_type = match kind {
38 WavWrapperKind::SfxHeader => WavType::Sfx,
39 WavWrapperKind::Standard => WavType::Standard,
40 _ => WavType::Vo, };
42
43 if kind == WavWrapperKind::Mp3InWav {
44 return Ok(Wav::new_mp3(wav_type, payload.to_vec()));
45 }
46
47 parse_wave_payload(payload, wav_type)
48}
49
50fn detect_wrapper(bytes: &[u8]) -> (WavWrapperKind, usize) {
51 if bytes.len() < 12 {
52 return (WavWrapperKind::Standard, 0);
53 }
54
55 if bytes.starts_with(&SFX_MAGIC) {
56 return (WavWrapperKind::SfxHeader, SFX_HEADER_SIZE);
57 }
58
59 if bytes.starts_with(&RIFF_MAGIC) {
60 if bytes.len() >= VO_HEADER_SIZE + 4
61 && bytes[VO_HEADER_SIZE..VO_HEADER_SIZE + 4] == RIFF_MAGIC
62 {
63 return (WavWrapperKind::VoHeader, VO_HEADER_SIZE);
64 }
65
66 if let Ok(riff_size) = binary::read_u32(bytes, 4) {
67 if riff_size == MP3_IN_WAV_RIFF_SIZE {
68 return (WavWrapperKind::Mp3InWav, MP3_IN_WAV_HEADER_SIZE);
69 }
70 }
71 }
72
73 (WavWrapperKind::Standard, 0)
74}
75
76fn parse_wave_payload(bytes: &[u8], wav_type: WavType) -> Result<Wav, WavError> {
77 if bytes.len() < 12 {
78 return Err(WavError::InvalidHeader(
79 "WAVE payload shorter than 12-byte RIFF header".into(),
80 ));
81 }
82 if !bytes.starts_with(&RIFF_MAGIC) {
83 return Err(WavError::InvalidHeader(
84 "missing RIFF magic after wrapper normalization".into(),
85 ));
86 }
87 if bytes[8..12] != WAVE_MAGIC {
88 return Err(WavError::InvalidHeader(
89 "missing WAVE magic in RIFF header".into(),
90 ));
91 }
92
93 let mut cursor = 12usize;
94 let mut fmt_fields: Option<(WavEncodingCode, u16, u32, u32, u16, u16)> = None;
95 let mut data_chunk: Option<Vec<u8>> = None;
96
97 while cursor + 8 <= bytes.len() {
98 let chunk_id = [
99 bytes[cursor],
100 bytes[cursor + 1],
101 bytes[cursor + 2],
102 bytes[cursor + 3],
103 ];
104 let chunk_size_u32 = binary::read_u32(bytes, cursor + 4).map_err(|_| {
105 WavError::InvalidChunk(format!(
106 "unable to read chunk size at offset {}",
107 cursor + 4
108 ))
109 })?;
110 let chunk_size = usize::try_from(chunk_size_u32).map_err(|_| {
111 WavError::InvalidChunk(format!("chunk size {chunk_size_u32} does not fit in usize"))
112 })?;
113 cursor += 8;
114
115 let chunk_end = cursor.checked_add(chunk_size).ok_or_else(|| {
116 WavError::InvalidChunk(format!("chunk size overflow at chunk offset {cursor}"))
117 })?;
118 if chunk_end > bytes.len() {
119 return Err(WavError::InvalidChunk(format!(
120 "chunk at offset {} exceeds file bounds",
121 cursor - 8
122 )));
123 }
124 let chunk_data = &bytes[cursor..chunk_end];
125
126 if chunk_id == FMT_CHUNK_ID {
127 if chunk_data.len() < 16 {
128 return Err(WavError::InvalidChunk(
129 "`fmt ` chunk shorter than 16 bytes".into(),
130 ));
131 }
132
133 let encoding =
134 WavEncodingCode::from_raw(u16::from_le_bytes([chunk_data[0], chunk_data[1]]));
135 let channels = u16::from_le_bytes([chunk_data[2], chunk_data[3]]);
136 let sample_rate =
137 u32::from_le_bytes([chunk_data[4], chunk_data[5], chunk_data[6], chunk_data[7]]);
138 let bytes_per_sec =
139 u32::from_le_bytes([chunk_data[8], chunk_data[9], chunk_data[10], chunk_data[11]]);
140 let block_align = u16::from_le_bytes([chunk_data[12], chunk_data[13]]);
141 let bits_per_sample = u16::from_le_bytes([chunk_data[14], chunk_data[15]]);
142 fmt_fields = Some((
143 encoding,
144 channels,
145 sample_rate,
146 bytes_per_sec,
147 block_align,
148 bits_per_sample,
149 ));
150 } else if chunk_id == DATA_CHUNK_ID {
151 data_chunk = Some(chunk_data.to_vec());
152 break;
153 }
154
155 cursor = chunk_end;
156 if chunk_size % 2 == 1 && cursor < bytes.len() {
157 cursor += 1;
158 }
159 }
160
161 let (encoding, channels, sample_rate, bytes_per_sec, block_align, bits_per_sample) =
162 fmt_fields.ok_or_else(|| WavError::InvalidChunk("missing required `fmt ` chunk".into()))?;
163 let data =
164 data_chunk.ok_or_else(|| WavError::InvalidChunk("missing required `data` chunk".into()))?;
165
166 Ok(Wav::new_wave(
167 wav_type,
168 WavWaveMetadata {
169 encoding,
170 channels,
171 sample_rate,
172 bytes_per_sec,
173 block_align,
174 bits_per_sample,
175 },
176 data,
177 ))
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::binary::{DecodeBinary, EncodeBinary};
184 use crate::wav::{
185 write_wav_to_vec, write_wav_to_vec_with_options, WavAudioFormat, WavEncoding, WavWriteMode,
186 WavWriteOptions,
187 };
188
189 fn sample_wave() -> Wav {
190 Wav::new_wave(
191 WavType::Vo,
192 WavWaveMetadata {
193 encoding: WavEncodingCode::from(WavEncoding::Pcm),
194 channels: 1,
195 sample_rate: 22_050,
196 bytes_per_sec: 44_100,
197 block_align: 2,
198 bits_per_sample: 16,
199 },
200 vec![1, 2, 3, 4, 5, 6],
201 )
202 }
203
204 #[test]
205 fn roundtrip_clean_wave_payload() {
206 let wav = sample_wave();
207 let bytes = write_wav_to_vec_with_options(
208 &wav,
209 WavWriteOptions {
210 mode: WavWriteMode::Clean,
211 },
212 )
213 .expect("write clean WAV");
214 let parsed = read_wav_from_bytes(&bytes).expect("read clean WAV");
215
216 assert_eq!(parsed.audio_format, WavAudioFormat::Wave);
217 assert_eq!(parsed.encoding.known(), Some(WavEncoding::Pcm));
218 assert_eq!(parsed.channels, 1);
219 assert_eq!(parsed.sample_rate, 22_050);
220 assert_eq!(parsed.data, wav.data);
221 assert_eq!(parsed.wav_type, WavType::Standard);
223 }
224
225 #[test]
226 fn writer_is_deterministic_for_synthetic_wave() {
227 let wav = sample_wave();
228 let first = write_wav_to_vec(&wav).expect("first write should succeed");
229 let second = write_wav_to_vec(&wav).expect("second write should succeed");
230 assert_eq!(first, second);
231 }
232
233 #[test]
234 fn roundtrip_game_sfx_wrapper() {
235 let mut wav = sample_wave();
236 wav.wav_type = WavType::Sfx;
237
238 let bytes = write_wav_to_vec(&wav).expect("write game WAV");
239 assert!(bytes.starts_with(&SFX_MAGIC));
240
241 let parsed = read_wav_from_bytes(&bytes).expect("read SFX WAV");
242 assert_eq!(parsed.wav_type, WavType::Sfx);
243 assert_eq!(parsed.audio_format, WavAudioFormat::Wave);
244 assert_eq!(parsed.data, wav.data);
245 }
246
247 #[test]
248 fn roundtrip_game_vo_wrapper() {
249 let wav = sample_wave();
250 let bytes = write_wav_to_vec(&wav).expect("write game WAV");
251 assert!(bytes.starts_with(&RIFF_MAGIC));
252 assert_eq!(&bytes[VO_HEADER_SIZE..VO_HEADER_SIZE + 4], &RIFF_MAGIC);
253
254 let parsed = read_wav_from_bytes(&bytes).expect("read VO WAV");
255 assert_eq!(parsed.wav_type, WavType::Vo);
256 assert_eq!(parsed.audio_format, WavAudioFormat::Wave);
257 assert_eq!(parsed.data, wav.data);
258 }
259
260 #[test]
261 fn detects_and_unwraps_mp3_in_wav() {
262 let mp3_payload = vec![0x49, 0x44, 0x33, 0x04, 0x00];
263 let mut wrapped = vec![0_u8; MP3_IN_WAV_HEADER_SIZE];
264 wrapped[0..4].copy_from_slice(&RIFF_MAGIC);
265 wrapped[4..8].copy_from_slice(&MP3_IN_WAV_RIFF_SIZE.to_le_bytes());
266 wrapped[8..12].copy_from_slice(&WAVE_MAGIC);
267 wrapped.extend_from_slice(&mp3_payload);
268
269 let parsed = read_wav_from_bytes(&wrapped).expect("read MP3-in-WAV");
270 assert_eq!(parsed.wav_type, WavType::Vo);
271 assert_eq!(parsed.audio_format, WavAudioFormat::Mp3);
272 assert_eq!(parsed.encoding.known(), Some(WavEncoding::Mp3));
273 assert_eq!(parsed.data, mp3_payload);
274 }
275
276 #[test]
277 fn clean_mode_mp3_writes_raw_payload() {
278 let wav = Wav::new_mp3(WavType::Vo, vec![0x01, 0x02, 0x03]);
279 let bytes = write_wav_to_vec_with_options(
280 &wav,
281 WavWriteOptions {
282 mode: WavWriteMode::Clean,
283 },
284 )
285 .expect("write clean MP3");
286 assert_eq!(bytes, vec![0x01, 0x02, 0x03]);
287 }
288
289 #[test]
290 fn rejects_non_riff_wave_payload() {
291 let err = read_wav_from_bytes(b"not-riff").expect_err("must fail");
292 assert!(matches!(err, WavError::InvalidHeader(_)));
293 }
294
295 #[test]
296 fn rejects_truncated_header() {
297 let err = read_wav_from_bytes(&RIFF_MAGIC).expect_err("must fail");
298 assert!(matches!(err, WavError::InvalidHeader(_)));
299 }
300
301 #[test]
302 fn rejects_missing_fmt_chunk() {
303 let mut bytes = Vec::new();
304 bytes.extend_from_slice(&RIFF_MAGIC);
305 bytes.extend_from_slice(&(12_u32 + 8 + 4).to_le_bytes());
306 bytes.extend_from_slice(&WAVE_MAGIC);
307 bytes.extend_from_slice(&DATA_CHUNK_ID);
308 bytes.extend_from_slice(&(4_u32).to_le_bytes());
309 bytes.extend_from_slice(&[0, 1, 2, 3]);
310
311 let err = read_wav_from_bytes(&bytes).expect_err("must fail");
312 assert!(matches!(err, WavError::InvalidChunk(_)));
313 }
314
315 #[test]
316 fn rejects_truncated_chunk_payload() {
317 let mut bytes = Vec::new();
318 bytes.extend_from_slice(&RIFF_MAGIC);
319 bytes.extend_from_slice(&(12_u32 + 8 + 16).to_le_bytes());
320 bytes.extend_from_slice(&WAVE_MAGIC);
321 bytes.extend_from_slice(&FMT_CHUNK_ID);
322 bytes.extend_from_slice(&(16_u32).to_le_bytes());
323 bytes.extend_from_slice(&[1, 0, 1, 0]); let err = read_wav_from_bytes(&bytes).expect_err("must fail");
326 assert!(matches!(err, WavError::InvalidChunk(_)));
327 }
328
329 #[test]
330 fn decode_encode_traits_roundtrip() {
331 let mut wav = sample_wave();
332 wav.wav_type = WavType::Sfx;
333
334 let bytes = wav.encode_binary().expect("encode");
335 let decoded = Wav::decode_binary(&bytes).expect("decode");
336
337 assert_eq!(decoded.wav_type, WavType::Sfx);
338 assert_eq!(decoded.audio_format, WavAudioFormat::Wave);
339 assert_eq!(decoded.data, wav.data);
340 }
341}