mtmd-audio.cpp 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. #include "mtmd-audio.h"
  2. #define _USE_MATH_DEFINES // for M_PI
  3. #include <cmath>
  4. #include <cstdint>
  5. #include <cstring>
  6. #include <thread>
  7. #include <vector>
  8. #include <fstream>
  9. #include <algorithm>
  10. // some of the code here is copied from whisper.cpp
  11. constexpr bool DEBUG = false;
  12. void mtmd_audio_cache::fill_sin_cos_table(int n) {
  13. sin_vals.resize(n);
  14. cos_vals.resize(n);
  15. for (int i = 0; i < n; i++) {
  16. double theta = (2 * M_PI * i) / n;
  17. sin_vals[i] = sinf(theta);
  18. cos_vals[i] = cosf(theta);
  19. }
  20. }
  21. void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
  22. hann_window.resize(length);
  23. int offset = -1;
  24. if (periodic) {
  25. offset = 0;
  26. }
  27. for (int i = 0; i < length; i++) {
  28. hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
  29. }
  30. }
  31. void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel,
  32. int n_fft,
  33. int sample_rate,
  34. float fmin,
  35. float fmax,
  36. bool slaney_area_norm,
  37. float scale) {
  38. GGML_ASSERT(n_mel > 0 && n_fft > 1);
  39. if (fmax <= 0.0f) {
  40. fmax = 0.5f * sample_rate;
  41. }
  42. // Slaney scale (matches librosa default)
  43. const double min_log_hz = 1000.0;
  44. const double lin_slope = 3 / 200.;
  45. const double min_log_mel = min_log_hz * lin_slope;
  46. const double log_step = log(6.4) / 27.0;
  47. auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
  48. return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
  49. };
  50. auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
  51. return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
  52. };
  53. // infer N_fft from n_fft_bins
  54. const double bin_hz_step = double(sample_rate) / double(n_fft);
  55. // mel grid: n_mel + 2 edges
  56. const double m_lo = hz_to_mel(fmin);
  57. const double m_hi = hz_to_mel(fmax);
  58. std::vector<double> mel_pts(n_mel + 2);
  59. for (int i = 0; i < n_mel + 2; ++i) {
  60. mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
  61. }
  62. // convert to Hz
  63. std::vector<double> hz_pts(n_mel + 2);
  64. for (int i = 0; i < n_mel + 2; ++i) {
  65. hz_pts[i] = mel_to_hz(mel_pts[i]);
  66. }
  67. const int n_fft_bins = n_fft / 2 + 1;
  68. // filterbank
  69. std::vector<float> out(n_mel * n_fft_bins, 0);
  70. for (int m = 0; m < n_mel; ++m) {
  71. const double f_left = hz_pts[m];
  72. const double f_center = hz_pts[m + 1];
  73. const double f_right = hz_pts[m + 2];
  74. const double denom_l = std::max(1e-30, f_center - f_left);
  75. const double denom_r = std::max(1e-30, f_right - f_center);
  76. const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
  77. for (int k = 0; k < n_fft_bins; ++k) {
  78. const double f = k * bin_hz_step;
  79. double w = 0.0;
  80. if (f >= f_left && f <= f_center) {
  81. w = (f - f_left) / denom_l;
  82. } else if (f > f_center && f <= f_right) {
  83. w = (f_right - f) / denom_r;
  84. }
  85. out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
  86. }
  87. }
  88. filters.n_mel = n_mel;
  89. filters.n_fft = n_fft;
  90. filters.data = std::move(out);
  91. if (DEBUG) { // debug
  92. for (size_t i = 0; i < filters.data.size(); ++i) {
  93. if (filters.data[i] != 0.0f) {
  94. printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
  95. }
  96. }
  97. }
  98. }
  99. // Unified DFT implementation for both forward and inverse transforms
  100. // Template parameters:
  101. // Inverse: false = DFT with exp(-2πi·k·n/N), no scaling
  102. // true = IDFT with exp(+2πi·k·n/N), scales by 1/N
  103. // RealInput: true = input is real-valued (stride 1), avoids imaginary computations
  104. // false = input is complex-valued (interleaved real/imag, stride 2)
  105. template <bool Inverse, bool RealInput>
  106. static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) {
  107. const int n_sin_cos_vals = cache.sin_vals.size();
  108. const int sin_cos_step = n_sin_cos_vals / N;
  109. constexpr float sign = Inverse ? 1.0f : -1.0f;
  110. const float scale = Inverse ? (1.0f / N) : 1.0f;
  111. for (int k = 0; k < N; k++) {
  112. float re = 0;
  113. float im = 0;
  114. for (int n = 0; n < N; n++) {
  115. int idx = (k * n * sin_cos_step) % n_sin_cos_vals;
  116. float cos_val = cache.cos_vals[idx];
  117. float sin_val = cache.sin_vals[idx];
  118. if constexpr (RealInput) {
  119. // Real input: in_im = 0, simplifies to:
  120. // re += in_re * cos_val
  121. // im += sign * in_re * sin_val
  122. float in_re = in[n];
  123. re += in_re * cos_val;
  124. im += sign * in_re * sin_val;
  125. } else {
  126. float in_re = in[n * 2 + 0];
  127. float in_im = in[n * 2 + 1];
  128. // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i
  129. re += in_re * cos_val - sign * in_im * sin_val;
  130. im += sign * in_re * sin_val + in_im * cos_val;
  131. }
  132. }
  133. out[k * 2 + 0] = re * scale;
  134. out[k * 2 + 1] = im * scale;
  135. }
  136. }
  137. // Cooley-Tukey FFT/IFFT unified implementation
  138. // Template parameters:
  139. // Inverse: false = FFT with exp(-2πi·k/N), no scaling
  140. // true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level
  141. // RealInput: true = input is real-valued (stride 1)
  142. // false = input is complex-valued (interleaved real/imag, stride 2)
  143. template <bool Inverse, bool RealInput>
  144. static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
  145. const int n_sin_cos_vals = cache.sin_vals.size();
  146. if (N == 1) {
  147. out[0] = in[0];
  148. if constexpr (RealInput) {
  149. out[1] = 0.0f;
  150. } else {
  151. out[1] = in[1];
  152. }
  153. return;
  154. }
  155. const int half_N = N / 2;
  156. if (N - half_N * 2 == 1) {
  157. // Odd N: fall back to DFT
  158. dft_impl<Inverse, RealInput>(cache, in, N, out);
  159. return;
  160. }
  161. // Split into even and odd
  162. if constexpr (RealInput) {
  163. // Real input: stride is 1, copy only real values
  164. float * even = in + N;
  165. for (int i = 0; i < half_N; ++i) {
  166. even[i] = in[2 * i];
  167. }
  168. float * even_fft = out + 2 * N;
  169. fft_impl<Inverse, true>(cache, even, half_N, even_fft);
  170. float * odd = even;
  171. for (int i = 0; i < half_N; ++i) {
  172. odd[i] = in[2 * i + 1];
  173. }
  174. float * odd_fft = even_fft + N;
  175. fft_impl<Inverse, true>(cache, odd, half_N, odd_fft);
  176. } else {
  177. // Complex input: stride is 2, copy complex pairs
  178. float * even = in + N * 2;
  179. for (int i = 0; i < half_N; ++i) {
  180. even[i * 2 + 0] = in[2 * i * 2 + 0];
  181. even[i * 2 + 1] = in[2 * i * 2 + 1];
  182. }
  183. float * even_fft = out + 2 * N;
  184. fft_impl<Inverse, false>(cache, even, half_N, even_fft);
  185. float * odd = even;
  186. for (int i = 0; i < half_N; ++i) {
  187. odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0];
  188. odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1];
  189. }
  190. float * odd_fft = even_fft + N;
  191. fft_impl<Inverse, false>(cache, odd, half_N, odd_fft);
  192. }
  193. float * even_fft = out + 2 * N;
  194. float * odd_fft = even_fft + N;
  195. const int sin_cos_step = n_sin_cos_vals / N;
  196. constexpr float sign = Inverse ? 1.0f : -1.0f;
  197. constexpr float scale = Inverse ? 0.5f : 1.0f;
  198. for (int k = 0; k < half_N; k++) {
  199. int idx = k * sin_cos_step; // t = 2*M_PI*k/N
  200. float re = cache.cos_vals[idx];
  201. float im = sign * cache.sin_vals[idx];
  202. float re_odd = odd_fft[2 * k + 0];
  203. float im_odd = odd_fft[2 * k + 1];
  204. out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd);
  205. out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd);
  206. out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd);
  207. out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd);
  208. }
  209. }
  210. // Forward FFT for real input (used by mel spectrogram)
  211. static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
  212. fft_impl<false, true>(cache, in, N, out);
  213. }
  214. // Inverse FFT for complex input
  215. static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
  216. fft_impl<true, false>(cache, in, N, out);
  217. }
  218. struct filter_params {
  219. int32_t n_mel;
  220. int32_t n_fft_bins;
  221. int32_t hann_window_size;
  222. int32_t hop_length;
  223. int32_t sample_rate;
  224. bool center_padding = false;
  225. float preemph = 0.f;
  226. bool use_natural_log = false;
  227. bool norm_per_feature = false;
  228. };
  229. static void log_mel_spectrogram_worker_thread(int ith,
  230. const float * hann,
  231. const std::vector<float> & samples,
  232. int n_samples,
  233. int frame_size,
  234. int frame_step,
  235. int n_threads,
  236. const filter_params & params,
  237. const mtmd_audio_cache & cache,
  238. mtmd_audio_mel & out) {
  239. std::vector<float> fft_in(frame_size * 2, 0.0);
  240. std::vector<float> fft_out(frame_size * 2 * 2 * 2);
  241. int n_fft_bins = params.n_fft_bins;
  242. int i = ith;
  243. const auto & filters = cache.filters;
  244. // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
  245. GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
  246. GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size());
  247. // calculate FFT only when fft_in are not all zero
  248. for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
  249. const int offset = i * frame_step;
  250. // apply Hann window (~10% faster)
  251. for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
  252. fft_in[j] = hann[j] * samples[offset + j];
  253. }
  254. // fill the rest with zeros
  255. if (n_samples - offset < frame_size) {
  256. std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
  257. }
  258. // FFT
  259. fft(cache, fft_in.data(), frame_size, fft_out.data());
  260. // Calculate modulus^2 of complex numbers
  261. // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
  262. for (int j = 0; j < n_fft_bins; j++) {
  263. fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
  264. }
  265. // mel spectrogram
  266. for (int j = 0; j < out.n_mel; j++) {
  267. double sum = 0.0;
  268. // unroll loop (suggested by GH user @lunixbochs)
  269. int k = 0;
  270. for (k = 0; k < n_fft_bins - 3; k += 4) {
  271. size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k);
  272. sum +=
  273. fft_out[k + 0] * filters.data[idx + 0] +
  274. fft_out[k + 1] * filters.data[idx + 1] +
  275. fft_out[k + 2] * filters.data[idx + 2] +
  276. fft_out[k + 3] * filters.data[idx + 3];
  277. }
  278. // handle n_fft remainder
  279. for (; k < n_fft_bins; k++) {
  280. sum += fft_out[k] * filters.data[j * n_fft_bins + k];
  281. }
  282. sum = params.use_natural_log
  283. ? log(sum + 5.960464477539063e-08)
  284. : log10(std::max(sum, 1e-10));
  285. out.data[j * out.n_len + i] = sum;
  286. }
  287. }
  288. // Otherwise fft_out are all zero
  289. double sum = params.use_natural_log ? log(1e-10) : log10(1e-10);
  290. for (; i < out.n_len; i += n_threads) {
  291. for (int j = 0; j < out.n_mel; j++) {
  292. out.data[j * out.n_len + i] = sum;
  293. }
  294. }
  295. }
  296. // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
  297. static bool log_mel_spectrogram(
  298. const float * samples,
  299. const int n_samples_in,
  300. const int n_threads,
  301. const filter_params & params,
  302. const mtmd_audio_cache & cache,
  303. mtmd_audio_mel & out) {
  304. //const int64_t t_start_us = ggml_time_us();
  305. out.n_len_org = n_samples_in;
  306. int n_samples = n_samples_in;
  307. // Hann window
  308. const float * hann = cache.hann_window.data();
  309. const int frame_size = (params.n_fft_bins - 1) * 2;
  310. const int frame_step = params.hop_length;
  311. // Padding
  312. std::vector<float> samples_padded;
  313. if (params.center_padding) {
  314. const auto pad_amount = frame_size / 2;
  315. samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
  316. std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
  317. samples = samples_padded.data();
  318. n_samples = samples_padded.size();
  319. } else {
  320. // existing padding logic
  321. int64_t stage_1_pad = params.sample_rate * 30;
  322. int64_t stage_2_pad = frame_size / 2;
  323. samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
  324. std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
  325. // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
  326. std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
  327. // reflective pad 200 samples at the beginning of audio
  328. if (n_samples < stage_2_pad + 1) {
  329. // TODO: Handle short audio differently or return error
  330. return false;
  331. }
  332. std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
  333. }
  334. // preemphasis
  335. if (params.preemph) {
  336. const int pad_amount = frame_size / 2;
  337. const float preemph = 0.97f;
  338. float prev = samples_padded[pad_amount];
  339. for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
  340. float cur = samples_padded[i];
  341. samples_padded[i] = cur - preemph * prev;
  342. prev = cur;
  343. }
  344. }
  345. // pad hann window if it's smaller than frame_size
  346. // TODO: probably unnecessary here? (or better doing it in g_cache?)
  347. std::vector<float> hann_window_padded;
  348. if (params.hann_window_size < frame_size) {
  349. hann_window_padded.resize(frame_size);
  350. const int padding = (frame_size - params.hann_window_size) / 2;
  351. std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]);
  352. hann = hann_window_padded.data();
  353. }
  354. out.n_mel = params.n_mel;
  355. out.n_len = (n_samples - frame_size) / frame_step + 1;
  356. // TODO: handle these checks better
  357. if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) {
  358. LOG_ERR("%s: size overflow\n", __func__);
  359. return false;
  360. }
  361. if (n_samples < frame_size) {
  362. LOG_ERR("%s: not enough samples after padding\n", __func__);
  363. return false;
  364. }
  365. out.data.resize(out.n_mel * out.n_len);
  366. {
  367. std::vector<std::thread> workers(n_threads - 1);
  368. for (int iw = 0; iw < n_threads - 1; ++iw) {
  369. workers[iw] =
  370. std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples,
  371. frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out));
  372. }
  373. // main thread
  374. log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params,
  375. cache, out);
  376. for (int iw = 0; iw < n_threads - 1; ++iw) {
  377. workers[iw].join();
  378. }
  379. }
  380. const int effective_n_len = n_samples_in / frame_step;
  381. if (params.norm_per_feature) {
  382. for (int i = 0; i < out.n_mel; i++) {
  383. double mean = 0;
  384. for (int j = 0; j < effective_n_len; ++j) {
  385. mean += out.data[i * out.n_len + j];
  386. }
  387. mean /= effective_n_len;
  388. double var = 0.0;
  389. for (int j = 0; j < effective_n_len; ++j) {
  390. const double value = out.data[i * out.n_len + j] - mean;
  391. var += value * value;
  392. }
  393. var /= effective_n_len - 1; // unbiased
  394. const double mstd = std::sqrt(var + 1e-5);
  395. for (int j = 0; j < effective_n_len; ++j) {
  396. auto &value = out.data[i * out.n_len + j];
  397. value = (value - mean) / mstd;
  398. }
  399. // pad the rest with zeros
  400. for (int j = effective_n_len; j < out.n_len; ++j) {
  401. out.data[i * out.n_len + j] = 0.0;
  402. }
  403. }
  404. } else {
  405. // clamping and normalization
  406. double mmax = -1e20;
  407. for (int i = 0; i < out.n_mel*out.n_len; i++) {
  408. if (out.data[i] > mmax) {
  409. mmax = out.data[i];
  410. }
  411. }
  412. mmax -= 8.0;
  413. for (int i = 0; i < out.n_mel*out.n_len; i++) {
  414. if (out.data[i] < mmax) {
  415. out.data[i] = mmax;
  416. }
  417. out.data[i] = (out.data[i] + 4.0)/4.0;
  418. }
  419. }
  420. // Dump log_mel_spectrogram
  421. if (DEBUG) {
  422. std::ofstream outFile("log_mel_spectrogram.json");
  423. outFile << "[";
  424. for (uint64_t i = 0; i < out.data.size() - 1; i++) {
  425. outFile << out.data[i] << ", ";
  426. }
  427. outFile << out.data[out.data.size() - 1] << "]";
  428. outFile.close();
  429. }
  430. return true;
  431. }
  432. //
  433. // mtmd_audio_preprocessor_whisper
  434. //
  435. void mtmd_audio_preprocessor_whisper::initialize() {
  436. cache.fill_sin_cos_table(hparams.audio_n_fft);
  437. cache.fill_hann_window(hparams.audio_window_len, true);
  438. cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
  439. }
  440. bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples,
  441. size_t n_samples,
  442. std::vector<mtmd_audio_mel> & output) {
  443. if (n_samples == 0) {
  444. // empty audio
  445. return false;
  446. }
  447. std::vector<float> smpl;
  448. // if input is too short, pad with zeros
  449. // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
  450. // TODO: maybe handle this better
  451. size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
  452. if (n_samples < min_samples) {
  453. smpl.resize(min_samples, 0.0f);
  454. std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
  455. samples = smpl.data();
  456. n_samples = smpl.size();
  457. }
  458. filter_params params;
  459. params.n_mel = hparams.n_mel_bins;
  460. params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
  461. params.hann_window_size = hparams.audio_window_len;
  462. params.hop_length = hparams.audio_hop_len;
  463. params.sample_rate = hparams.audio_sample_rate;
  464. params.center_padding = false;
  465. params.preemph = 0.0f; // disabled
  466. params.use_natural_log = false;
  467. params.norm_per_feature = false;
  468. // make sure the cache is initialized
  469. GGML_ASSERT(!cache.sin_vals.empty());
  470. GGML_ASSERT(!cache.cos_vals.empty());
  471. GGML_ASSERT(!cache.filters.data.empty());
  472. mtmd_audio_mel out_full;
  473. bool ok = log_mel_spectrogram(samples, n_samples,
  474. 4, // n_threads
  475. params, cache, out_full);
  476. if (!ok) {
  477. return false;
  478. }
  479. // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
  480. // we always expect the mel to have 3000 silent frames at the end
  481. if (DEBUG) {
  482. printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
  483. }
  484. const size_t frames_per_chunk = 3000;
  485. GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk);
  486. for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) {
  487. int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off);
  488. if ((size_t) n_len < frames_per_chunk) {
  489. break; // last uncomplete chunk will always be a padded chunk, safe to ignore
  490. }
  491. mtmd_audio_mel out_chunk;
  492. out_chunk.n_len = n_len;
  493. out_chunk.n_mel = out_full.n_mel;
  494. out_chunk.n_len_org = out_full.n_mel; // unused
  495. out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
  496. for (int i = 0; i < out_full.n_mel; i++) {
  497. auto src = out_full.data.begin() + i * out_full.n_len + off;
  498. out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
  499. }
  500. output.push_back(std::move(out_chunk));
  501. }
  502. return true;
  503. }
  504. //
  505. // mtmd_audio_preprocessor_conformer
  506. //
  507. void mtmd_audio_preprocessor_conformer::initialize() {
  508. cache.fill_sin_cos_table(hparams.audio_n_fft);
  509. cache.fill_hann_window(hparams.audio_window_len, true);
  510. cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
  511. }
  512. bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples,
  513. size_t n_samples,
  514. std::vector<mtmd_audio_mel> & output) {
  515. // empty audio
  516. if (n_samples == 0) {
  517. return false;
  518. }
  519. filter_params params;
  520. params.n_mel = hparams.n_mel_bins;
  521. params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
  522. params.hann_window_size = hparams.audio_window_len;
  523. params.hop_length = hparams.audio_hop_len;
  524. params.sample_rate = hparams.audio_sample_rate;
  525. params.center_padding = true;
  526. params.preemph = 0.97f;
  527. params.use_natural_log = true;
  528. params.norm_per_feature = true;
  529. // make sure the cache is initialized
  530. GGML_ASSERT(!cache.sin_vals.empty());
  531. GGML_ASSERT(!cache.cos_vals.empty());
  532. GGML_ASSERT(!cache.filters.data.empty());
  533. mtmd_audio_mel out_full;
  534. bool ok = log_mel_spectrogram(samples, n_samples,
  535. 4, // n_threads
  536. params, cache, out_full);
  537. if (!ok) {
  538. return false;
  539. }
  540. output.push_back(std::move(out_full));
  541. return true;
  542. }
  543. //
  544. // mtmd_audio_streaming_istft implementation
  545. //
  546. mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) :
  547. n_fft(n_fft),
  548. hop_length(hop_length),
  549. n_fft_bins(n_fft / 2 + 1),
  550. overlap_buffer(n_fft, 0.0f),
  551. window_sum_buffer(n_fft, 0.0f),
  552. padding_to_remove((n_fft - hop_length) / 2),
  553. ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
  554. ifft_out(n_fft * 2 * 4, 0.0f) {
  555. cache.fill_sin_cos_table(n_fft);
  556. cache.fill_hann_window(n_fft, true);
  557. }
  558. void mtmd_audio_streaming_istft::reset() {
  559. std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f);
  560. std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f);
  561. padding_to_remove = (n_fft - hop_length) / 2;
  562. }
  563. std::vector<float> mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) {
  564. std::vector<float> output(hop_length);
  565. // copy frequencies
  566. for (int j = 0; j < n_fft_bins; j++) {
  567. ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0];
  568. ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1];
  569. }
  570. // mirror negative frequencies
  571. for (int j = 1; j < n_fft_bins - 1; j++) {
  572. int mirror_idx = n_fft - j;
  573. ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0];
  574. ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate
  575. }
  576. ifft(cache, ifft_in.data(), n_fft, ifft_out.data());
  577. // update window sum and overlap buffer
  578. for (int j = 0; j < n_fft; j++) {
  579. window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j];
  580. overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j];
  581. }
  582. // extract hop_length samples with normalization
  583. for (int i = 0; i < hop_length; i++) {
  584. if (window_sum_buffer[i] > 1e-8f) {
  585. output[i] = overlap_buffer[i] / window_sum_buffer[i];
  586. } else {
  587. output[i] = overlap_buffer[i];
  588. }
  589. }
  590. // shift buffers left by hop_length
  591. std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin());
  592. std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f);
  593. std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin());
  594. std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f);
  595. // Remove padding if needed
  596. int to_remove = std::min(padding_to_remove, (int) output.size());
  597. padding_to_remove -= to_remove;
  598. output.erase(output.begin(), output.begin() + to_remove);
  599. return output;
  600. }
  601. std::vector<float> mtmd_audio_streaming_istft::flush() {
  602. std::vector<float> output;
  603. // Extract remaining samples from overlap buffer
  604. // Continue until we've extracted all meaningful samples
  605. int remaining = n_fft - hop_length;
  606. while (remaining > 0) {
  607. int chunk_size = std::min(remaining, hop_length);
  608. for (int i = 0; i < chunk_size; i++) {
  609. float sample;
  610. if (window_sum_buffer[i] > 1e-8f) {
  611. sample = overlap_buffer[i] / window_sum_buffer[i];
  612. } else {
  613. sample = overlap_buffer[i];
  614. }
  615. output.push_back(sample);
  616. }
  617. // Shift buffers
  618. std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin());
  619. std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f);
  620. std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin());
  621. std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f);
  622. remaining -= chunk_size;
  623. }
  624. return output;
  625. }