rope.cu 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #include "rope.cuh"
  2. struct rope_corr_dims {
  3. float v[2];
  4. };
  5. static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
  6. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  7. return 1.0f - min(1.0f, max(0.0f, y));
  8. }
  9. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  10. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  11. static __device__ void rope_yarn(
  12. float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
  13. float * cos_theta, float * sin_theta) {
  14. // Get n-d rotational scaling corrected for extrapolation
  15. float theta_interp = freq_scale * theta_extrap;
  16. float theta = theta_interp;
  17. if (ext_factor != 0.0f) {
  18. float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
  19. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  20. // Get n-d magnitude scaling corrected for interpolation
  21. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  22. }
  23. *cos_theta = cosf(theta) * mscale;
  24. *sin_theta = sinf(theta) * mscale;
  25. }
  26. template<typename T, bool has_ff>
  27. static __global__ void rope_norm(
  28. const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  29. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
  30. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  31. if (i0 >= ne0) {
  32. return;
  33. }
  34. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  35. if (i0 >= n_dims) {
  36. const int i = row*ne0 + i0;
  37. dst[i + 0] = x[i + 0];
  38. dst[i + 1] = x[i + 1];
  39. return;
  40. }
  41. const int i = row*ne0 + i0;
  42. const int i2 = row/p_delta_rows;
  43. const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  44. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  45. float cos_theta;
  46. float sin_theta;
  47. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  48. const float x0 = x[i + 0];
  49. const float x1 = x[i + 1];
  50. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  51. dst[i + 1] = x0*sin_theta + x1*cos_theta;
  52. }
  53. template<typename T, bool has_ff>
  54. static __global__ void rope_neox(
  55. const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  56. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
  57. const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  58. if (i0 >= ne0) {
  59. return;
  60. }
  61. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  62. if (i0 >= n_dims) {
  63. const int i = row*ne0 + i0;
  64. dst[i + 0] = x[i + 0];
  65. dst[i + 1] = x[i + 1];
  66. return;
  67. }
  68. const int i = row*ne0 + i0/2;
  69. const int i2 = row/p_delta_rows;
  70. const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
  71. const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
  72. float cos_theta;
  73. float sin_theta;
  74. rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
  75. const float x0 = x[i + 0];
  76. const float x1 = x[i + n_dims/2];
  77. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  78. dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
  79. }
  80. template<typename T>
  81. static void rope_norm_cuda(
  82. const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  83. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  84. GGML_ASSERT(ne0 % 2 == 0);
  85. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  86. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  87. const dim3 block_nums(nr, n_blocks_x, 1);
  88. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  89. if (freq_factors == nullptr) {
  90. rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
  91. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  92. theta_scale, freq_factors
  93. );
  94. } else {
  95. rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
  96. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  97. theta_scale, freq_factors
  98. );
  99. }
  100. }
  101. template<typename T>
  102. static void rope_neox_cuda(
  103. const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  104. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  105. GGML_ASSERT(ne0 % 2 == 0);
  106. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  107. const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  108. const dim3 block_nums(nr, n_blocks_x, 1);
  109. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  110. if (freq_factors == nullptr) {
  111. rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
  112. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  113. theta_scale, freq_factors
  114. );
  115. } else {
  116. rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
  117. x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  118. theta_scale, freq_factors
  119. );
  120. }
  121. }
  122. static void rope_norm_cuda_f16(
  123. const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  124. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  125. rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  126. }
  127. static void rope_norm_cuda_f32(
  128. const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  129. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  130. rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  131. }
  132. static void rope_neox_cuda_f16(
  133. const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  134. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  135. rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  136. }
  137. static void rope_neox_cuda_f32(
  138. const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
  139. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
  140. ) {
  141. rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  142. }
  143. void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  144. const ggml_tensor * src0 = dst->src[0];
  145. const ggml_tensor * src1 = dst->src[1];
  146. const ggml_tensor * src2 = dst->src[2];
  147. const float * src0_d = (const float *)src0->data;
  148. const float * src1_d = (const float *)src1->data;
  149. float * dst_d = (float *)dst->data;
  150. cudaStream_t stream = ctx.stream();
  151. GGML_ASSERT(ggml_is_contiguous(src0));
  152. GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
  153. GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  154. GGML_ASSERT(src0->type == dst->type);
  155. const int64_t ne00 = src0->ne[0];
  156. const int64_t ne01 = src0->ne[1];
  157. const int64_t nr = ggml_nrows(src0);
  158. //const int n_past = ((int32_t *) dst->op_params)[0];
  159. const int n_dims = ((int32_t *) dst->op_params)[1];
  160. const int mode = ((int32_t *) dst->op_params)[2];
  161. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  162. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  163. // RoPE alteration for extended context
  164. float freq_base;
  165. float freq_scale;
  166. float ext_factor;
  167. float attn_factor;
  168. float beta_fast;
  169. float beta_slow;
  170. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  171. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  172. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  173. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  174. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  175. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  176. const bool is_neox = mode & 2;
  177. const int32_t * pos = (const int32_t *) src1_d;
  178. const float * freq_factors = nullptr;
  179. if (src2 != nullptr) {
  180. freq_factors = (const float *) src2->data;
  181. }
  182. rope_corr_dims corr_dims;
  183. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
  184. // compute
  185. if (is_neox) {
  186. if (src0->type == GGML_TYPE_F32) {
  187. rope_neox_cuda_f32(
  188. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  189. attn_factor, corr_dims, freq_factors, stream
  190. );
  191. } else if (src0->type == GGML_TYPE_F16) {
  192. rope_neox_cuda_f16(
  193. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  194. attn_factor, corr_dims, freq_factors, stream
  195. );
  196. } else {
  197. GGML_ASSERT(false);
  198. }
  199. } else {
  200. if (src0->type == GGML_TYPE_F32) {
  201. rope_norm_cuda_f32(
  202. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  203. attn_factor, corr_dims, freq_factors, stream
  204. );
  205. } else if (src0->type == GGML_TYPE_F16) {
  206. rope_norm_cuda_f16(
  207. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
  208. attn_factor, corr_dims, freq_factors, stream
  209. );
  210. } else {
  211. GGML_ASSERT(false);
  212. }
  213. }
  214. }