app_lib/
bpm.rs

1//! BPM estimation via onset-strength autocorrelation.
2//!
3//! Reads raw PCM data from WAV and AIFF files, computes an energy envelope,
4//! derives onset strength, then uses autocorrelation to find the dominant
5//! tempo in the 50–220 BPM range.
6
7use std::fs;
8use std::path::Path;
9
10/// Estimate BPM for an audio file. Returns None for unsupported formats,
11/// unreadable files, or when no clear tempo is detected.
12pub fn estimate_bpm(file_path: &str) -> Option<f64> {
13    let path = Path::new(file_path);
14    let ext = path
15        .extension()
16        .and_then(|e| e.to_str())
17        .unwrap_or("")
18        .to_lowercase();
19
20    let (samples, sample_rate) = match ext.as_str() {
21        "wav" => read_wav_pcm(path)?,
22        "aiff" | "aif" => read_aiff_pcm(path)?,
23        "mp3" | "flac" | "ogg" | "m4a" | "aac" | "opus" => decode_with_symphonia(path)?,
24        _ => return None,
25    };
26
27    if samples.len() < 1024 || sample_rate == 0 {
28        return None;
29    }
30
31    detect_tempo(&samples, sample_rate)
32}
33
34// Public wrappers for use by similarity module
35pub fn read_wav_pcm_pub(path: &Path) -> Option<(Vec<f32>, u32)> {
36    read_wav_pcm(path)
37}
38pub fn read_aiff_pcm_pub(path: &Path) -> Option<(Vec<f32>, u32)> {
39    read_aiff_pcm(path)
40}
41pub fn decode_with_symphonia_pub(path: &Path) -> Option<(Vec<f32>, u32)> {
42    decode_with_symphonia(path)
43}
44
45/// Read WAV file and return mono f32 samples + sample rate.
46fn read_wav_pcm(path: &Path) -> Option<(Vec<f32>, u32)> {
47    let data = fs::read(path).ok()?;
48    if data.len() < 44 || &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
49        return None;
50    }
51
52    let channels = u16::from_le_bytes([data[22], data[23]]) as usize;
53    let sample_rate = u32::from_le_bytes([data[24], data[25], data[26], data[27]]);
54    let bits = u16::from_le_bytes([data[34], data[35]]);
55
56    // Find the data chunk (don't assume it starts at byte 44)
57    let mut offset = 12;
58    while offset + 8 < data.len() {
59        let chunk_id = &data[offset..offset + 4];
60        let chunk_size = u32::from_le_bytes([
61            data[offset + 4],
62            data[offset + 5],
63            data[offset + 6],
64            data[offset + 7],
65        ]) as usize;
66        if chunk_id == b"data" {
67            let start = offset + 8;
68            let end = (start + chunk_size).min(data.len());
69            let pcm = &data[start..end];
70            let samples = decode_pcm(pcm, bits, channels, true);
71            return Some((samples, sample_rate));
72        }
73        offset += 8 + chunk_size;
74        if !chunk_size.is_multiple_of(2) {
75            offset += 1;
76        }
77    }
78    None
79}
80
81/// Read AIFF file and return mono f32 samples + sample rate.
82fn read_aiff_pcm(path: &Path) -> Option<(Vec<f32>, u32)> {
83    let data = fs::read(path).ok()?;
84    if data.len() < 12 || &data[0..4] != b"FORM" || &data[8..12] != b"AIFF" {
85        return None;
86    }
87
88    let mut channels = 0u16;
89    let mut bits = 0u16;
90    let mut sample_rate = 0u32;
91    let mut ssnd_data: Option<&[u8]> = None;
92
93    let mut offset = 12;
94    while offset + 8 < data.len() {
95        let chunk_id = &data[offset..offset + 4];
96        let chunk_size = u32::from_be_bytes([
97            data[offset + 4],
98            data[offset + 5],
99            data[offset + 6],
100            data[offset + 7],
101        ]) as usize;
102
103        if chunk_id == b"COMM" && offset + 26 <= data.len() {
104            channels = u16::from_be_bytes([data[offset + 8], data[offset + 9]]);
105            bits = u16::from_be_bytes([data[offset + 14], data[offset + 15]]);
106            // 80-bit extended float sample rate
107            let exp = u16::from_be_bytes([data[offset + 16], data[offset + 17]]) as i32;
108            let mantissa = u32::from_be_bytes([
109                data[offset + 18],
110                data[offset + 19],
111                data[offset + 20],
112                data[offset + 21],
113            ]);
114            sample_rate = (mantissa as f64 * 2f64.powi(exp - 16383 - 31)).round() as u32;
115        } else if chunk_id == b"SSND" {
116            // SSND has 8 bytes of offset/blockSize before sample data
117            let start = offset + 8 + 8;
118            let end = (offset + 8 + chunk_size).min(data.len());
119            if start < end {
120                ssnd_data = Some(&data[start..end]);
121            }
122        }
123
124        offset += 8 + chunk_size;
125        if !chunk_size.is_multiple_of(2) {
126            offset += 1;
127        }
128    }
129
130    let pcm = ssnd_data?;
131    if channels == 0 || sample_rate == 0 {
132        return None;
133    }
134    let samples = decode_pcm(pcm, bits, channels as usize, false);
135    Some((samples, sample_rate))
136}
137
138/// Decode raw PCM bytes to mono f32 samples, normalized to [-1, 1].
139fn decode_pcm(data: &[u8], bits: u16, channels: usize, little_endian: bool) -> Vec<f32> {
140    let bytes_per_sample = (bits / 8) as usize;
141    let frame_size = bytes_per_sample * channels;
142    if frame_size == 0 {
143        return vec![];
144    }
145
146    let num_frames = data.len() / frame_size;
147    // Cap at ~30 seconds at 44.1kHz for performance
148    let max_frames = 44100 * 30;
149    let frames = num_frames.min(max_frames);
150    let mut samples = Vec::with_capacity(frames);
151
152    for i in 0..frames {
153        let offset = i * frame_size;
154        // Read first channel only (mono mixdown)
155        let sample = match bits {
156            16 => {
157                let raw = if little_endian {
158                    i16::from_le_bytes([data[offset], data[offset + 1]])
159                } else {
160                    i16::from_be_bytes([data[offset], data[offset + 1]])
161                };
162                raw as f32 / 32768.0
163            }
164            24 => {
165                let (b0, b1, b2) = if little_endian {
166                    (data[offset], data[offset + 1], data[offset + 2])
167                } else {
168                    (data[offset + 2], data[offset + 1], data[offset])
169                };
170                let raw = ((b2 as i32) << 24 | (b1 as i32) << 16 | (b0 as i32) << 8) >> 8;
171                raw as f32 / 8388608.0
172            }
173            32 => {
174                let raw = if little_endian {
175                    i32::from_le_bytes([
176                        data[offset],
177                        data[offset + 1],
178                        data[offset + 2],
179                        data[offset + 3],
180                    ])
181                } else {
182                    i32::from_be_bytes([
183                        data[offset],
184                        data[offset + 1],
185                        data[offset + 2],
186                        data[offset + 3],
187                    ])
188                };
189                raw as f32 / 2147483648.0
190            }
191            8 => (data[offset] as f32 - 128.0) / 128.0,
192            _ => 0.0,
193        };
194        samples.push(sample);
195    }
196
197    samples
198}
199
200/// Decode compressed audio (MP3, FLAC, OGG, M4A, AAC) via symphonia.
201/// Returns mono f32 samples + sample rate, or None on failure.
202fn decode_with_symphonia(path: &Path) -> Option<(Vec<f32>, u32)> {
203    use symphonia::core::audio::SampleBuffer;
204    use symphonia::core::codecs::DecoderOptions;
205    use symphonia::core::formats::FormatOptions;
206    use symphonia::core::io::MediaSourceStream;
207    use symphonia::core::meta::MetadataOptions;
208    use symphonia::core::probe::Hint;
209
210    let file = std::fs::File::open(path).ok()?;
211    let mss = MediaSourceStream::new(Box::new(file), Default::default());
212
213    let mut hint = Hint::new();
214    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
215        hint.with_extension(ext);
216    }
217
218    let probed = symphonia::default::get_probe()
219        .format(
220            &hint,
221            mss,
222            &FormatOptions::default(),
223            &MetadataOptions::default(),
224        )
225        .ok()?;
226
227    let mut format = probed.format;
228    let track = format.default_track()?;
229    let sample_rate = track.codec_params.sample_rate?;
230    let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(1);
231    let track_id = track.id;
232
233    let mut decoder = symphonia::default::get_codecs()
234        .make(&track.codec_params, &DecoderOptions::default())
235        .ok()?;
236
237    let mut all_samples: Vec<f32> = Vec::new();
238    // Limit to ~30 seconds for BPM detection (avoid decoding huge files)
239    let max_samples = sample_rate as usize * 30 * channels;
240
241    while let Ok(packet) = format.next_packet() {
242        if packet.track_id() != track_id {
243            continue;
244        }
245        let decoded = match decoder.decode(&packet) {
246            Ok(d) => d,
247            Err(_) => continue,
248        };
249
250        let spec = *decoded.spec();
251        let duration = decoded.capacity();
252        let mut sample_buf = SampleBuffer::<f32>::new(duration as u64, spec);
253        sample_buf.copy_interleaved_ref(decoded);
254
255        let buf = sample_buf.samples();
256        // Mix to mono
257        if channels > 1 {
258            for chunk in buf.chunks_exact(channels) {
259                let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
260                all_samples.push(mono);
261            }
262        } else {
263            all_samples.extend_from_slice(buf);
264        }
265
266        if all_samples.len() >= max_samples {
267            break;
268        }
269    }
270
271    if all_samples.is_empty() {
272        return None;
273    }
274
275    Some((all_samples, sample_rate))
276}
277
278/// Detect tempo using onset-strength autocorrelation.
279fn detect_tempo(samples: &[f32], sample_rate: u32) -> Option<f64> {
280    // Window size for energy computation (~23ms at 44.1kHz)
281    let hop = (sample_rate as usize) / 43; // ~1024 at 44.1kHz
282    if hop == 0 {
283        return None;
284    }
285    let num_frames = samples.len() / hop;
286    if num_frames < 4 {
287        return None;
288    }
289
290    // Compute RMS energy per frame
291    let mut energy = Vec::with_capacity(num_frames);
292    for i in 0..num_frames {
293        let start = i * hop;
294        let end = (start + hop).min(samples.len());
295        let rms: f32 =
296            samples[start..end].iter().map(|s| s * s).sum::<f32>() / (end - start) as f32;
297        energy.push(rms.sqrt());
298    }
299
300    // Onset strength: half-wave rectified first difference
301    let mut onset = Vec::with_capacity(num_frames);
302    onset.push(0.0f32);
303    for i in 1..energy.len() {
304        let diff = energy[i] - energy[i - 1];
305        onset.push(diff.max(0.0));
306    }
307
308    // Normalize onset strength
309    let max_onset = onset.iter().cloned().fold(0.0f32, f32::max);
310    if max_onset < 1e-8 {
311        return None; // silence
312    }
313    for v in onset.iter_mut() {
314        *v /= max_onset;
315    }
316
317    // Autocorrelation over BPM range 50–220
318    let frame_rate = sample_rate as f64 / hop as f64;
319    let min_lag = (frame_rate * 60.0 / 220.0).floor() as usize; // 220 BPM
320    let max_lag = (frame_rate * 60.0 / 50.0).ceil() as usize; // 50 BPM
321    let max_lag = max_lag.min(onset.len() - 1);
322
323    if min_lag >= max_lag || max_lag >= onset.len() {
324        return None;
325    }
326
327    // Compute raw autocorrelation for all lags
328    let mut corr_values = vec![0.0f64; max_lag + 1];
329    for lag in min_lag..=max_lag {
330        let n = onset.len() - lag;
331        let mut c = 0.0f64;
332        for i in 0..n {
333            c += onset[i] as f64 * onset[i + lag] as f64;
334        }
335        corr_values[lag] = c / n as f64;
336    }
337
338    // Find raw best lag
339    let mut best_lag = min_lag;
340    let mut best_corr = f64::NEG_INFINITY;
341    for (lag, &corr) in corr_values
342        .iter()
343        .enumerate()
344        .skip(min_lag)
345        .take(max_lag - min_lag + 1)
346    {
347        if corr > best_corr {
348            best_corr = corr;
349            best_lag = lag;
350        }
351    }
352
353    if best_corr < 0.01 {
354        return None;
355    }
356
357    // Collect candidate tempos: the raw peak + sub-harmonics (lag/2, lag/3)
358    let mut candidates: Vec<(f64, f64)> = Vec::new(); // (bpm, correlation)
359
360    for divisor in 1..=3 {
361        let candidate_lag = best_lag / divisor;
362        if candidate_lag >= min_lag && candidate_lag <= max_lag {
363            let c = corr_values[candidate_lag];
364            let bpm = frame_rate * 60.0 / candidate_lag as f64;
365            candidates.push((bpm, c));
366        }
367    }
368
369    // Also check the raw best
370    let raw_bpm = frame_rate * 60.0 / best_lag as f64;
371    candidates.push((raw_bpm, best_corr));
372
373    // Select best candidate: if any candidate in the 80–170 BPM range has
374    // reasonable correlation (>30% of best), prefer it over out-of-range peaks
375    let mut final_bpm = raw_bpm;
376    let mut best_in_range: Option<(f64, f64)> = None;
377
378    for &(bpm, corr) in &candidates {
379        if (80.0..=170.0).contains(&bpm)
380            && corr > best_corr * 0.25
381            && (best_in_range.is_none() || corr > best_in_range.unwrap().1)
382        {
383            best_in_range = Some((bpm, corr));
384        }
385    }
386
387    if let Some((bpm, _)) = best_in_range {
388        final_bpm = bpm;
389    } else {
390        // Fallback: pick highest-weighted candidate
391        let mut best_score = f64::NEG_INFINITY;
392        for &(bpm, corr) in &candidates {
393            let weight = if (60.0..=220.0).contains(&bpm) {
394                1.2
395            } else {
396                1.0
397            };
398            if corr * weight > best_score {
399                best_score = corr * weight;
400                final_bpm = bpm;
401            }
402        }
403    }
404
405    // Parabolic interpolation for sub-frame accuracy
406    let final_lag = (frame_rate * 60.0 / final_bpm).round() as usize;
407    let refined_bpm = if final_lag > min_lag && final_lag < max_lag {
408        let prev = corr_values[final_lag - 1];
409        let curr = corr_values[final_lag];
410        let next = corr_values[final_lag + 1];
411        let denom = 2.0 * (2.0 * curr - prev - next);
412        if denom.abs() > 1e-12 {
413            let refined_lag = final_lag as f64 + (prev - next) / denom;
414            frame_rate * 60.0 / refined_lag
415        } else {
416            final_bpm
417        }
418    } else {
419        final_bpm
420    };
421
422    // Round to nearest whole number if within 0.15, otherwise keep 1 decimal
423    let rounded = (refined_bpm * 10.0).round() / 10.0;
424    let nearest_int = rounded.round();
425    if (rounded - nearest_int).abs() < 0.15 {
426        Some(nearest_int)
427    } else {
428        Some(rounded)
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use std::f32::consts::PI;
436
437    fn write_wav(path: &Path, samples: &[f32], sample_rate: u32) {
438        let num_samples = samples.len() as u32;
439        let bits: u16 = 16;
440        let channels: u16 = 1;
441        let byte_rate = sample_rate * (bits as u32 / 8) * channels as u32;
442        let block_align = channels * (bits / 8);
443        let data_size = num_samples * (bits as u32 / 8);
444        let file_size = 36 + data_size;
445
446        let mut buf = Vec::with_capacity(44 + data_size as usize);
447        buf.extend_from_slice(b"RIFF");
448        buf.extend_from_slice(&file_size.to_le_bytes());
449        buf.extend_from_slice(b"WAVE");
450        buf.extend_from_slice(b"fmt ");
451        buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
452        buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
453        buf.extend_from_slice(&channels.to_le_bytes());
454        buf.extend_from_slice(&sample_rate.to_le_bytes());
455        buf.extend_from_slice(&byte_rate.to_le_bytes());
456        buf.extend_from_slice(&block_align.to_le_bytes());
457        buf.extend_from_slice(&bits.to_le_bytes());
458        buf.extend_from_slice(b"data");
459        buf.extend_from_slice(&data_size.to_le_bytes());
460
461        for &s in samples {
462            let i = (s.clamp(-1.0, 1.0) * 32767.0) as i16;
463            buf.extend_from_slice(&i.to_le_bytes());
464        }
465
466        fs::write(path, buf).unwrap();
467    }
468
469    /// Generate a click track at a specific BPM.
470    fn click_track(bpm: f64, duration_secs: f64, sample_rate: u32) -> Vec<f32> {
471        let num_samples = (duration_secs * sample_rate as f64) as usize;
472        let samples_per_beat = (60.0 / bpm * sample_rate as f64) as usize;
473        let click_len = (sample_rate as usize) / 100; // 10ms click
474
475        let mut samples = vec![0.0f32; num_samples];
476        let mut pos = 0;
477        while pos < num_samples {
478            for i in 0..click_len.min(num_samples - pos) {
479                // Short sine burst
480                let t = i as f32 / sample_rate as f32;
481                let envelope = 1.0 - (i as f32 / click_len as f32);
482                samples[pos + i] = (2.0 * PI * 1000.0 * t).sin() * envelope * 0.8;
483            }
484            pos += samples_per_beat;
485        }
486        samples
487    }
488
489    #[test]
490    fn test_estimate_bpm_unsupported_format() {
491        assert!(estimate_bpm("/some/file.mp3").is_none());
492    }
493
494    #[test]
495    fn test_estimate_bpm_nonexistent() {
496        assert!(estimate_bpm("/nonexistent/file.wav").is_none());
497    }
498
499    #[test]
500    fn test_read_wav_pcm_pub_truncated_file_returns_none() {
501        let tmp = std::env::temp_dir().join("test_bpm_trunc.wav");
502        fs::write(&tmp, b"RIFF").unwrap();
503        assert!(read_wav_pcm_pub(&tmp).is_none());
504        let _ = fs::remove_file(&tmp);
505    }
506
507    #[test]
508    fn test_read_wav_pcm_pub_not_riff_returns_none() {
509        let tmp = std::env::temp_dir().join("test_bpm_not_riff.wav");
510        fs::write(&tmp, b"XXXX0000WAVE").unwrap();
511        assert!(read_wav_pcm_pub(&tmp).is_none());
512        let _ = fs::remove_file(&tmp);
513    }
514
515    #[test]
516    fn test_estimate_bpm_silence() {
517        let tmp = std::env::temp_dir().join("test_bpm_silence.wav");
518        let silence = vec![0.0f32; 44100 * 4];
519        write_wav(&tmp, &silence, 44100);
520
521        let result = estimate_bpm(tmp.to_str().unwrap());
522        assert!(result.is_none(), "Silence should not produce a BPM");
523
524        let _ = fs::remove_file(&tmp);
525    }
526
527    #[test]
528    fn test_estimate_bpm_120() {
529        let tmp = std::env::temp_dir().join("test_bpm_120.wav");
530        let samples = click_track(120.0, 8.0, 44100);
531        write_wav(&tmp, &samples, 44100);
532
533        let bpm = estimate_bpm(tmp.to_str().unwrap());
534        assert!(bpm.is_some(), "Should detect BPM");
535        let bpm = bpm.unwrap();
536        assert!((bpm - 120.0).abs() < 8.0, "Expected ~120 BPM, got {bpm}");
537
538        let _ = fs::remove_file(&tmp);
539    }
540
541    #[test]
542    fn test_estimate_bpm_140() {
543        let tmp = std::env::temp_dir().join("test_bpm_140.wav");
544        let samples = click_track(140.0, 8.0, 44100);
545        write_wav(&tmp, &samples, 44100);
546
547        let bpm = estimate_bpm(tmp.to_str().unwrap());
548        assert!(bpm.is_some());
549        let bpm = bpm.unwrap();
550        assert!((bpm - 140.0).abs() < 8.0, "Expected ~140 BPM, got {bpm}");
551
552        let _ = fs::remove_file(&tmp);
553    }
554
555    #[test]
556    fn test_estimate_bpm_90() {
557        let tmp = std::env::temp_dir().join("test_bpm_90.wav");
558        let samples = click_track(90.0, 8.0, 44100);
559        write_wav(&tmp, &samples, 44100);
560
561        let bpm = estimate_bpm(tmp.to_str().unwrap());
562        assert!(bpm.is_some());
563        let bpm = bpm.unwrap();
564        assert!((bpm - 90.0).abs() < 8.0, "Expected ~90 BPM, got {bpm}");
565
566        let _ = fs::remove_file(&tmp);
567    }
568
569    #[test]
570    fn test_estimate_bpm_short_file() {
571        let tmp = std::env::temp_dir().join("test_bpm_short.wav");
572        // Very short file — 0.1 seconds
573        let samples = vec![0.5f32; 4410];
574        write_wav(&tmp, &samples, 44100);
575
576        // Should return None — too short to detect
577        let result = estimate_bpm(tmp.to_str().unwrap());
578        assert!(result.is_none());
579
580        let _ = fs::remove_file(&tmp);
581    }
582
583    #[test]
584    fn test_decode_pcm_16bit() {
585        // Two 16-bit LE samples: 16384 (0.5) and -16384 (-0.5)
586        let data = [0x00u8, 0x40, 0x00, 0xC0];
587        let samples = decode_pcm(&data, 16, 1, true);
588        assert_eq!(samples.len(), 2);
589        assert!((samples[0] - 0.5).abs() < 0.001);
590        assert!((samples[1] + 0.5).abs() < 0.001);
591    }
592
593    #[test]
594    fn test_decode_pcm_16bit_be() {
595        // Big-endian 0x4000 = 16384 → 0.5 normalized
596        let data = [0x40u8, 0x00];
597        let samples = decode_pcm(&data, 16, 1, false);
598        assert_eq!(samples.len(), 1);
599        assert!((samples[0] - 0.5).abs() < 0.001);
600    }
601
602    #[test]
603    fn test_decode_pcm_8bit() {
604        let data = [128u8, 255, 0]; // 0.0, ~1.0, ~-1.0
605        let samples = decode_pcm(&data, 8, 1, true);
606        assert_eq!(samples.len(), 3);
607        assert!((samples[0]).abs() < 0.01);
608        assert!(samples[1] > 0.9);
609        assert!(samples[2] < -0.9);
610    }
611
612    #[test]
613    fn test_decode_pcm_stereo_takes_left() {
614        // Two stereo frames, 16-bit LE: (L=0.5, R=-0.5), (L=-0.25, R=0.25)
615        let data = [
616            0x00u8, 0x40, 0x00, 0xC0, // frame 1
617            0x00, 0xE0, 0x00, 0x20, // frame 2
618        ];
619        let samples = decode_pcm(&data, 16, 2, true);
620        assert_eq!(samples.len(), 2);
621        assert!((samples[0] - 0.5).abs() < 0.001);
622    }
623
624    #[test]
625    fn test_estimate_bpm_174() {
626        let tmp = std::env::temp_dir().join("test_bpm_174.wav");
627        let samples = click_track(174.0, 8.0, 44100);
628        write_wav(&tmp, &samples, 44100);
629
630        let bpm = estimate_bpm(tmp.to_str().unwrap());
631        assert!(
632            bpm.is_some(),
633            "Should detect BPM for 174 drum-and-bass tempo"
634        );
635        let bpm = bpm.unwrap();
636        assert!((bpm - 174.0).abs() < 8.0, "Expected ~174 BPM, got {bpm}");
637
638        let _ = fs::remove_file(&tmp);
639    }
640
641    #[test]
642    fn test_decode_pcm_32bit_le() {
643        let mut data = vec![0u8; 4];
644        let raw = 0x4000_0000i32;
645        data.copy_from_slice(&raw.to_le_bytes());
646        let samples = decode_pcm(&data, 32, 1, true);
647        assert_eq!(samples.len(), 1);
648        assert!(
649            (samples[0] - 0.5).abs() < 1e-5,
650            "32-bit LE 0x40000000 should normalize ~0.5, got {}",
651            samples[0]
652        );
653    }
654
655    #[test]
656    fn test_decode_pcm_24bit() {
657        // 24-bit LE sample: bytes [0x00, 0x00, 0x40] = 0x400000 as signed = 4194304
658        // normalized: 4194304 / 8388608 = 0.5
659        let data = [0x00u8, 0x00, 0x40];
660        let samples = decode_pcm(&data, 24, 1, true);
661        assert_eq!(samples.len(), 1);
662        assert!(
663            (samples[0] - 0.5).abs() < 0.001,
664            "Expected ~0.5, got {}",
665            samples[0]
666        );
667    }
668
669    #[test]
670    fn test_decode_pcm_empty() {
671        let data: [u8; 0] = [];
672        let samples = decode_pcm(&data, 16, 1, true);
673        assert!(
674            samples.is_empty(),
675            "Empty input should produce empty output"
676        );
677    }
678
679    #[test]
680    fn test_decode_pcm_16bit_truncates_trailing_odd_byte() {
681        let data = [0x00u8, 0x40, 0xFF];
682        let samples = decode_pcm(&data, 16, 1, true);
683        assert_eq!(samples.len(), 1);
684        assert!((samples[0] - 0.5).abs() < 0.001);
685    }
686
687    #[test]
688    fn test_decode_pcm_32bit_big_endian() {
689        let mut data = vec![0u8; 4];
690        let raw = 0x4000_0000i32;
691        data.copy_from_slice(&raw.to_be_bytes());
692        let samples = decode_pcm(&data, 32, 1, false);
693        assert_eq!(samples.len(), 1);
694        assert!(
695            (samples[0] - 0.5).abs() < 1e-5,
696            "32-bit BE 0x40000000 → ~0.5, got {}",
697            samples[0]
698        );
699    }
700
701    #[test]
702    fn test_decode_pcm_unsupported_bit_depth_yields_zero_sample() {
703        let data = [0xABu8, 0xCD];
704        let samples = decode_pcm(&data, 12, 1, true);
705        assert_eq!(samples.len(), 2);
706        assert_eq!(samples[0], 0.0);
707        assert_eq!(samples[1], 0.0);
708    }
709
710    #[test]
711    fn test_read_wav_invalid_header() {
712        let tmp = std::env::temp_dir().join("test_bpm_invalid_header.wav");
713        // Write data that is NOT a valid RIFF header
714        fs::write(
715            &tmp,
716            b"NOT_RIFF_DATA_HERE_1234567890abcdefghijklmnopqrstuvwx",
717        )
718        .unwrap();
719
720        let result = read_wav_pcm(&tmp);
721        assert!(result.is_none(), "Non-RIFF data should return None");
722
723        let _ = fs::remove_file(&tmp);
724    }
725
726    #[test]
727    fn test_read_aiff_basic() {
728        let tmp = std::env::temp_dir().join("test_bpm_aiff_basic.aiff");
729
730        let mut data = Vec::new();
731        data.extend_from_slice(b"FORM");
732        // total size placeholder — filled after building
733        data.extend_from_slice(&[0u8; 4]);
734        data.extend_from_slice(b"AIFF");
735
736        // COMM chunk: 18 bytes
737        data.extend_from_slice(b"COMM");
738        data.extend_from_slice(&18u32.to_be_bytes());
739        data.extend_from_slice(&1u16.to_be_bytes()); // channels = 1
740        data.extend_from_slice(&1000u32.to_be_bytes()); // num sample frames
741        data.extend_from_slice(&16u16.to_be_bytes()); // bits per sample
742                                                      // 80-bit extended for 44100 Hz:
743                                                      // exponent = 16383 + 15 = 16398 = 0x400E
744                                                      // mantissa high 32 bits = 44100 << 16 = 0xAC44_0000
745        data.extend_from_slice(&[0x40, 0x0E, 0xAC, 0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
746
747        // SSND chunk: 8 bytes header (offset + blockSize) + PCM data
748        let pcm_bytes = 1000 * 2; // 1000 frames, 16-bit mono
749        let ssnd_size = 8 + pcm_bytes;
750        data.extend_from_slice(b"SSND");
751        data.extend_from_slice(&(ssnd_size as u32).to_be_bytes());
752        data.extend_from_slice(&0u32.to_be_bytes()); // offset
753        data.extend_from_slice(&0u32.to_be_bytes()); // blockSize
754                                                     // 1000 frames of silence (big-endian 16-bit zeros)
755        data.extend_from_slice(&vec![0u8; pcm_bytes]);
756
757        // Fix FORM size
758        let form_size = (data.len() - 8) as u32;
759        data[4..8].copy_from_slice(&form_size.to_be_bytes());
760
761        fs::write(&tmp, &data).unwrap();
762
763        let result = read_aiff_pcm(&tmp);
764        assert!(result.is_some(), "Valid AIFF should parse successfully");
765        let (samples, sr) = result.unwrap();
766        assert_eq!(sr, 44100);
767        assert_eq!(samples.len(), 1000);
768        // All samples should be zero (silence)
769        assert!(samples.iter().all(|&s| s.abs() < 0.001));
770
771        let _ = fs::remove_file(&tmp);
772    }
773
774    #[test]
775    fn test_read_wav_with_extra_chunks() {
776        // WAV with a LIST chunk before data
777        let tmp = std::env::temp_dir().join("test_bpm_extrachunk.wav");
778        let pcm_data: Vec<u8> = (0..4410).flat_map(|_| 0i16.to_le_bytes()).collect();
779        let list_chunk = b"LIST\x04\x00\x00\x00INFO";
780        let data_size = pcm_data.len() as u32;
781        let file_size = 4 + 24 + 8 + list_chunk.len() as u32 + 8 + data_size;
782
783        let mut buf = Vec::new();
784        buf.extend_from_slice(b"RIFF");
785        buf.extend_from_slice(&file_size.to_le_bytes());
786        buf.extend_from_slice(b"WAVE");
787        buf.extend_from_slice(b"fmt ");
788        buf.extend_from_slice(&16u32.to_le_bytes());
789        buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
790        buf.extend_from_slice(&1u16.to_le_bytes()); // mono
791        buf.extend_from_slice(&44100u32.to_le_bytes());
792        buf.extend_from_slice(&88200u32.to_le_bytes()); // byte rate
793        buf.extend_from_slice(&2u16.to_le_bytes()); // block align
794        buf.extend_from_slice(&16u16.to_le_bytes()); // bits
795        buf.extend_from_slice(list_chunk);
796        buf.extend_from_slice(b"data");
797        buf.extend_from_slice(&data_size.to_le_bytes());
798        buf.extend_from_slice(&pcm_data);
799
800        fs::write(&tmp, buf).unwrap();
801
802        let result = read_wav_pcm(&tmp);
803        assert!(result.is_some());
804        let (samples, sr) = result.unwrap();
805        assert_eq!(sr, 44100);
806        assert_eq!(samples.len(), 4410);
807
808        let _ = fs::remove_file(&tmp);
809    }
810
811    #[test]
812    fn test_decode_symphonia_nonexistent() {
813        let result = decode_with_symphonia(Path::new("/nonexistent/file.mp3"));
814        assert!(result.is_none());
815    }
816
817    #[test]
818    fn test_decode_symphonia_invalid_data() {
819        let tmp = std::env::temp_dir().join("bpm_test_invalid.mp3");
820        fs::write(&tmp, b"this is not an mp3 file").unwrap();
821        let result = decode_with_symphonia(&tmp);
822        assert!(result.is_none(), "garbage data should return None");
823        let _ = fs::remove_file(&tmp);
824    }
825
826    #[test]
827    fn test_decode_symphonia_wav_fallback() {
828        // symphonia should be able to decode a valid WAV too
829        let tmp = std::env::temp_dir().join("bpm_test_sym_wav.wav");
830        let sr = 44100u32;
831        let n = 4410usize; // 0.1s
832        let mut buf = Vec::new();
833        let data_size = (n * 2) as u32;
834        buf.extend_from_slice(b"RIFF");
835        buf.extend_from_slice(&(36 + data_size).to_le_bytes());
836        buf.extend_from_slice(b"WAVE");
837        buf.extend_from_slice(b"fmt ");
838        buf.extend_from_slice(&16u32.to_le_bytes());
839        buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
840        buf.extend_from_slice(&1u16.to_le_bytes()); // mono
841        buf.extend_from_slice(&sr.to_le_bytes());
842        buf.extend_from_slice(&(sr * 2).to_le_bytes());
843        buf.extend_from_slice(&2u16.to_le_bytes());
844        buf.extend_from_slice(&16u16.to_le_bytes());
845        buf.extend_from_slice(b"data");
846        buf.extend_from_slice(&data_size.to_le_bytes());
847        for i in 0..n {
848            let t = i as f64 / sr as f64;
849            let s = ((t * 440.0 * 2.0 * std::f64::consts::PI).sin() * 16000.0) as i16;
850            buf.extend_from_slice(&s.to_le_bytes());
851        }
852        fs::write(&tmp, &buf).unwrap();
853
854        let result = decode_with_symphonia(&tmp);
855        assert!(result.is_some(), "symphonia should decode valid WAV");
856        let (samples, rate) = result.unwrap();
857        assert_eq!(rate, 44100);
858        assert!(!samples.is_empty());
859
860        let _ = fs::remove_file(&tmp);
861    }
862
863    #[test]
864    fn test_estimate_bpm_zero_length() {
865        // WAV with zero data samples
866        let tmp = std::env::temp_dir().join("bpm_test_zero.wav");
867        let mut buf = Vec::new();
868        buf.extend_from_slice(b"RIFF");
869        buf.extend_from_slice(&36u32.to_le_bytes());
870        buf.extend_from_slice(b"WAVE");
871        buf.extend_from_slice(b"fmt ");
872        buf.extend_from_slice(&16u32.to_le_bytes());
873        buf.extend_from_slice(&1u16.to_le_bytes());
874        buf.extend_from_slice(&1u16.to_le_bytes());
875        buf.extend_from_slice(&44100u32.to_le_bytes());
876        buf.extend_from_slice(&(44100u32 * 2).to_le_bytes());
877        buf.extend_from_slice(&2u16.to_le_bytes());
878        buf.extend_from_slice(&16u16.to_le_bytes());
879        buf.extend_from_slice(b"data");
880        buf.extend_from_slice(&0u32.to_le_bytes());
881        fs::write(&tmp, &buf).unwrap();
882
883        let bpm = estimate_bpm(tmp.to_str().unwrap());
884        assert!(bpm.is_none(), "zero-length audio should not estimate BPM");
885        let _ = fs::remove_file(&tmp);
886    }
887
888    #[test]
889    fn test_read_aiff_nonexistent() {
890        let result = read_aiff_pcm(Path::new("/nonexistent/file.aiff"));
891        assert!(result.is_none());
892    }
893
894    #[test]
895    fn test_read_aiff_invalid_header() {
896        let tmp = std::env::temp_dir().join("bpm_test_bad_aiff.aiff");
897        fs::write(&tmp, b"not an aiff file at all").unwrap();
898        let result = read_aiff_pcm(&tmp);
899        assert!(result.is_none());
900        let _ = fs::remove_file(&tmp);
901    }
902
903    #[test]
904    fn test_bpm_rounding_snaps_to_integer() {
905        // Values within 0.15 of integer should snap
906        let round = |v: f64| -> f64 {
907            let rounded = (v * 10.0).round() / 10.0;
908            let nearest = rounded.round();
909            if (rounded - nearest).abs() < 0.15 {
910                nearest
911            } else {
912                rounded
913            }
914        };
915        assert_eq!(round(120.08), 120.0);
916        assert_eq!(round(119.92), 120.0);
917        assert_eq!(round(128.14), 128.0);
918        assert_eq!(round(128.0), 128.0);
919        // Values beyond 0.15 keep decimal
920        assert_eq!(round(127.5), 127.5);
921        assert_eq!(round(135.3), 135.3);
922        assert_eq!(round(99.8), 99.8);
923    }
924}