1use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AudioFingerprint {
14 pub path: String,
15 pub rms: f64,
16 pub spectral_centroid: f64,
17 pub zero_crossing_rate: f64,
18 pub low_band_energy: f64,
19 pub mid_band_energy: f64,
20 pub high_band_energy: f64,
21 pub low_energy_ratio: f64,
22 pub attack_time: f64,
23}
24
25pub fn compute_fingerprint(file_path: &str) -> Option<AudioFingerprint> {
27 let path = Path::new(file_path);
28 let ext = path
29 .extension()
30 .and_then(|e| e.to_str())
31 .unwrap_or("")
32 .to_lowercase();
33
34 let (samples, sample_rate) = match ext.as_str() {
35 "wav" => crate::bpm::read_wav_pcm_pub(path)?,
36 "aiff" | "aif" => crate::bpm::read_aiff_pcm_pub(path)?,
37 "mp3" | "flac" | "ogg" | "m4a" | "aac" | "opus" => {
38 crate::bpm::decode_with_symphonia_pub(path)?
39 }
40 _ => return None,
41 };
42
43 if samples.len() < 1024 || sample_rate == 0 {
44 return None;
45 }
46
47 let max_samples = (sample_rate as usize) * 10;
49 let s = if samples.len() > max_samples {
50 &samples[..max_samples]
51 } else {
52 &samples
53 };
54 let n = s.len() as f64;
55 let sr = sample_rate as f64;
56
57 let rms = (s.iter().map(|&x| (x as f64) * (x as f64)).sum::<f64>() / n).sqrt();
59
60 let mut zc = 0usize;
62 for i in 1..s.len() {
63 if (s[i] >= 0.0) != (s[i - 1] >= 0.0) {
64 zc += 1;
65 }
66 }
67 let zero_crossing_rate = zc as f64 / n;
68
69 let spectral_centroid = zero_crossing_rate;
72
73 let frame_size = (sr / 50.0) as usize; let mut low_e = 0.0f64;
78 let mut mid_e = 0.0f64;
79 let mut high_e = 0.0f64;
80 let mut frame_count = 0usize;
81
82 for chunk in s.chunks(frame_size) {
83 if chunk.len() < 4 {
84 continue;
85 }
86 frame_count += 1;
87
88 let mut lp = 0.0f32;
92 let alpha_low = 0.05f32; let alpha_high = 0.7f32; let mut low_sum = 0.0f64;
95 let mut mid_sum = 0.0f64;
96 let mut high_sum = 0.0f64;
97
98 let mut lp_slow = 0.0f32;
99 for &x in chunk {
100 lp = lp + alpha_low * (x - lp); lp_slow = lp_slow + alpha_high * (x - lp_slow); let low = lp as f64;
103 let mid = (lp_slow - lp) as f64;
104 let high = (x - lp_slow) as f64;
105 low_sum += low * low;
106 mid_sum += mid * mid;
107 high_sum += high * high;
108 }
109 low_e += low_sum / chunk.len() as f64;
110 mid_e += mid_sum / chunk.len() as f64;
111 high_e += high_sum / chunk.len() as f64;
112 }
113
114 let _fc = frame_count.max(1) as f64;
115 let total_e = (low_e + mid_e + high_e).max(1e-10);
116 let low_band_energy = low_e / total_e;
117 let mid_band_energy = mid_e / total_e;
118 let high_band_energy = high_e / total_e;
119
120 let mut frame_energies = Vec::new();
122 for chunk in s.chunks(1024) {
123 let e: f64 =
124 chunk.iter().map(|&x| (x as f64) * (x as f64)).sum::<f64>() / chunk.len() as f64;
125 frame_energies.push(e);
126 }
127 let avg_energy = frame_energies.iter().sum::<f64>() / frame_energies.len().max(1) as f64;
128 let low_energy_ratio = frame_energies.iter().filter(|&&e| e < avg_energy).count() as f64
129 / frame_energies.len().max(1) as f64;
130
131 let env_size = 256;
133 let mut envelope = Vec::new();
134 for chunk in s.chunks(env_size) {
135 let peak = chunk.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
136 envelope.push(peak as f64);
137 }
138 let peak_val = envelope.iter().cloned().fold(0.0f64, f64::max).max(1e-10);
139 let attack_threshold = peak_val * 0.9;
140 let attack_time = envelope
141 .iter()
142 .position(|&e| e >= attack_threshold)
143 .map(|i| i as f64 * env_size as f64 / sr)
144 .unwrap_or(1.0);
145
146 Some(AudioFingerprint {
147 path: file_path.to_string(),
148 rms,
149 spectral_centroid,
150 zero_crossing_rate,
151 low_band_energy,
152 mid_band_energy,
153 high_band_energy,
154 low_energy_ratio,
155 attack_time,
156 })
157}
158
159pub fn fingerprint_distance(a: &AudioFingerprint, b: &AudioFingerprint) -> f64 {
161 let norm = |va: f64, vb: f64, max: f64| -> f64 {
162 let da = va / max.max(1e-10);
163 let db = vb / max.max(1e-10);
164 (da - db) * (da - db)
165 };
166
167 let d = norm(a.rms, b.rms, 1.0)
168 + norm(a.spectral_centroid, b.spectral_centroid, 0.5)
169 + norm(a.zero_crossing_rate, b.zero_crossing_rate, 0.5)
170 + norm(a.low_band_energy, b.low_band_energy, 1.0)
171 + norm(a.mid_band_energy, b.mid_band_energy, 1.0)
172 + norm(a.high_band_energy, b.high_band_energy, 1.0)
173 + norm(a.low_energy_ratio, b.low_energy_ratio, 1.0)
174 + norm(a.attack_time, b.attack_time, 2.0);
175
176 d.sqrt()
177}
178
179pub fn find_similar(
181 reference: &AudioFingerprint,
182 candidates: &[AudioFingerprint],
183 max_results: usize,
184) -> Vec<(String, f64)> {
185 let mut scored: Vec<(String, f64)> = candidates
186 .iter()
187 .filter(|c| c.path != reference.path)
188 .map(|c| (c.path.clone(), fingerprint_distance(reference, c)))
189 .collect();
190
191 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
192 scored.truncate(max_results);
193 scored
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 fn make_fp(
201 path: &str,
202 rms: f64,
203 centroid: f64,
204 zcr: f64,
205 low: f64,
206 mid: f64,
207 high: f64,
208 ) -> AudioFingerprint {
209 AudioFingerprint {
210 path: path.to_string(),
211 rms,
212 spectral_centroid: centroid,
213 zero_crossing_rate: zcr,
214 low_band_energy: low,
215 mid_band_energy: mid,
216 high_band_energy: high,
217 low_energy_ratio: 0.5,
218 attack_time: 0.1,
219 }
220 }
221
222 #[test]
223 fn test_identical_fingerprints_zero_distance() {
224 let a = make_fp("a.wav", 0.5, 0.1, 0.1, 0.6, 0.3, 0.1);
225 let b = make_fp("b.wav", 0.5, 0.1, 0.1, 0.6, 0.3, 0.1);
226 let d = fingerprint_distance(&a, &b);
227 assert!(
228 d < 0.001,
229 "identical fingerprints should have ~0 distance, got {}",
230 d
231 );
232 }
233
234 #[test]
235 fn test_fingerprint_distance_zero_rms_identical() {
236 let a = make_fp("a.wav", 0.0, 0.1, 0.1, 0.5, 0.3, 0.2);
237 let b = make_fp("b.wav", 0.0, 0.1, 0.1, 0.5, 0.3, 0.2);
238 let d = fingerprint_distance(&a, &b);
239 assert!(
240 d < 1e-9,
241 "RMS norm uses max(1e-10); both zero RMS should match, got {}",
242 d
243 );
244 }
245
246 #[test]
247 fn test_different_fingerprints_nonzero_distance() {
248 let kick = make_fp("kick.wav", 0.8, 0.02, 0.05, 0.9, 0.08, 0.02);
249 let hihat = make_fp("hihat.wav", 0.3, 0.4, 0.4, 0.05, 0.15, 0.8);
250 let d = fingerprint_distance(&kick, &hihat);
251 assert!(
252 d > 0.5,
253 "kick and hihat should be very different, got {}",
254 d
255 );
256 }
257
258 #[test]
259 fn test_similar_kicks_closer_than_kick_hihat() {
260 let kick1 = make_fp("kick1.wav", 0.8, 0.02, 0.05, 0.9, 0.08, 0.02);
261 let kick2 = make_fp("kick2.wav", 0.75, 0.03, 0.06, 0.85, 0.1, 0.05);
262 let hihat = make_fp("hihat.wav", 0.3, 0.4, 0.4, 0.05, 0.15, 0.8);
263
264 let d_kicks = fingerprint_distance(&kick1, &kick2);
265 let d_kick_hihat = fingerprint_distance(&kick1, &hihat);
266 assert!(
267 d_kicks < d_kick_hihat,
268 "similar kicks ({}) should be closer than kick-hihat ({})",
269 d_kicks,
270 d_kick_hihat
271 );
272 }
273
274 #[test]
275 fn test_find_similar_returns_sorted() {
276 let reference = make_fp("ref.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
277 let close = make_fp("close.wav", 0.48, 0.11, 0.11, 0.48, 0.32, 0.2);
278 let far = make_fp("far.wav", 0.1, 0.4, 0.4, 0.05, 0.15, 0.8);
279 let medium = make_fp("medium.wav", 0.6, 0.15, 0.15, 0.4, 0.35, 0.25);
280
281 let results = find_similar(&reference, &[close, far, medium], 10);
282 assert_eq!(results.len(), 3);
283 assert_eq!(results[0].0, "close.wav");
284 assert_eq!(results[2].0, "far.wav");
285 }
286
287 #[test]
288 fn test_find_similar_excludes_self() {
289 let a = make_fp("a.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
290 let same = make_fp("a.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
291 let results = find_similar(&a, &[same], 10);
292 assert_eq!(results.len(), 0, "should exclude self from results");
293 }
294
295 #[test]
296 fn test_find_similar_max_results() {
297 let reference = make_fp("ref.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
298 let candidates: Vec<_> = (0..50)
299 .map(|i| {
300 make_fp(
301 &format!("s{}.wav", i),
302 0.5 + i as f64 * 0.01,
303 0.1,
304 0.1,
305 0.5,
306 0.3,
307 0.2,
308 )
309 })
310 .collect();
311 let results = find_similar(&reference, &candidates, 5);
312 assert_eq!(results.len(), 5);
313 }
314
315 #[test]
316 fn test_compute_fingerprint_nonexistent_file() {
317 let fp = compute_fingerprint("/nonexistent/file.wav");
318 assert!(fp.is_none());
319 }
320
321 #[test]
322 fn test_compute_fingerprint_unsupported_format() {
323 let fp = compute_fingerprint("/some/file.txt");
324 assert!(fp.is_none());
325 }
326
327 #[test]
328 fn test_compute_fingerprint_wav() {
329 let tmp = std::env::temp_dir().join("sim_test.wav");
331 let sample_rate = 44100u32;
332 let num_samples = sample_rate as usize; let mut data = vec![0u8; 44 + num_samples * 2];
334 data[0..4].copy_from_slice(b"RIFF");
335 let file_size = (36 + num_samples * 2) as u32;
336 data[4..8].copy_from_slice(&file_size.to_le_bytes());
337 data[8..12].copy_from_slice(b"WAVE");
338 data[12..16].copy_from_slice(b"fmt ");
339 data[16..20].copy_from_slice(&16u32.to_le_bytes());
340 data[20..22].copy_from_slice(&1u16.to_le_bytes()); data[22..24].copy_from_slice(&1u16.to_le_bytes()); data[24..28].copy_from_slice(&sample_rate.to_le_bytes());
343 data[28..32].copy_from_slice(&(sample_rate * 2).to_le_bytes());
344 data[32..34].copy_from_slice(&2u16.to_le_bytes());
345 data[34..36].copy_from_slice(&16u16.to_le_bytes());
346 data[36..40].copy_from_slice(b"data");
347 data[40..44].copy_from_slice(&(num_samples as u32 * 2).to_le_bytes());
348 for i in 0..num_samples {
350 let t = i as f64 / sample_rate as f64;
351 let sample = (t * 440.0 * 2.0 * std::f64::consts::PI).sin() * 16000.0;
352 let s = sample as i16;
353 let offset = 44 + i * 2;
354 data[offset..offset + 2].copy_from_slice(&s.to_le_bytes());
355 }
356 std::fs::write(&tmp, &data).unwrap();
357
358 let fp = compute_fingerprint(tmp.to_str().unwrap());
359 assert!(fp.is_some(), "should compute fingerprint for valid WAV");
360 let fp = fp.unwrap();
361 assert!(fp.rms > 0.0, "RMS should be positive");
362 assert!(fp.spectral_centroid > 0.0, "centroid should be positive");
363 assert!(
364 fp.spectral_centroid <= 1.0,
365 "centroid should be normalized to [0,1], got {}",
366 fp.spectral_centroid
367 );
368 assert!(fp.zero_crossing_rate <= 1.0, "ZCR should be <= 1.0");
369 assert!(
370 fp.low_band_energy >= 0.0 && fp.low_band_energy <= 1.0,
371 "band energy should be [0,1]"
372 );
373
374 let _ = std::fs::remove_file(&tmp);
375 }
376
377 #[test]
379 fn test_compute_fingerprint_wav_deterministic_repeat_reads() {
380 let tmp = std::env::temp_dir().join("sim_test_deterministic.wav");
381 let sample_rate = 44100u32;
382 let num_samples = sample_rate as usize;
383 let mut data = vec![0u8; 44 + num_samples * 2];
384 data[0..4].copy_from_slice(b"RIFF");
385 let file_size = (36 + num_samples * 2) as u32;
386 data[4..8].copy_from_slice(&file_size.to_le_bytes());
387 data[8..12].copy_from_slice(b"WAVE");
388 data[12..16].copy_from_slice(b"fmt ");
389 data[16..20].copy_from_slice(&16u32.to_le_bytes());
390 data[20..22].copy_from_slice(&1u16.to_le_bytes());
391 data[22..24].copy_from_slice(&1u16.to_le_bytes());
392 data[24..28].copy_from_slice(&sample_rate.to_le_bytes());
393 data[28..32].copy_from_slice(&(sample_rate * 2).to_le_bytes());
394 data[32..34].copy_from_slice(&2u16.to_le_bytes());
395 data[34..36].copy_from_slice(&16u16.to_le_bytes());
396 data[36..40].copy_from_slice(b"data");
397 data[40..44].copy_from_slice(&(num_samples as u32 * 2).to_le_bytes());
398 for i in 0..num_samples {
399 let t = i as f64 / sample_rate as f64;
400 let sample = (t * 220.0 * 2.0 * std::f64::consts::PI).sin() * 12000.0;
401 let s = sample as i16;
402 let offset = 44 + i * 2;
403 data[offset..offset + 2].copy_from_slice(&s.to_le_bytes());
404 }
405 std::fs::write(&tmp, &data).unwrap();
406 let path = tmp.to_str().unwrap();
407 let a = compute_fingerprint(path).expect("first read");
408 let b = compute_fingerprint(path).expect("second read");
409 let d = fingerprint_distance(&a, &b);
410 assert!(
411 d < 1e-9,
412 "same WAV twice should give identical features, distance={}",
413 d
414 );
415 let _ = std::fs::remove_file(&tmp);
416 }
417
418 #[test]
419 fn test_fingerprint_all_zeros() {
420 let a = make_fp("z.wav", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
422 let b = make_fp("z2.wav", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
423 let d = fingerprint_distance(&a, &b);
424 assert!(
425 d < 0.001,
426 "all-zero fingerprints should have ~0 distance, got {}",
427 d
428 );
429 }
430
431 #[test]
432 fn test_fingerprint_distance_symmetric() {
433 let a = make_fp("a.wav", 0.8, 0.02, 0.05, 0.9, 0.08, 0.02);
434 let b = make_fp("b.wav", 0.3, 0.4, 0.4, 0.05, 0.15, 0.8);
435 let d_ab = fingerprint_distance(&a, &b);
436 let d_ba = fingerprint_distance(&b, &a);
437 assert!((d_ab - d_ba).abs() < 1e-10, "distance should be symmetric");
438 }
439
440 #[test]
441 fn test_fingerprint_distance_finite_and_nonnegative() {
442 let a = make_fp("a.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
443 let b = make_fp("b.wav", 0.3, 0.4, 0.4, 0.05, 0.15, 0.8);
444 let d = fingerprint_distance(&a, &b);
445 assert!(
446 d >= 0.0 && d.is_finite(),
447 "distance must be finite and ≥0, got {}",
448 d
449 );
450 }
451
452 #[test]
453 fn test_fingerprint_distance_attack_time_contributes_when_other_features_match() {
454 let mut a = make_fp("a.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
455 let mut b = make_fp("b.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
456 a.attack_time = 0.05;
457 b.attack_time = 1.5;
458 let d = fingerprint_distance(&a, &b);
459 assert!(
460 d > 0.01,
461 "attack_time norm uses divisor 2.0; large gap should move distance, got {}",
462 d
463 );
464 }
465
466 #[test]
467 fn test_audio_fingerprint_json_roundtrip() {
468 let fp = make_fp("x.wav", 0.4, 0.2, 0.15, 0.5, 0.25, 0.1);
469 let json = serde_json::to_string(&fp).unwrap();
470 let back: AudioFingerprint = serde_json::from_str(&json).unwrap();
471 assert_eq!(back.path, fp.path);
472 assert!((back.rms - fp.rms).abs() < 1e-9);
473 assert!((back.low_band_energy - fp.low_band_energy).abs() < 1e-9);
474 }
475
476 #[test]
477 fn test_find_similar_empty_candidates() {
478 let reference = make_fp("ref.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
479 let results = find_similar(&reference, &[], 10);
480 assert!(results.is_empty());
481 }
482
483 #[test]
484 fn test_find_similar_max_results_zero() {
485 let reference = make_fp("ref.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
486 let close = make_fp("close.wav", 0.48, 0.11, 0.11, 0.48, 0.32, 0.2);
487 let results = find_similar(&reference, &[close], 0);
488 assert!(results.is_empty());
489 }
490
491 #[test]
492 fn test_find_similar_single_candidate() {
493 let reference = make_fp("ref.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
494 let candidate = make_fp("c.wav", 0.6, 0.12, 0.12, 0.45, 0.35, 0.2);
495 let results = find_similar(&reference, &[candidate], 10);
496 assert_eq!(results.len(), 1);
497 assert_eq!(results[0].0, "c.wav");
498 }
499
500 #[test]
501 fn test_find_similar_keeps_duplicate_candidate_paths() {
502 let reference = make_fp("ref.wav", 0.5, 0.1, 0.1, 0.5, 0.3, 0.2);
503 let c = make_fp("dup.wav", 0.55, 0.11, 0.11, 0.48, 0.32, 0.2);
504 let results = find_similar(&reference, &[c.clone(), c], 10);
505 assert_eq!(results.len(), 2);
506 assert_eq!(results[0].0, "dup.wav");
507 assert_eq!(results[1].0, "dup.wav");
508 }
509
510 #[test]
511 fn test_compute_fingerprint_silence_wav() {
512 let tmp = std::env::temp_dir().join("sim_test_silence.wav");
514 let sample_rate = 44100u32;
515 let num_samples = sample_rate as usize;
516 let mut data = vec![0u8; 44 + num_samples * 2];
517 data[0..4].copy_from_slice(b"RIFF");
518 let file_size = (36 + num_samples * 2) as u32;
519 data[4..8].copy_from_slice(&file_size.to_le_bytes());
520 data[8..12].copy_from_slice(b"WAVE");
521 data[12..16].copy_from_slice(b"fmt ");
522 data[16..20].copy_from_slice(&16u32.to_le_bytes());
523 data[20..22].copy_from_slice(&1u16.to_le_bytes());
524 data[22..24].copy_from_slice(&1u16.to_le_bytes());
525 data[24..28].copy_from_slice(&sample_rate.to_le_bytes());
526 data[28..32].copy_from_slice(&(sample_rate * 2).to_le_bytes());
527 data[32..34].copy_from_slice(&2u16.to_le_bytes());
528 data[34..36].copy_from_slice(&16u16.to_le_bytes());
529 data[36..40].copy_from_slice(b"data");
530 data[40..44].copy_from_slice(&(num_samples as u32 * 2).to_le_bytes());
531 std::fs::write(&tmp, &data).unwrap();
533
534 let fp = compute_fingerprint(tmp.to_str().unwrap());
535 assert!(fp.is_some(), "should handle silent WAV");
536 let fp = fp.unwrap();
537 assert!(fp.rms < 0.001, "silent file should have near-zero RMS");
538
539 let _ = std::fs::remove_file(&tmp);
540 }
541
542 #[test]
543 fn test_compute_fingerprint_very_short_wav() {
544 let tmp = std::env::temp_dir().join("sim_test_short.wav");
546 let sample_rate = 44100u32;
547 let num_samples = 100usize;
548 let mut data = vec![0u8; 44 + num_samples * 2];
549 data[0..4].copy_from_slice(b"RIFF");
550 let file_size = (36 + num_samples * 2) as u32;
551 data[4..8].copy_from_slice(&file_size.to_le_bytes());
552 data[8..12].copy_from_slice(b"WAVE");
553 data[12..16].copy_from_slice(b"fmt ");
554 data[16..20].copy_from_slice(&16u32.to_le_bytes());
555 data[20..22].copy_from_slice(&1u16.to_le_bytes());
556 data[22..24].copy_from_slice(&1u16.to_le_bytes());
557 data[24..28].copy_from_slice(&sample_rate.to_le_bytes());
558 data[28..32].copy_from_slice(&(sample_rate * 2).to_le_bytes());
559 data[32..34].copy_from_slice(&2u16.to_le_bytes());
560 data[34..36].copy_from_slice(&16u16.to_le_bytes());
561 data[36..40].copy_from_slice(b"data");
562 data[40..44].copy_from_slice(&(num_samples as u32 * 2).to_le_bytes());
563 for i in 0..num_samples {
564 let s = (i as i16).wrapping_mul(100);
565 let offset = 44 + i * 2;
566 data[offset..offset + 2].copy_from_slice(&s.to_le_bytes());
567 }
568 std::fs::write(&tmp, &data).unwrap();
569
570 let fp = compute_fingerprint(tmp.to_str().unwrap());
572 let _ = std::fs::remove_file(&tmp);
574 let _ = fp;
576 }
577}