1use std::path::Path;
8
9const NOTE_NAMES: [&str; 12] = [
11 "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B",
12];
13
14const MAJOR_PROFILE: [f64; 12] = [
16 6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88,
17];
18
19const MINOR_PROFILE: [f64; 12] = [
21 6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17,
22];
23
24pub fn detect_key(file_path: &str) -> Option<String> {
27 let path = Path::new(file_path);
28 let ext = path
29 .extension()
30 .and_then(|e| e.to_str())
31 .unwrap_or("")
32 .to_lowercase();
33
34 let (samples, sample_rate) = match ext.as_str() {
35 "wav" => crate::bpm::read_wav_pcm_pub(path)?,
36 "aiff" | "aif" => crate::bpm::read_aiff_pcm_pub(path)?,
37 "mp3" | "flac" | "ogg" | "m4a" | "aac" | "opus" => {
38 crate::bpm::decode_with_symphonia_pub(path)?
39 }
40 _ => return None,
41 };
42
43 if samples.len() < 4096 || sample_rate == 0 {
44 return None;
45 }
46
47 let max_samples = (sample_rate as usize) * 30;
49 let s = if samples.len() > max_samples {
50 &samples[..max_samples]
51 } else {
52 &samples
53 };
54
55 let chroma = compute_chromagram(s, sample_rate);
56
57 let total_energy: f64 = chroma.iter().sum();
59 if total_energy < 1e-10 {
60 return None;
61 }
62
63 match_key_profile(&chroma)
64}
65
66fn compute_chromagram(samples: &[f32], sample_rate: u32) -> [f64; 12] {
72 let sr = sample_rate as f64;
73 let mut chroma = [0.0f64; 12];
74
75 let frame_size = 4096usize;
78 let hop = frame_size / 2;
79 let num_frames = (samples.len().saturating_sub(frame_size)) / hop;
80
81 if num_frames == 0 {
82 return chroma;
83 }
84
85 let base_freq = 32.7032; let mut targets: Vec<(usize, f64)> = Vec::new(); for octave in 0..7 {
90 for note in 0..12 {
91 let freq = base_freq * 2.0f64.powi(octave) * 2.0f64.powf(note as f64 / 12.0);
92 if freq < sr / 2.0 && freq > 20.0 {
93 targets.push((note, freq));
94 }
95 }
96 }
97
98 let hann: Vec<f64> = (0..frame_size)
100 .map(|i| {
101 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (frame_size - 1) as f64).cos())
102 })
103 .collect();
104 let mut windowed = vec![0.0f64; frame_size];
105
106 for frame_idx in 0..num_frames {
108 let start = frame_idx * hop;
109 let end = (start + frame_size).min(samples.len());
110 let n = end - start;
111
112 for i in 0..n {
114 windowed[i] = samples[start + i] as f64 * hann[i];
115 }
116
117 for &(chroma_bin, freq) in &targets {
118 let power = goertzel(&windowed[..n], freq, sr);
119 chroma[chroma_bin] += power;
120 }
121 }
122
123 let max_val = chroma.iter().cloned().fold(0.0f64, f64::max);
125 if max_val > 1e-10 {
126 for c in chroma.iter_mut() {
127 *c /= max_val;
128 }
129 }
130
131 chroma
132}
133
134fn goertzel(samples: &[f64], target_freq: f64, sample_rate: f64) -> f64 {
137 let n = samples.len();
138 let k = (target_freq * n as f64 / sample_rate).round();
139 let w = 2.0 * std::f64::consts::PI * k / n as f64;
140 let coeff = 2.0 * w.cos();
141
142 let mut s1 = 0.0f64;
143 let mut s2 = 0.0f64;
144
145 for &sample in samples {
146 let s0 = sample + coeff * s1 - s2;
147 s2 = s1;
148 s1 = s0;
149 }
150
151 s1 * s1 + s2 * s2 - coeff * s1 * s2
153}
154
155fn match_key_profile(chroma: &[f64; 12]) -> Option<String> {
158 let mut best_key = String::new();
159 let mut best_corr = f64::NEG_INFINITY;
160
161 for (root, note) in NOTE_NAMES.iter().enumerate() {
162 let major_corr = profile_correlation(chroma, &MAJOR_PROFILE, root);
164 if major_corr > best_corr {
165 best_corr = major_corr;
166 best_key = format!("{note} Major");
167 }
168
169 let minor_corr = profile_correlation(chroma, &MINOR_PROFILE, root);
170 if minor_corr > best_corr {
171 best_corr = minor_corr;
172 best_key = format!("{note} Minor");
173 }
174 }
175
176 if best_key.is_empty() || best_corr < 0.0 {
177 return None;
178 }
179
180 Some(best_key)
181}
182
183fn profile_correlation(chroma: &[f64; 12], profile: &[f64; 12], root: usize) -> f64 {
185 let n = 12.0;
186 let mut sum_x = 0.0;
187 let mut sum_y = 0.0;
188 let mut sum_xy = 0.0;
189 let mut sum_x2 = 0.0;
190 let mut sum_y2 = 0.0;
191
192 for i in 0..12 {
193 let x = chroma[(i + root) % 12];
194 let y = profile[i];
195 sum_x += x;
196 sum_y += y;
197 sum_xy += x * y;
198 sum_x2 += x * x;
199 sum_y2 += y * y;
200 }
201
202 let numerator = n * sum_xy - sum_x * sum_y;
203 let denominator = ((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y)).sqrt();
204
205 if denominator < 1e-10 {
206 return 0.0;
207 }
208
209 numerator / denominator
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_goertzel_440hz() {
218 let sr = 44100.0;
220 let n = 4096;
221 let samples: Vec<f64> = (0..n)
222 .map(|i| (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sr).sin())
223 .collect();
224
225 let power_440 = goertzel(&samples, 440.0, sr);
226 let power_300 = goertzel(&samples, 300.0, sr);
227 assert!(
228 power_440 > power_300 * 10.0,
229 "440Hz should have much more energy than 300Hz: 440={}, 300={}",
230 power_440,
231 power_300
232 );
233 }
234
235 #[test]
236 fn test_goertzel_261hz_c4() {
237 let sr = 44100.0;
238 let n = 4096;
239 let freq = 261.63; let samples: Vec<f64> = (0..n)
241 .map(|i| (2.0 * std::f64::consts::PI * freq * i as f64 / sr).sin())
242 .collect();
243
244 let power_c = goertzel(&samples, freq, sr);
245 let power_e = goertzel(&samples, 329.63, sr); assert!(power_c > power_e * 5.0, "C4 should dominate over E4");
247 }
248
249 #[test]
250 fn test_chromagram_pure_a() {
251 let sr = 44100u32;
253 let n = sr as usize * 2; let samples: Vec<f32> = (0..n)
255 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin())
256 .collect();
257
258 let chroma = compute_chromagram(&samples, sr);
259 let a_energy = chroma[9];
261 let max_other = chroma
262 .iter()
263 .enumerate()
264 .filter(|&(i, _)| i != 9)
265 .map(|(_, &v)| v)
266 .fold(0.0f64, f64::max);
267 assert!(
268 a_energy > max_other,
269 "A (440Hz) should have strongest chroma bin. A={}, max_other={}",
270 a_energy,
271 max_other
272 );
273 }
274
275 #[test]
276 fn test_chromagram_pure_c() {
277 let sr = 44100u32;
279 let n = sr as usize * 2;
280 let samples: Vec<f32> = (0..n)
281 .map(|i| (2.0 * std::f32::consts::PI * 261.63 * i as f32 / sr as f32).sin())
282 .collect();
283
284 let chroma = compute_chromagram(&samples, sr);
285 let c_energy = chroma[0];
286 let max_other = chroma
287 .iter()
288 .enumerate()
289 .filter(|&(i, _)| i != 0)
290 .map(|(_, &v)| v)
291 .fold(0.0f64, f64::max);
292 assert!(
293 c_energy > max_other,
294 "C should have strongest chroma bin. C={}, max_other={}",
295 c_energy,
296 max_other
297 );
298 }
299
300 #[test]
301 fn test_match_c_major_triad() {
302 let mut chroma = [0.1f64; 12];
304 chroma[0] = 1.0; chroma[4] = 0.8; chroma[7] = 0.7; let key = match_key_profile(&chroma);
308 assert!(key.is_some());
309 let key = key.unwrap();
310 assert!(
311 key.contains("C") && key.contains("Major"),
312 "C-E-G triad should match C Major, got {}",
313 key
314 );
315 }
316
317 #[test]
318 fn test_match_a_minor_triad() {
319 let mut chroma = [0.1f64; 12];
321 chroma[9] = 1.0; chroma[0] = 0.8; chroma[4] = 0.7; let key = match_key_profile(&chroma);
325 assert!(key.is_some());
326 let key = key.unwrap();
327 assert!(
329 key.contains("Minor") || key.contains("Major"),
330 "A-C-E should match A Minor or C Major, got {}",
331 key
332 );
333 }
334
335 #[test]
336 fn test_detect_key_nonexistent() {
337 let key = detect_key("/nonexistent/file.wav");
338 assert!(key.is_none());
339 }
340
341 #[test]
342 fn test_detect_key_unsupported() {
343 let key = detect_key("/some/file.txt");
344 assert!(key.is_none());
345 }
346
347 #[test]
348 fn test_detect_key_silence() {
349 let tmp = std::env::temp_dir().join("key_test_silence.wav");
350 let sr = 44100u32;
351 let samples = vec![0.0f32; sr as usize * 2];
352 write_test_wav(&tmp, &samples, sr);
353 let key = detect_key(tmp.to_str().unwrap());
354 assert!(key.is_none(), "silence should not detect a key");
355 let _ = std::fs::remove_file(&tmp);
356 }
357
358 #[test]
359 fn test_detect_key_a440_wav() {
360 let tmp = std::env::temp_dir().join("key_test_a440.wav");
361 let sr = 44100u32;
362 let n = sr as usize * 3; let samples: Vec<f32> = (0..n)
364 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin() * 0.8)
365 .collect();
366 write_test_wav(&tmp, &samples, sr);
367 let key = detect_key(tmp.to_str().unwrap());
368 assert!(key.is_some(), "should detect key for 440Hz sine");
369 let key = key.unwrap();
370 assert!(
372 key.contains('A'),
373 "440Hz should detect A-related key, got {}",
374 key
375 );
376 let _ = std::fs::remove_file(&tmp);
377 }
378
379 #[test]
380 fn test_detect_key_c_major_chord() {
381 let tmp = std::env::temp_dir().join("key_test_cmaj.wav");
382 let sr = 44100u32;
383 let n = sr as usize * 3;
384 let samples: Vec<f32> = (0..n)
386 .map(|i| {
387 let t = i as f32 / sr as f32;
388 let c = (2.0 * std::f32::consts::PI * 261.63 * t).sin();
389 let e = (2.0 * std::f32::consts::PI * 329.63 * t).sin();
390 let g = (2.0 * std::f32::consts::PI * 392.00 * t).sin();
391 (c + e + g) * 0.3
392 })
393 .collect();
394 write_test_wav(&tmp, &samples, sr);
395 let key = detect_key(tmp.to_str().unwrap());
396 assert!(key.is_some(), "should detect key for C major chord");
397 let key = key.unwrap();
398 assert!(
400 key.contains('C') || key.contains('A'),
401 "C major chord should detect C Major or A Minor, got {}",
402 key
403 );
404 let _ = std::fs::remove_file(&tmp);
405 }
406
407 #[test]
408 fn test_profile_correlation_perfect_match() {
409 let chroma: [f64; 12] = MAJOR_PROFILE;
411 let corr = profile_correlation(&chroma, &MAJOR_PROFILE, 0);
412 assert!(
413 (corr - 1.0).abs() < 0.001,
414 "perfect match should be ~1.0, got {}",
415 corr
416 );
417 }
418
419 #[test]
420 fn test_profile_correlation_shifted() {
421 let mut chroma = [0.0f64; 12];
423 for i in 0..12 {
424 chroma[(i + 7) % 12] = MAJOR_PROFILE[i];
425 }
426 let corr = profile_correlation(&chroma, &MAJOR_PROFILE, 7);
427 assert!(
428 (corr - 1.0).abs() < 0.001,
429 "G major should perfectly match shifted profile, got {}",
430 corr
431 );
432 }
433
434 #[test]
435 fn test_goertzel_single_sample_is_finite() {
436 let samples = vec![1.0f64];
438 let p = goertzel(&samples, 440.0, 44100.0);
439 assert!(
440 p.is_finite(),
441 "goertzel power should be finite for n=1, got {}",
442 p
443 );
444 }
445
446 #[test]
447 fn test_goertzel_near_zero_for_absent_frequency() {
448 let sr = 44100.0;
450 let n = 4096;
451 let samples: Vec<f64> = (0..n)
452 .map(|i| (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sr).sin())
453 .collect();
454
455 let power_440 = goertzel(&samples, 440.0, sr);
456 let power_261 = goertzel(&samples, 261.63, sr);
457 assert!(
458 power_261 < power_440 * 0.01,
459 "261Hz should have <1% of 440Hz energy: 261={}, 440={}",
460 power_261,
461 power_440
462 );
463 }
464
465 #[test]
466 fn test_chromagram_chord_c_major() {
467 let sr = 44100u32;
469 let n = sr as usize * 2;
470 let samples: Vec<f32> = (0..n)
471 .map(|i| {
472 let t = i as f32 / sr as f32;
473 let c = (2.0 * std::f32::consts::PI * 261.63 * t).sin();
474 let e = (2.0 * std::f32::consts::PI * 329.63 * t).sin();
475 let g = (2.0 * std::f32::consts::PI * 392.00 * t).sin();
476 (c + e + g) * 0.3
477 })
478 .collect();
479
480 let chroma = compute_chromagram(&samples, sr);
481 let mut indexed: Vec<(usize, f64)> =
483 chroma.iter().enumerate().map(|(i, &v)| (i, v)).collect();
484 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
485 let top3: Vec<usize> = indexed[..3].iter().map(|&(i, _)| i).collect();
486 assert!(top3.contains(&0), "C should be in top 3, got {:?}", top3);
487 assert!(top3.contains(&4), "E should be in top 3, got {:?}", top3);
488 assert!(top3.contains(&7), "G should be in top 3, got {:?}", top3);
489 }
490
491 #[test]
492 fn test_detect_key_high_sample_rate() {
493 let tmp = std::env::temp_dir().join("key_test_96k.wav");
495 let sr = 96000u32;
496 let n = sr as usize * 2;
497 let samples: Vec<f32> = (0..n)
498 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin() * 0.8)
499 .collect();
500 write_test_wav(&tmp, &samples, sr);
501 let key = detect_key(tmp.to_str().unwrap());
502 assert!(key.is_some(), "should detect key at 96kHz");
503 let _ = std::fs::remove_file(&tmp);
504 }
505
506 #[test]
507 fn test_detect_key_low_sample_rate() {
508 let tmp = std::env::temp_dir().join("key_test_8k.wav");
510 let sr = 8000u32;
511 let n = sr as usize * 2;
512 let samples: Vec<f32> = (0..n)
513 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin() * 0.8)
514 .collect();
515 write_test_wav(&tmp, &samples, sr);
516 let _ = detect_key(tmp.to_str().unwrap());
518 let _ = std::fs::remove_file(&tmp);
519 }
520
521 #[test]
522 fn test_detect_key_multi_octave_a() {
523 let tmp = std::env::temp_dir().join("key_test_multi_oct.wav");
525 let sr = 44100u32;
526 let n = sr as usize * 3;
527 let samples: Vec<f32> = (0..n)
528 .map(|i| {
529 let t = i as f32 / sr as f32;
530 let a3 = (2.0 * std::f32::consts::PI * 220.0 * t).sin();
531 let a4 = (2.0 * std::f32::consts::PI * 440.0 * t).sin();
532 let a5 = (2.0 * std::f32::consts::PI * 880.0 * t).sin();
533 (a3 + a4 + a5) * 0.25
534 })
535 .collect();
536 write_test_wav(&tmp, &samples, sr);
537 let key = detect_key(tmp.to_str().unwrap());
538 assert!(key.is_some(), "should detect key for multi-octave A");
539 let key = key.unwrap();
540 assert!(
541 key.contains('A'),
542 "multi-octave A should detect A-related key, got {}",
543 key
544 );
545 let _ = std::fs::remove_file(&tmp);
546 }
547
548 #[test]
549 fn test_chromagram_bins_bounded() {
550 let sr = 44100u32;
552 let n = sr as usize * 2;
553 let samples: Vec<f32> = (0..n)
554 .map(|i| {
555 let t = i as f32 / sr as f32;
556 (t * 261.63 * 2.0 * std::f32::consts::PI).sin() * 0.3
558 + (t * 440.0 * 2.0 * std::f32::consts::PI).sin() * 0.3
559 + (t * 783.99 * 2.0 * std::f32::consts::PI).sin() * 0.2
560 })
561 .collect();
562 let chroma = compute_chromagram(&samples, sr);
563 for (i, &v) in chroma.iter().enumerate() {
564 assert!(
565 (0.0..=1.0).contains(&v),
566 "chroma bin {} should be [0,1], got {}",
567 i,
568 v
569 );
570 }
571 }
572
573 #[test]
574 fn test_profile_correlation_identical_chroma_and_profile_is_one() {
575 let chroma = MAJOR_PROFILE;
576 let r = profile_correlation(&chroma, &MAJOR_PROFILE, 0);
577 assert!((r - 1.0).abs() < 1e-9, "expected 1.0, got {}", r);
578 }
579
580 #[test]
581 fn test_profile_correlation_rotated_chroma_matches_at_correct_root() {
582 let mut chroma = [0.0f64; 12];
584 for k in 0..12 {
585 chroma[k] = MAJOR_PROFILE[(k + 5) % 12];
586 }
587 let r = profile_correlation(&chroma, &MAJOR_PROFILE, 7);
588 assert!(
589 (r - 1.0).abs() < 1e-9,
590 "rotated major should match at root 7, got {}",
591 r
592 );
593 }
594
595 #[test]
596 fn test_profile_correlation_constant_chroma_zero_variance_returns_zero() {
597 let c = [0.25f64; 12];
598 let r = profile_correlation(&c, &MAJOR_PROFILE, 0);
599 assert_eq!(r, 0.0);
600 }
601
602 #[test]
603 fn test_profile_correlation_zero_chroma_returns_zero() {
604 let z = [0.0f64; 12];
605 let r = profile_correlation(&z, &MAJOR_PROFILE, 0);
606 assert_eq!(r, 0.0, "zero variance in x → Pearson denom 0 → 0.0");
607 }
608
609 #[test]
610 fn test_compute_chromagram_too_short_for_one_frame_returns_zeros() {
611 let sr = 44100u32;
612 let samples: Vec<f32> = vec![0.5; 3000];
613 let chroma = compute_chromagram(&samples, sr);
614 assert!(
615 chroma.iter().all(|&v| v == 0.0),
616 "len < frame_size gives num_frames=0 → zero chroma"
617 );
618 }
619
620 #[test]
621 fn test_match_key_profile_all_zero_chroma_picks_first_tie_at_zero_correlation() {
622 let z = [0.0f64; 12];
623 let key = match_key_profile(&z);
624 assert_eq!(key.as_deref(), Some("C Major"));
625 }
626
627 fn write_test_wav(path: &Path, samples: &[f32], sample_rate: u32) {
628 let n = samples.len() as u32;
629 let data_size = n * 2;
630 let mut buf = Vec::with_capacity(44 + data_size as usize);
631 buf.extend_from_slice(b"RIFF");
632 buf.extend_from_slice(&(36 + data_size).to_le_bytes());
633 buf.extend_from_slice(b"WAVE");
634 buf.extend_from_slice(b"fmt ");
635 buf.extend_from_slice(&16u32.to_le_bytes());
636 buf.extend_from_slice(&1u16.to_le_bytes());
637 buf.extend_from_slice(&1u16.to_le_bytes());
638 buf.extend_from_slice(&sample_rate.to_le_bytes());
639 buf.extend_from_slice(&(sample_rate * 2).to_le_bytes());
640 buf.extend_from_slice(&2u16.to_le_bytes());
641 buf.extend_from_slice(&16u16.to_le_bytes());
642 buf.extend_from_slice(b"data");
643 buf.extend_from_slice(&data_size.to_le_bytes());
644 for &s in samples {
645 let i = (s.clamp(-1.0, 1.0) * 32767.0) as i16;
646 buf.extend_from_slice(&i.to_le_bytes());
647 }
648 std::fs::write(path, buf).unwrap();
649 }
650}