1use std::fs;
8use std::path::Path;
9
10pub fn estimate_bpm(file_path: &str) -> Option<f64> {
13 let path = Path::new(file_path);
14 let ext = path
15 .extension()
16 .and_then(|e| e.to_str())
17 .unwrap_or("")
18 .to_lowercase();
19
20 let (samples, sample_rate) = match ext.as_str() {
21 "wav" => read_wav_pcm(path)?,
22 "aiff" | "aif" => read_aiff_pcm(path)?,
23 "mp3" | "flac" | "ogg" | "m4a" | "aac" | "opus" => decode_with_symphonia(path)?,
24 _ => return None,
25 };
26
27 if samples.len() < 1024 || sample_rate == 0 {
28 return None;
29 }
30
31 detect_tempo(&samples, sample_rate)
32}
33
34pub fn read_wav_pcm_pub(path: &Path) -> Option<(Vec<f32>, u32)> {
36 read_wav_pcm(path)
37}
38pub fn read_aiff_pcm_pub(path: &Path) -> Option<(Vec<f32>, u32)> {
39 read_aiff_pcm(path)
40}
41pub fn decode_with_symphonia_pub(path: &Path) -> Option<(Vec<f32>, u32)> {
42 decode_with_symphonia(path)
43}
44
45fn read_wav_pcm(path: &Path) -> Option<(Vec<f32>, u32)> {
47 let data = fs::read(path).ok()?;
48 if data.len() < 44 || &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
49 return None;
50 }
51
52 let channels = u16::from_le_bytes([data[22], data[23]]) as usize;
53 let sample_rate = u32::from_le_bytes([data[24], data[25], data[26], data[27]]);
54 let bits = u16::from_le_bytes([data[34], data[35]]);
55
56 let mut offset = 12;
58 while offset + 8 < data.len() {
59 let chunk_id = &data[offset..offset + 4];
60 let chunk_size = u32::from_le_bytes([
61 data[offset + 4],
62 data[offset + 5],
63 data[offset + 6],
64 data[offset + 7],
65 ]) as usize;
66 if chunk_id == b"data" {
67 let start = offset + 8;
68 let end = (start + chunk_size).min(data.len());
69 let pcm = &data[start..end];
70 let samples = decode_pcm(pcm, bits, channels, true);
71 return Some((samples, sample_rate));
72 }
73 offset += 8 + chunk_size;
74 if !chunk_size.is_multiple_of(2) {
75 offset += 1;
76 }
77 }
78 None
79}
80
81fn read_aiff_pcm(path: &Path) -> Option<(Vec<f32>, u32)> {
83 let data = fs::read(path).ok()?;
84 if data.len() < 12 || &data[0..4] != b"FORM" || &data[8..12] != b"AIFF" {
85 return None;
86 }
87
88 let mut channels = 0u16;
89 let mut bits = 0u16;
90 let mut sample_rate = 0u32;
91 let mut ssnd_data: Option<&[u8]> = None;
92
93 let mut offset = 12;
94 while offset + 8 < data.len() {
95 let chunk_id = &data[offset..offset + 4];
96 let chunk_size = u32::from_be_bytes([
97 data[offset + 4],
98 data[offset + 5],
99 data[offset + 6],
100 data[offset + 7],
101 ]) as usize;
102
103 if chunk_id == b"COMM" && offset + 26 <= data.len() {
104 channels = u16::from_be_bytes([data[offset + 8], data[offset + 9]]);
105 bits = u16::from_be_bytes([data[offset + 14], data[offset + 15]]);
106 let exp = u16::from_be_bytes([data[offset + 16], data[offset + 17]]) as i32;
108 let mantissa = u32::from_be_bytes([
109 data[offset + 18],
110 data[offset + 19],
111 data[offset + 20],
112 data[offset + 21],
113 ]);
114 sample_rate = (mantissa as f64 * 2f64.powi(exp - 16383 - 31)).round() as u32;
115 } else if chunk_id == b"SSND" {
116 let start = offset + 8 + 8;
118 let end = (offset + 8 + chunk_size).min(data.len());
119 if start < end {
120 ssnd_data = Some(&data[start..end]);
121 }
122 }
123
124 offset += 8 + chunk_size;
125 if !chunk_size.is_multiple_of(2) {
126 offset += 1;
127 }
128 }
129
130 let pcm = ssnd_data?;
131 if channels == 0 || sample_rate == 0 {
132 return None;
133 }
134 let samples = decode_pcm(pcm, bits, channels as usize, false);
135 Some((samples, sample_rate))
136}
137
138fn decode_pcm(data: &[u8], bits: u16, channels: usize, little_endian: bool) -> Vec<f32> {
140 let bytes_per_sample = (bits / 8) as usize;
141 let frame_size = bytes_per_sample * channels;
142 if frame_size == 0 {
143 return vec![];
144 }
145
146 let num_frames = data.len() / frame_size;
147 let max_frames = 44100 * 30;
149 let frames = num_frames.min(max_frames);
150 let mut samples = Vec::with_capacity(frames);
151
152 for i in 0..frames {
153 let offset = i * frame_size;
154 let sample = match bits {
156 16 => {
157 let raw = if little_endian {
158 i16::from_le_bytes([data[offset], data[offset + 1]])
159 } else {
160 i16::from_be_bytes([data[offset], data[offset + 1]])
161 };
162 raw as f32 / 32768.0
163 }
164 24 => {
165 let (b0, b1, b2) = if little_endian {
166 (data[offset], data[offset + 1], data[offset + 2])
167 } else {
168 (data[offset + 2], data[offset + 1], data[offset])
169 };
170 let raw = ((b2 as i32) << 24 | (b1 as i32) << 16 | (b0 as i32) << 8) >> 8;
171 raw as f32 / 8388608.0
172 }
173 32 => {
174 let raw = if little_endian {
175 i32::from_le_bytes([
176 data[offset],
177 data[offset + 1],
178 data[offset + 2],
179 data[offset + 3],
180 ])
181 } else {
182 i32::from_be_bytes([
183 data[offset],
184 data[offset + 1],
185 data[offset + 2],
186 data[offset + 3],
187 ])
188 };
189 raw as f32 / 2147483648.0
190 }
191 8 => (data[offset] as f32 - 128.0) / 128.0,
192 _ => 0.0,
193 };
194 samples.push(sample);
195 }
196
197 samples
198}
199
200fn decode_with_symphonia(path: &Path) -> Option<(Vec<f32>, u32)> {
203 use symphonia::core::audio::SampleBuffer;
204 use symphonia::core::codecs::DecoderOptions;
205 use symphonia::core::formats::FormatOptions;
206 use symphonia::core::io::MediaSourceStream;
207 use symphonia::core::meta::MetadataOptions;
208 use symphonia::core::probe::Hint;
209
210 let file = std::fs::File::open(path).ok()?;
211 let mss = MediaSourceStream::new(Box::new(file), Default::default());
212
213 let mut hint = Hint::new();
214 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
215 hint.with_extension(ext);
216 }
217
218 let probed = symphonia::default::get_probe()
219 .format(
220 &hint,
221 mss,
222 &FormatOptions::default(),
223 &MetadataOptions::default(),
224 )
225 .ok()?;
226
227 let mut format = probed.format;
228 let track = format.default_track()?;
229 let sample_rate = track.codec_params.sample_rate?;
230 let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(1);
231 let track_id = track.id;
232
233 let mut decoder = symphonia::default::get_codecs()
234 .make(&track.codec_params, &DecoderOptions::default())
235 .ok()?;
236
237 let mut all_samples: Vec<f32> = Vec::new();
238 let max_samples = sample_rate as usize * 30 * channels;
240
241 while let Ok(packet) = format.next_packet() {
242 if packet.track_id() != track_id {
243 continue;
244 }
245 let decoded = match decoder.decode(&packet) {
246 Ok(d) => d,
247 Err(_) => continue,
248 };
249
250 let spec = *decoded.spec();
251 let duration = decoded.capacity();
252 let mut sample_buf = SampleBuffer::<f32>::new(duration as u64, spec);
253 sample_buf.copy_interleaved_ref(decoded);
254
255 let buf = sample_buf.samples();
256 if channels > 1 {
258 for chunk in buf.chunks_exact(channels) {
259 let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
260 all_samples.push(mono);
261 }
262 } else {
263 all_samples.extend_from_slice(buf);
264 }
265
266 if all_samples.len() >= max_samples {
267 break;
268 }
269 }
270
271 if all_samples.is_empty() {
272 return None;
273 }
274
275 Some((all_samples, sample_rate))
276}
277
278fn detect_tempo(samples: &[f32], sample_rate: u32) -> Option<f64> {
280 let hop = (sample_rate as usize) / 43; if hop == 0 {
283 return None;
284 }
285 let num_frames = samples.len() / hop;
286 if num_frames < 4 {
287 return None;
288 }
289
290 let mut energy = Vec::with_capacity(num_frames);
292 for i in 0..num_frames {
293 let start = i * hop;
294 let end = (start + hop).min(samples.len());
295 let rms: f32 =
296 samples[start..end].iter().map(|s| s * s).sum::<f32>() / (end - start) as f32;
297 energy.push(rms.sqrt());
298 }
299
300 let mut onset = Vec::with_capacity(num_frames);
302 onset.push(0.0f32);
303 for i in 1..energy.len() {
304 let diff = energy[i] - energy[i - 1];
305 onset.push(diff.max(0.0));
306 }
307
308 let max_onset = onset.iter().cloned().fold(0.0f32, f32::max);
310 if max_onset < 1e-8 {
311 return None; }
313 for v in onset.iter_mut() {
314 *v /= max_onset;
315 }
316
317 let frame_rate = sample_rate as f64 / hop as f64;
319 let min_lag = (frame_rate * 60.0 / 220.0).floor() as usize; let max_lag = (frame_rate * 60.0 / 50.0).ceil() as usize; let max_lag = max_lag.min(onset.len() - 1);
322
323 if min_lag >= max_lag || max_lag >= onset.len() {
324 return None;
325 }
326
327 let mut corr_values = vec![0.0f64; max_lag + 1];
329 for lag in min_lag..=max_lag {
330 let n = onset.len() - lag;
331 let mut c = 0.0f64;
332 for i in 0..n {
333 c += onset[i] as f64 * onset[i + lag] as f64;
334 }
335 corr_values[lag] = c / n as f64;
336 }
337
338 let mut best_lag = min_lag;
340 let mut best_corr = f64::NEG_INFINITY;
341 for (lag, &corr) in corr_values
342 .iter()
343 .enumerate()
344 .skip(min_lag)
345 .take(max_lag - min_lag + 1)
346 {
347 if corr > best_corr {
348 best_corr = corr;
349 best_lag = lag;
350 }
351 }
352
353 if best_corr < 0.01 {
354 return None;
355 }
356
357 let mut candidates: Vec<(f64, f64)> = Vec::new(); for divisor in 1..=3 {
361 let candidate_lag = best_lag / divisor;
362 if candidate_lag >= min_lag && candidate_lag <= max_lag {
363 let c = corr_values[candidate_lag];
364 let bpm = frame_rate * 60.0 / candidate_lag as f64;
365 candidates.push((bpm, c));
366 }
367 }
368
369 let raw_bpm = frame_rate * 60.0 / best_lag as f64;
371 candidates.push((raw_bpm, best_corr));
372
373 let mut final_bpm = raw_bpm;
376 let mut best_in_range: Option<(f64, f64)> = None;
377
378 for &(bpm, corr) in &candidates {
379 if (80.0..=170.0).contains(&bpm)
380 && corr > best_corr * 0.25
381 && (best_in_range.is_none() || corr > best_in_range.unwrap().1)
382 {
383 best_in_range = Some((bpm, corr));
384 }
385 }
386
387 if let Some((bpm, _)) = best_in_range {
388 final_bpm = bpm;
389 } else {
390 let mut best_score = f64::NEG_INFINITY;
392 for &(bpm, corr) in &candidates {
393 let weight = if (60.0..=220.0).contains(&bpm) {
394 1.2
395 } else {
396 1.0
397 };
398 if corr * weight > best_score {
399 best_score = corr * weight;
400 final_bpm = bpm;
401 }
402 }
403 }
404
405 let final_lag = (frame_rate * 60.0 / final_bpm).round() as usize;
407 let refined_bpm = if final_lag > min_lag && final_lag < max_lag {
408 let prev = corr_values[final_lag - 1];
409 let curr = corr_values[final_lag];
410 let next = corr_values[final_lag + 1];
411 let denom = 2.0 * (2.0 * curr - prev - next);
412 if denom.abs() > 1e-12 {
413 let refined_lag = final_lag as f64 + (prev - next) / denom;
414 frame_rate * 60.0 / refined_lag
415 } else {
416 final_bpm
417 }
418 } else {
419 final_bpm
420 };
421
422 let rounded = (refined_bpm * 10.0).round() / 10.0;
424 let nearest_int = rounded.round();
425 if (rounded - nearest_int).abs() < 0.15 {
426 Some(nearest_int)
427 } else {
428 Some(rounded)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use std::f32::consts::PI;
436
437 fn write_wav(path: &Path, samples: &[f32], sample_rate: u32) {
438 let num_samples = samples.len() as u32;
439 let bits: u16 = 16;
440 let channels: u16 = 1;
441 let byte_rate = sample_rate * (bits as u32 / 8) * channels as u32;
442 let block_align = channels * (bits / 8);
443 let data_size = num_samples * (bits as u32 / 8);
444 let file_size = 36 + data_size;
445
446 let mut buf = Vec::with_capacity(44 + data_size as usize);
447 buf.extend_from_slice(b"RIFF");
448 buf.extend_from_slice(&file_size.to_le_bytes());
449 buf.extend_from_slice(b"WAVE");
450 buf.extend_from_slice(b"fmt ");
451 buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&channels.to_le_bytes());
454 buf.extend_from_slice(&sample_rate.to_le_bytes());
455 buf.extend_from_slice(&byte_rate.to_le_bytes());
456 buf.extend_from_slice(&block_align.to_le_bytes());
457 buf.extend_from_slice(&bits.to_le_bytes());
458 buf.extend_from_slice(b"data");
459 buf.extend_from_slice(&data_size.to_le_bytes());
460
461 for &s in samples {
462 let i = (s.clamp(-1.0, 1.0) * 32767.0) as i16;
463 buf.extend_from_slice(&i.to_le_bytes());
464 }
465
466 fs::write(path, buf).unwrap();
467 }
468
469 fn click_track(bpm: f64, duration_secs: f64, sample_rate: u32) -> Vec<f32> {
471 let num_samples = (duration_secs * sample_rate as f64) as usize;
472 let samples_per_beat = (60.0 / bpm * sample_rate as f64) as usize;
473 let click_len = (sample_rate as usize) / 100; let mut samples = vec![0.0f32; num_samples];
476 let mut pos = 0;
477 while pos < num_samples {
478 for i in 0..click_len.min(num_samples - pos) {
479 let t = i as f32 / sample_rate as f32;
481 let envelope = 1.0 - (i as f32 / click_len as f32);
482 samples[pos + i] = (2.0 * PI * 1000.0 * t).sin() * envelope * 0.8;
483 }
484 pos += samples_per_beat;
485 }
486 samples
487 }
488
489 #[test]
490 fn test_estimate_bpm_unsupported_format() {
491 assert!(estimate_bpm("/some/file.mp3").is_none());
492 }
493
494 #[test]
495 fn test_estimate_bpm_nonexistent() {
496 assert!(estimate_bpm("/nonexistent/file.wav").is_none());
497 }
498
499 #[test]
500 fn test_read_wav_pcm_pub_truncated_file_returns_none() {
501 let tmp = std::env::temp_dir().join("test_bpm_trunc.wav");
502 fs::write(&tmp, b"RIFF").unwrap();
503 assert!(read_wav_pcm_pub(&tmp).is_none());
504 let _ = fs::remove_file(&tmp);
505 }
506
507 #[test]
508 fn test_read_wav_pcm_pub_not_riff_returns_none() {
509 let tmp = std::env::temp_dir().join("test_bpm_not_riff.wav");
510 fs::write(&tmp, b"XXXX0000WAVE").unwrap();
511 assert!(read_wav_pcm_pub(&tmp).is_none());
512 let _ = fs::remove_file(&tmp);
513 }
514
515 #[test]
516 fn test_estimate_bpm_silence() {
517 let tmp = std::env::temp_dir().join("test_bpm_silence.wav");
518 let silence = vec![0.0f32; 44100 * 4];
519 write_wav(&tmp, &silence, 44100);
520
521 let result = estimate_bpm(tmp.to_str().unwrap());
522 assert!(result.is_none(), "Silence should not produce a BPM");
523
524 let _ = fs::remove_file(&tmp);
525 }
526
527 #[test]
528 fn test_estimate_bpm_120() {
529 let tmp = std::env::temp_dir().join("test_bpm_120.wav");
530 let samples = click_track(120.0, 8.0, 44100);
531 write_wav(&tmp, &samples, 44100);
532
533 let bpm = estimate_bpm(tmp.to_str().unwrap());
534 assert!(bpm.is_some(), "Should detect BPM");
535 let bpm = bpm.unwrap();
536 assert!((bpm - 120.0).abs() < 8.0, "Expected ~120 BPM, got {bpm}");
537
538 let _ = fs::remove_file(&tmp);
539 }
540
541 #[test]
542 fn test_estimate_bpm_140() {
543 let tmp = std::env::temp_dir().join("test_bpm_140.wav");
544 let samples = click_track(140.0, 8.0, 44100);
545 write_wav(&tmp, &samples, 44100);
546
547 let bpm = estimate_bpm(tmp.to_str().unwrap());
548 assert!(bpm.is_some());
549 let bpm = bpm.unwrap();
550 assert!((bpm - 140.0).abs() < 8.0, "Expected ~140 BPM, got {bpm}");
551
552 let _ = fs::remove_file(&tmp);
553 }
554
555 #[test]
556 fn test_estimate_bpm_90() {
557 let tmp = std::env::temp_dir().join("test_bpm_90.wav");
558 let samples = click_track(90.0, 8.0, 44100);
559 write_wav(&tmp, &samples, 44100);
560
561 let bpm = estimate_bpm(tmp.to_str().unwrap());
562 assert!(bpm.is_some());
563 let bpm = bpm.unwrap();
564 assert!((bpm - 90.0).abs() < 8.0, "Expected ~90 BPM, got {bpm}");
565
566 let _ = fs::remove_file(&tmp);
567 }
568
569 #[test]
570 fn test_estimate_bpm_short_file() {
571 let tmp = std::env::temp_dir().join("test_bpm_short.wav");
572 let samples = vec![0.5f32; 4410];
574 write_wav(&tmp, &samples, 44100);
575
576 let result = estimate_bpm(tmp.to_str().unwrap());
578 assert!(result.is_none());
579
580 let _ = fs::remove_file(&tmp);
581 }
582
583 #[test]
584 fn test_decode_pcm_16bit() {
585 let data = [0x00u8, 0x40, 0x00, 0xC0];
587 let samples = decode_pcm(&data, 16, 1, true);
588 assert_eq!(samples.len(), 2);
589 assert!((samples[0] - 0.5).abs() < 0.001);
590 assert!((samples[1] + 0.5).abs() < 0.001);
591 }
592
593 #[test]
594 fn test_decode_pcm_16bit_be() {
595 let data = [0x40u8, 0x00];
597 let samples = decode_pcm(&data, 16, 1, false);
598 assert_eq!(samples.len(), 1);
599 assert!((samples[0] - 0.5).abs() < 0.001);
600 }
601
602 #[test]
603 fn test_decode_pcm_8bit() {
604 let data = [128u8, 255, 0]; let samples = decode_pcm(&data, 8, 1, true);
606 assert_eq!(samples.len(), 3);
607 assert!((samples[0]).abs() < 0.01);
608 assert!(samples[1] > 0.9);
609 assert!(samples[2] < -0.9);
610 }
611
612 #[test]
613 fn test_decode_pcm_stereo_takes_left() {
614 let data = [
616 0x00u8, 0x40, 0x00, 0xC0, 0x00, 0xE0, 0x00, 0x20, ];
619 let samples = decode_pcm(&data, 16, 2, true);
620 assert_eq!(samples.len(), 2);
621 assert!((samples[0] - 0.5).abs() < 0.001);
622 }
623
624 #[test]
625 fn test_estimate_bpm_174() {
626 let tmp = std::env::temp_dir().join("test_bpm_174.wav");
627 let samples = click_track(174.0, 8.0, 44100);
628 write_wav(&tmp, &samples, 44100);
629
630 let bpm = estimate_bpm(tmp.to_str().unwrap());
631 assert!(
632 bpm.is_some(),
633 "Should detect BPM for 174 drum-and-bass tempo"
634 );
635 let bpm = bpm.unwrap();
636 assert!((bpm - 174.0).abs() < 8.0, "Expected ~174 BPM, got {bpm}");
637
638 let _ = fs::remove_file(&tmp);
639 }
640
641 #[test]
642 fn test_decode_pcm_32bit_le() {
643 let mut data = vec![0u8; 4];
644 let raw = 0x4000_0000i32;
645 data.copy_from_slice(&raw.to_le_bytes());
646 let samples = decode_pcm(&data, 32, 1, true);
647 assert_eq!(samples.len(), 1);
648 assert!(
649 (samples[0] - 0.5).abs() < 1e-5,
650 "32-bit LE 0x40000000 should normalize ~0.5, got {}",
651 samples[0]
652 );
653 }
654
655 #[test]
656 fn test_decode_pcm_24bit() {
657 let data = [0x00u8, 0x00, 0x40];
660 let samples = decode_pcm(&data, 24, 1, true);
661 assert_eq!(samples.len(), 1);
662 assert!(
663 (samples[0] - 0.5).abs() < 0.001,
664 "Expected ~0.5, got {}",
665 samples[0]
666 );
667 }
668
669 #[test]
670 fn test_decode_pcm_empty() {
671 let data: [u8; 0] = [];
672 let samples = decode_pcm(&data, 16, 1, true);
673 assert!(
674 samples.is_empty(),
675 "Empty input should produce empty output"
676 );
677 }
678
679 #[test]
680 fn test_decode_pcm_16bit_truncates_trailing_odd_byte() {
681 let data = [0x00u8, 0x40, 0xFF];
682 let samples = decode_pcm(&data, 16, 1, true);
683 assert_eq!(samples.len(), 1);
684 assert!((samples[0] - 0.5).abs() < 0.001);
685 }
686
687 #[test]
688 fn test_decode_pcm_32bit_big_endian() {
689 let mut data = vec![0u8; 4];
690 let raw = 0x4000_0000i32;
691 data.copy_from_slice(&raw.to_be_bytes());
692 let samples = decode_pcm(&data, 32, 1, false);
693 assert_eq!(samples.len(), 1);
694 assert!(
695 (samples[0] - 0.5).abs() < 1e-5,
696 "32-bit BE 0x40000000 → ~0.5, got {}",
697 samples[0]
698 );
699 }
700
701 #[test]
702 fn test_decode_pcm_unsupported_bit_depth_yields_zero_sample() {
703 let data = [0xABu8, 0xCD];
704 let samples = decode_pcm(&data, 12, 1, true);
705 assert_eq!(samples.len(), 2);
706 assert_eq!(samples[0], 0.0);
707 assert_eq!(samples[1], 0.0);
708 }
709
710 #[test]
711 fn test_read_wav_invalid_header() {
712 let tmp = std::env::temp_dir().join("test_bpm_invalid_header.wav");
713 fs::write(
715 &tmp,
716 b"NOT_RIFF_DATA_HERE_1234567890abcdefghijklmnopqrstuvwx",
717 )
718 .unwrap();
719
720 let result = read_wav_pcm(&tmp);
721 assert!(result.is_none(), "Non-RIFF data should return None");
722
723 let _ = fs::remove_file(&tmp);
724 }
725
726 #[test]
727 fn test_read_aiff_basic() {
728 let tmp = std::env::temp_dir().join("test_bpm_aiff_basic.aiff");
729
730 let mut data = Vec::new();
731 data.extend_from_slice(b"FORM");
732 data.extend_from_slice(&[0u8; 4]);
734 data.extend_from_slice(b"AIFF");
735
736 data.extend_from_slice(b"COMM");
738 data.extend_from_slice(&18u32.to_be_bytes());
739 data.extend_from_slice(&1u16.to_be_bytes()); data.extend_from_slice(&1000u32.to_be_bytes()); data.extend_from_slice(&16u16.to_be_bytes()); data.extend_from_slice(&[0x40, 0x0E, 0xAC, 0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
746
747 let pcm_bytes = 1000 * 2; let ssnd_size = 8 + pcm_bytes;
750 data.extend_from_slice(b"SSND");
751 data.extend_from_slice(&(ssnd_size as u32).to_be_bytes());
752 data.extend_from_slice(&0u32.to_be_bytes()); data.extend_from_slice(&0u32.to_be_bytes()); data.extend_from_slice(&vec![0u8; pcm_bytes]);
756
757 let form_size = (data.len() - 8) as u32;
759 data[4..8].copy_from_slice(&form_size.to_be_bytes());
760
761 fs::write(&tmp, &data).unwrap();
762
763 let result = read_aiff_pcm(&tmp);
764 assert!(result.is_some(), "Valid AIFF should parse successfully");
765 let (samples, sr) = result.unwrap();
766 assert_eq!(sr, 44100);
767 assert_eq!(samples.len(), 1000);
768 assert!(samples.iter().all(|&s| s.abs() < 0.001));
770
771 let _ = fs::remove_file(&tmp);
772 }
773
774 #[test]
775 fn test_read_wav_with_extra_chunks() {
776 let tmp = std::env::temp_dir().join("test_bpm_extrachunk.wav");
778 let pcm_data: Vec<u8> = (0..4410).flat_map(|_| 0i16.to_le_bytes()).collect();
779 let list_chunk = b"LIST\x04\x00\x00\x00INFO";
780 let data_size = pcm_data.len() as u32;
781 let file_size = 4 + 24 + 8 + list_chunk.len() as u32 + 8 + data_size;
782
783 let mut buf = Vec::new();
784 buf.extend_from_slice(b"RIFF");
785 buf.extend_from_slice(&file_size.to_le_bytes());
786 buf.extend_from_slice(b"WAVE");
787 buf.extend_from_slice(b"fmt ");
788 buf.extend_from_slice(&16u32.to_le_bytes());
789 buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&44100u32.to_le_bytes());
792 buf.extend_from_slice(&88200u32.to_le_bytes()); buf.extend_from_slice(&2u16.to_le_bytes()); buf.extend_from_slice(&16u16.to_le_bytes()); buf.extend_from_slice(list_chunk);
796 buf.extend_from_slice(b"data");
797 buf.extend_from_slice(&data_size.to_le_bytes());
798 buf.extend_from_slice(&pcm_data);
799
800 fs::write(&tmp, buf).unwrap();
801
802 let result = read_wav_pcm(&tmp);
803 assert!(result.is_some());
804 let (samples, sr) = result.unwrap();
805 assert_eq!(sr, 44100);
806 assert_eq!(samples.len(), 4410);
807
808 let _ = fs::remove_file(&tmp);
809 }
810
811 #[test]
812 fn test_decode_symphonia_nonexistent() {
813 let result = decode_with_symphonia(Path::new("/nonexistent/file.mp3"));
814 assert!(result.is_none());
815 }
816
817 #[test]
818 fn test_decode_symphonia_invalid_data() {
819 let tmp = std::env::temp_dir().join("bpm_test_invalid.mp3");
820 fs::write(&tmp, b"this is not an mp3 file").unwrap();
821 let result = decode_with_symphonia(&tmp);
822 assert!(result.is_none(), "garbage data should return None");
823 let _ = fs::remove_file(&tmp);
824 }
825
826 #[test]
827 fn test_decode_symphonia_wav_fallback() {
828 let tmp = std::env::temp_dir().join("bpm_test_sym_wav.wav");
830 let sr = 44100u32;
831 let n = 4410usize; let mut buf = Vec::new();
833 let data_size = (n * 2) as u32;
834 buf.extend_from_slice(b"RIFF");
835 buf.extend_from_slice(&(36 + data_size).to_le_bytes());
836 buf.extend_from_slice(b"WAVE");
837 buf.extend_from_slice(b"fmt ");
838 buf.extend_from_slice(&16u32.to_le_bytes());
839 buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&sr.to_le_bytes());
842 buf.extend_from_slice(&(sr * 2).to_le_bytes());
843 buf.extend_from_slice(&2u16.to_le_bytes());
844 buf.extend_from_slice(&16u16.to_le_bytes());
845 buf.extend_from_slice(b"data");
846 buf.extend_from_slice(&data_size.to_le_bytes());
847 for i in 0..n {
848 let t = i as f64 / sr as f64;
849 let s = ((t * 440.0 * 2.0 * std::f64::consts::PI).sin() * 16000.0) as i16;
850 buf.extend_from_slice(&s.to_le_bytes());
851 }
852 fs::write(&tmp, &buf).unwrap();
853
854 let result = decode_with_symphonia(&tmp);
855 assert!(result.is_some(), "symphonia should decode valid WAV");
856 let (samples, rate) = result.unwrap();
857 assert_eq!(rate, 44100);
858 assert!(!samples.is_empty());
859
860 let _ = fs::remove_file(&tmp);
861 }
862
863 #[test]
864 fn test_estimate_bpm_zero_length() {
865 let tmp = std::env::temp_dir().join("bpm_test_zero.wav");
867 let mut buf = Vec::new();
868 buf.extend_from_slice(b"RIFF");
869 buf.extend_from_slice(&36u32.to_le_bytes());
870 buf.extend_from_slice(b"WAVE");
871 buf.extend_from_slice(b"fmt ");
872 buf.extend_from_slice(&16u32.to_le_bytes());
873 buf.extend_from_slice(&1u16.to_le_bytes());
874 buf.extend_from_slice(&1u16.to_le_bytes());
875 buf.extend_from_slice(&44100u32.to_le_bytes());
876 buf.extend_from_slice(&(44100u32 * 2).to_le_bytes());
877 buf.extend_from_slice(&2u16.to_le_bytes());
878 buf.extend_from_slice(&16u16.to_le_bytes());
879 buf.extend_from_slice(b"data");
880 buf.extend_from_slice(&0u32.to_le_bytes());
881 fs::write(&tmp, &buf).unwrap();
882
883 let bpm = estimate_bpm(tmp.to_str().unwrap());
884 assert!(bpm.is_none(), "zero-length audio should not estimate BPM");
885 let _ = fs::remove_file(&tmp);
886 }
887
888 #[test]
889 fn test_read_aiff_nonexistent() {
890 let result = read_aiff_pcm(Path::new("/nonexistent/file.aiff"));
891 assert!(result.is_none());
892 }
893
894 #[test]
895 fn test_read_aiff_invalid_header() {
896 let tmp = std::env::temp_dir().join("bpm_test_bad_aiff.aiff");
897 fs::write(&tmp, b"not an aiff file at all").unwrap();
898 let result = read_aiff_pcm(&tmp);
899 assert!(result.is_none());
900 let _ = fs::remove_file(&tmp);
901 }
902
903 #[test]
904 fn test_bpm_rounding_snaps_to_integer() {
905 let round = |v: f64| -> f64 {
907 let rounded = (v * 10.0).round() / 10.0;
908 let nearest = rounded.round();
909 if (rounded - nearest).abs() < 0.15 {
910 nearest
911 } else {
912 rounded
913 }
914 };
915 assert_eq!(round(120.08), 120.0);
916 assert_eq!(round(119.92), 120.0);
917 assert_eq!(round(128.14), 128.0);
918 assert_eq!(round(128.0), 128.0);
919 assert_eq!(round(127.5), 127.5);
921 assert_eq!(round(135.3), 135.3);
922 assert_eq!(round(99.8), 99.8);
923 }
924}