rope_common.comp 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #include "common.comp"
  2. // TODO: use a local size of 32 or more (Metal uses 1024)
  3. layout(local_size_x = 1) in;
  4. layout (push_constant) uniform parameter {
  5. uint inAOff;
  6. uint inBOff;
  7. uint outOff;
  8. int n_dims;
  9. int mode;
  10. int n_orig_ctx;
  11. float freq_base;
  12. float freq_scale;
  13. float ext_factor;
  14. float attn_factor;
  15. float beta_fast;
  16. float beta_slow;
  17. uint nb00;
  18. uint nb01;
  19. uint nb02;
  20. uint nb03;
  21. int ne0;
  22. uint nb0;
  23. uint nb1;
  24. uint nb2;
  25. uint nb3;
  26. } pcs;
  27. float rope_yarn_ramp(const float low, const float high, const float i0) {
  28. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  29. return 1.0f - min(1.0f, max(0.0f, y));
  30. }
  31. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  32. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  33. void rope_yarn(
  34. float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
  35. out float cos_theta, out float sin_theta
  36. ) {
  37. // Get n-d rotational scaling corrected for extrapolation
  38. float theta_interp = freq_scale * theta_extrap;
  39. float theta = theta_interp;
  40. if (ext_factor != 0.0f) {
  41. float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
  42. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  43. // Get n-d magnitude scaling corrected for interpolation
  44. mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
  45. }
  46. cos_theta = cos(theta) * mscale;
  47. sin_theta = sin(theta) * mscale;
  48. }
  49. // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
  50. // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
  51. float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
  52. return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
  53. }
  54. void rope_yarn_corr_dims(
  55. int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
  56. ) {
  57. // start and end correction dims
  58. dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
  59. dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
  60. }