mtmd-audio.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. #include "mtmd-audio.h"
  2. #define _USE_MATH_DEFINES // for M_PI
  3. #include <cmath>
  4. #include <cstdint>
  5. #include <cstring>
  6. #include <thread>
  7. #include <vector>
  8. #include <fstream>
  9. #include <algorithm>
  10. // most of the code here is copied from whisper.cpp
  11. constexpr bool DEBUG = false;
  12. struct mtmd_audio_mel_filters {
  13. int32_t n_mel;
  14. int32_t n_fft;
  15. std::vector<float> data;
  16. };
  17. // note: this global cache is shared among all preprocessors
  18. // if we want to use multiple preprocessors at the same time,
  19. // we will need to enclose it in the preprocessor class in the future
  20. static struct mtmd_audio_global_cache {
  21. // precomputed sin/cos table for FFT
  22. std::vector<float> sin_vals;
  23. std::vector<float> cos_vals;
  24. // hann window
  25. std::vector<float> hann_window;
  26. // mel filter bank
  27. mtmd_audio_mel_filters filters;
  28. void fill_sin_cos_table(int n) {
  29. sin_vals.resize(n);
  30. cos_vals.resize(n);
  31. for (int i = 0; i < n; i++) {
  32. double theta = (2 * M_PI * i) / n;
  33. sin_vals[i] = sinf(theta);
  34. cos_vals[i] = cosf(theta);
  35. }
  36. }
  37. void fill_hann_window(int length, bool periodic) {
  38. hann_window.resize(length);
  39. int offset = -1;
  40. if (periodic) {
  41. offset = 0;
  42. }
  43. for (int i = 0; i < length; i++) {
  44. hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
  45. }
  46. }
  47. // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
  48. // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
  49. void fill_mel_filterbank_matrix(
  50. int n_mel,
  51. int n_fft,
  52. int sample_rate, // e.g. 16000
  53. float fmin = 0.0f, // e.g. 0.0
  54. float fmax = -1.0f, // e.g. sr/2; pass -1 for auto
  55. bool slaney_area_norm = true,
  56. float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code
  57. ) {
  58. GGML_ASSERT(n_mel > 0 && n_fft > 1);
  59. if (fmax <= 0.0f) {
  60. fmax = 0.5f * sample_rate;
  61. }
  62. // Slaney scale (matches librosa default)
  63. const double min_log_hz = 1000.0;
  64. const double lin_slope = 3 / 200.;
  65. const double min_log_mel = min_log_hz * lin_slope;
  66. const double log_step = log(6.4) / 27.0;
  67. auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
  68. return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
  69. };
  70. auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
  71. return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
  72. };
  73. // infer N_fft from n_fft_bins
  74. const double bin_hz_step = double(sample_rate) / double(n_fft);
  75. // mel grid: n_mel + 2 edges
  76. const double m_lo = hz_to_mel(fmin);
  77. const double m_hi = hz_to_mel(fmax);
  78. std::vector<double> mel_pts(n_mel + 2);
  79. for (int i = 0; i < n_mel + 2; ++i) {
  80. mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
  81. }
  82. // convert to Hz
  83. std::vector<double> hz_pts(n_mel + 2);
  84. for (int i = 0; i < n_mel + 2; ++i) {
  85. hz_pts[i] = mel_to_hz(mel_pts[i]);
  86. }
  87. const int n_fft_bins = n_fft / 2 + 1;
  88. // filterbank
  89. std::vector<float> out(n_mel * n_fft_bins, 0);
  90. for (int m = 0; m < n_mel; ++m) {
  91. const double f_left = hz_pts[m];
  92. const double f_center = hz_pts[m + 1];
  93. const double f_right = hz_pts[m + 2];
  94. const double denom_l = std::max(1e-30, f_center - f_left);
  95. const double denom_r = std::max(1e-30, f_right - f_center);
  96. const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
  97. for (int k = 0; k < n_fft_bins; ++k) {
  98. const double f = k * bin_hz_step;
  99. double w = 0.0;
  100. if (f >= f_left && f <= f_center) {
  101. w = (f - f_left) / denom_l;
  102. } else if (f > f_center && f <= f_right) {
  103. w = (f_right - f) / denom_r;
  104. }
  105. out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
  106. }
  107. }
  108. filters.n_mel = n_mel;
  109. filters.n_fft = n_fft;
  110. filters.data = std::move(out);
  111. if (DEBUG) { // debug
  112. for (size_t i = 0; i < filters.data.size(); ++i) {
  113. if (filters.data[i] != 0.0f) {
  114. printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
  115. }
  116. }
  117. }
  118. }
  119. } g_cache;
  120. // naive Discrete Fourier Transform
  121. // input is real-valued
  122. // output is complex-valued
  123. static void dft(const float * in, int N, float * out) {
  124. const int n_sin_cos_vals = g_cache.sin_vals.size();
  125. const int sin_cos_step = n_sin_cos_vals / N;
  126. for (int k = 0; k < N; k++) {
  127. float re = 0;
  128. float im = 0;
  129. for (int n = 0; n < N; n++) {
  130. int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N
  131. re += in[n] * g_cache.cos_vals[idx]; // cos(t)
  132. im -= in[n] * g_cache.sin_vals[idx]; // sin(t)
  133. }
  134. out[k*2 + 0] = re;
  135. out[k*2 + 1] = im;
  136. }
  137. }
  138. // Cooley-Tukey FFT
  139. // poor man's implementation - use something better
  140. // input is real-valued
  141. // output is complex-valued
  142. static void fft(float * in, int N, float * out) {
  143. const int n_sin_cos_vals = g_cache.sin_vals.size();
  144. if (N == 1) {
  145. out[0] = in[0];
  146. out[1] = 0;
  147. return;
  148. }
  149. const int half_N = N / 2;
  150. if (N - half_N*2 == 1) {
  151. dft(in, N, out);
  152. return;
  153. }
  154. float* even = in + N;
  155. for (int i = 0; i < half_N; ++i) {
  156. even[i]= in[2*i];
  157. }
  158. float* even_fft = out + 2 * N;
  159. fft(even, half_N, even_fft);
  160. float* odd = even;
  161. for (int i = 0; i < half_N; ++i) {
  162. odd[i] = in[2*i + 1];
  163. }
  164. float* odd_fft = even_fft + N;
  165. fft(odd, half_N, odd_fft);
  166. const int sin_cos_step = n_sin_cos_vals / N;
  167. for (int k = 0; k < half_N; k++) {
  168. int idx = k * sin_cos_step; // t = 2*M_PI*k/N
  169. float re = g_cache.cos_vals[idx]; // cos(t)
  170. float im = -g_cache.sin_vals[idx]; // sin(t)
  171. float re_odd = odd_fft[2*k + 0];
  172. float im_odd = odd_fft[2*k + 1];
  173. out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
  174. out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
  175. out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
  176. out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
  177. }
  178. }
  179. struct filter_params {
  180. int32_t n_mel;
  181. int32_t n_fft_bins;
  182. int32_t hann_window_size;
  183. int32_t hop_length;
  184. int32_t sample_rate;
  185. bool center_padding = false;
  186. float preemph = 0.f;
  187. bool use_natural_log = false;
  188. bool norm_per_feature = false;
  189. };
  190. static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
  191. int n_samples, int frame_size, int frame_step, int n_threads,
  192. const filter_params & params, mtmd_audio_mel & out) {
  193. std::vector<float> fft_in(frame_size * 2, 0.0);
  194. std::vector<float> fft_out(frame_size * 2 * 2 * 2);
  195. int n_fft_bins = params.n_fft_bins;
  196. int i = ith;
  197. const auto & filters = g_cache.filters;
  198. // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
  199. GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
  200. GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size());
  201. // calculate FFT only when fft_in are not all zero
  202. for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
  203. const int offset = i * frame_step;
  204. // apply Hann window (~10% faster)
  205. for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
  206. fft_in[j] = hann[j] * samples[offset + j];
  207. }
  208. // fill the rest with zeros
  209. if (n_samples - offset < frame_size) {
  210. std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
  211. }
  212. // FFT
  213. fft(fft_in.data(), frame_size, fft_out.data());
  214. // Calculate modulus^2 of complex numbers
  215. // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
  216. for (int j = 0; j < n_fft_bins; j++) {
  217. fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
  218. }
  219. // mel spectrogram
  220. for (int j = 0; j < out.n_mel; j++) {
  221. double sum = 0.0;
  222. // unroll loop (suggested by GH user @lunixbochs)
  223. int k = 0;
  224. for (k = 0; k < n_fft_bins - 3; k += 4) {
  225. size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k);
  226. sum +=
  227. fft_out[k + 0] * filters.data[idx + 0] +
  228. fft_out[k + 1] * filters.data[idx + 1] +
  229. fft_out[k + 2] * filters.data[idx + 2] +
  230. fft_out[k + 3] * filters.data[idx + 3];
  231. }
  232. // handle n_fft remainder
  233. for (; k < n_fft_bins; k++) {
  234. sum += fft_out[k] * filters.data[j * n_fft_bins + k];
  235. }
  236. sum = params.use_natural_log
  237. ? log(sum + 5.960464477539063e-08)
  238. : log10(std::max(sum, 1e-10));
  239. out.data[j * out.n_len + i] = sum;
  240. }
  241. }
  242. // Otherwise fft_out are all zero
  243. double sum = params.use_natural_log ? log(1e-10) : log10(1e-10);
  244. for (; i < out.n_len; i += n_threads) {
  245. for (int j = 0; j < out.n_mel; j++) {
  246. out.data[j * out.n_len + i] = sum;
  247. }
  248. }
  249. }
  250. // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
  251. static bool log_mel_spectrogram(
  252. const float * samples,
  253. const int n_samples_in,
  254. const int n_threads,
  255. const filter_params & params,
  256. mtmd_audio_mel & out) {
  257. //const int64_t t_start_us = ggml_time_us();
  258. out.n_len_org = n_samples_in;
  259. int n_samples = n_samples_in;
  260. // Hann window
  261. const float * hann = g_cache.hann_window.data();
  262. const int frame_size = (params.n_fft_bins - 1) * 2;
  263. const int frame_step = params.hop_length;
  264. // Padding
  265. std::vector<float> samples_padded;
  266. if (params.center_padding) {
  267. const auto pad_amount = frame_size / 2;
  268. samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
  269. std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
  270. samples = samples_padded.data();
  271. n_samples = samples_padded.size();
  272. } else {
  273. // existing padding logic
  274. int64_t stage_1_pad = params.sample_rate * 30;
  275. int64_t stage_2_pad = frame_size / 2;
  276. samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
  277. std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
  278. // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
  279. std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
  280. // reflective pad 200 samples at the beginning of audio
  281. if (n_samples < stage_2_pad + 1) {
  282. // TODO: Handle short audio differently or return error
  283. return false;
  284. }
  285. std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
  286. }
  287. // preemphasis
  288. if (params.preemph) {
  289. const int pad_amount = frame_size / 2;
  290. const float preemph = 0.97f;
  291. float prev = samples_padded[pad_amount];
  292. for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
  293. float cur = samples_padded[i];
  294. samples_padded[i] = cur - preemph * prev;
  295. prev = cur;
  296. }
  297. }
  298. // pad hann window if it's smaller than frame_size
  299. // TODO: probably unnecessary here? (or better doing it in g_cache?)
  300. std::vector<float> hann_window_padded;
  301. if (params.hann_window_size < frame_size) {
  302. hann_window_padded.resize(frame_size);
  303. const int padding = (frame_size - params.hann_window_size) / 2;
  304. std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]);
  305. hann = hann_window_padded.data();
  306. }
  307. out.n_mel = params.n_mel;
  308. out.n_len = (n_samples - frame_size) / frame_step + 1;
  309. // TODO: handle these checks better
  310. if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) {
  311. LOG_ERR("%s: size overflow\n", __func__);
  312. return false;
  313. }
  314. if (n_samples < frame_size) {
  315. LOG_ERR("%s: not enough samples after padding\n", __func__);
  316. return false;
  317. }
  318. out.data.resize(out.n_mel * out.n_len);
  319. {
  320. std::vector<std::thread> workers(n_threads - 1);
  321. for (int iw = 0; iw < n_threads - 1; ++iw) {
  322. workers[iw] = std::thread(
  323. log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
  324. n_samples, frame_size, frame_step, n_threads,
  325. std::cref(params), std::ref(out));
  326. }
  327. // main thread
  328. log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out);
  329. for (int iw = 0; iw < n_threads - 1; ++iw) {
  330. workers[iw].join();
  331. }
  332. }
  333. const int effective_n_len = n_samples_in / frame_step;
  334. if (params.norm_per_feature) {
  335. for (int i = 0; i < out.n_mel; i++) {
  336. double mean = 0;
  337. for (int j = 0; j < effective_n_len; ++j) {
  338. mean += out.data[i * out.n_len + j];
  339. }
  340. mean /= effective_n_len;
  341. double var = 0.0;
  342. for (int j = 0; j < effective_n_len; ++j) {
  343. const double value = out.data[i * out.n_len + j] - mean;
  344. var += value * value;
  345. }
  346. var /= effective_n_len - 1; // unbiased
  347. const double mstd = std::sqrt(var + 1e-5);
  348. for (int j = 0; j < effective_n_len; ++j) {
  349. auto &value = out.data[i * out.n_len + j];
  350. value = (value - mean) / mstd;
  351. }
  352. // pad the rest with zeros
  353. for (int j = effective_n_len; j < out.n_len; ++j) {
  354. out.data[i * out.n_len + j] = 0.0;
  355. }
  356. }
  357. } else {
  358. // clamping and normalization
  359. double mmax = -1e20;
  360. for (int i = 0; i < out.n_mel*out.n_len; i++) {
  361. if (out.data[i] > mmax) {
  362. mmax = out.data[i];
  363. }
  364. }
  365. mmax -= 8.0;
  366. for (int i = 0; i < out.n_mel*out.n_len; i++) {
  367. if (out.data[i] < mmax) {
  368. out.data[i] = mmax;
  369. }
  370. out.data[i] = (out.data[i] + 4.0)/4.0;
  371. }
  372. }
  373. // Dump log_mel_spectrogram
  374. if (DEBUG) {
  375. std::ofstream outFile("log_mel_spectrogram.json");
  376. outFile << "[";
  377. for (uint64_t i = 0; i < out.data.size() - 1; i++) {
  378. outFile << out.data[i] << ", ";
  379. }
  380. outFile << out.data[out.data.size() - 1] << "]";
  381. outFile.close();
  382. }
  383. return true;
  384. }
  385. //
  386. // mtmd_audio_preprocessor_whisper
  387. //
  388. void mtmd_audio_preprocessor_whisper::initialize() {
  389. g_cache.fill_sin_cos_table(hparams.audio_n_fft);
  390. g_cache.fill_hann_window(hparams.audio_window_len, true);
  391. g_cache.fill_mel_filterbank_matrix(
  392. hparams.n_mel_bins,
  393. hparams.audio_n_fft,
  394. hparams.audio_sample_rate);
  395. }
  396. bool mtmd_audio_preprocessor_whisper::preprocess(
  397. const float * samples,
  398. size_t n_samples,
  399. std::vector<mtmd_audio_mel> & output) {
  400. if (n_samples == 0) {
  401. // empty audio
  402. return false;
  403. }
  404. std::vector<float> smpl;
  405. // if input is too short, pad with zeros
  406. // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
  407. // TODO: maybe handle this better
  408. size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
  409. if (n_samples < min_samples) {
  410. smpl.resize(min_samples, 0.0f);
  411. std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
  412. samples = smpl.data();
  413. n_samples = smpl.size();
  414. }
  415. filter_params params;
  416. params.n_mel = hparams.n_mel_bins;
  417. params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
  418. params.hann_window_size = hparams.audio_window_len;
  419. params.hop_length = hparams.audio_hop_len;
  420. params.sample_rate = hparams.audio_sample_rate;
  421. params.center_padding = false;
  422. params.preemph = 0.0f; // disabled
  423. params.use_natural_log = false;
  424. params.norm_per_feature = false;
  425. // make sure the global cache is initialized
  426. GGML_ASSERT(!g_cache.sin_vals.empty());
  427. GGML_ASSERT(!g_cache.cos_vals.empty());
  428. GGML_ASSERT(!g_cache.filters.data.empty());
  429. mtmd_audio_mel out_full;
  430. bool ok = log_mel_spectrogram(
  431. samples,
  432. n_samples,
  433. 4, // n_threads
  434. params,
  435. out_full);
  436. if (!ok) {
  437. return false;
  438. }
  439. // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
  440. // we always expect the mel to have 3000 silent frames at the end
  441. if (DEBUG) {
  442. printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
  443. }
  444. const size_t frames_per_chunk = 3000;
  445. GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
  446. for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
  447. int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
  448. if ((size_t)n_len < frames_per_chunk) {
  449. break; // last uncomplete chunk will always be a padded chunk, safe to ignore
  450. }
  451. mtmd_audio_mel out_chunk;
  452. out_chunk.n_len = n_len;
  453. out_chunk.n_mel = out_full.n_mel;
  454. out_chunk.n_len_org = out_full.n_mel; // unused
  455. out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
  456. for (int i = 0; i < out_full.n_mel; i++) {
  457. auto src = out_full.data.begin() + i*out_full.n_len + off;
  458. out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
  459. }
  460. output.push_back(std::move(out_chunk));
  461. }
  462. return true;
  463. }