rope.cu 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. #include "rope.cuh"
  2. struct rope_corr_dims {
  3. float v[4];
  4. };
  5. static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
  6. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  7. return 1.0f - min(1.0f, max(0.0f, y));
  8. }
  9. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  10. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  11. static __device__ void rope_yarn(
  12. float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
  13. float * cos_theta, float * sin_theta
  14. ) {
  15. // Get n-d rotational scaling corrected for extrapolation
  16. float theta_interp = freq_scale * theta_extrap;
  17. float theta = theta_interp;
  18. if (ext_factor != 0.0f) {
  19. float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
  20. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  21. // Get n-d magnitude scaling corrected for interpolation
  22. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  23. }
  24. *cos_theta = cosf(theta) * mscale;
  25. *sin_theta = sinf(theta) * mscale;
  26. }
  27. // rope == RoPE == rotary positional embedding
  28. template<typename T, bool has_pos>
  29. static __global__ void rope(
  30. const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
  31. float ext_factor, float attn_factor, rope_corr_dims corr_dims
  32. ) {
  33. const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  34. if (col >= ncols) {
  35. return;
  36. }
  37. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  38. const int i = row*ncols + col;
  39. const int i2 = row/p_delta_rows;
  40. const int p = has_pos ? pos[i2] : 0;
  41. const float theta_base = p*powf(freq_base, -float(col)/ncols);
  42. float cos_theta, sin_theta;
  43. rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
  44. const float x0 = x[i + 0];
  45. const float x1 = x[i + 1];
  46. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  47. dst[i + 1] = x0*sin_theta + x1*cos_theta;
  48. }
  49. template<typename T, bool has_pos, bool has_freq_facs>
  50. static __global__ void rope_neox(
  51. const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
  52. float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
  53. ) {
  54. const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
  55. if (col >= ncols) {
  56. return;
  57. }
  58. const int row = blockDim.x*blockIdx.x + threadIdx.x;
  59. const int ib = col / n_dims;
  60. const int ic = col % n_dims;
  61. if (ib > 0) {
  62. const int i = row*ncols + ib*n_dims + ic;
  63. dst[i + 0] = x[i + 0];
  64. dst[i + 1] = x[i + 1];
  65. return;
  66. }
  67. const int i = row*ncols + ib*n_dims + ic/2;
  68. const int i2 = row/p_delta_rows;
  69. float cur_rot = inv_ndims * ic - ib;
  70. const int p = has_pos ? pos[i2] : 0;
  71. const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
  72. const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
  73. float cos_theta, sin_theta;
  74. rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
  75. const float x0 = x[i + 0];
  76. const float x1 = x[i + n_dims/2];
  77. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  78. dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
  79. }
  80. static __global__ void rope_glm_f32(
  81. const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
  82. int n_ctx
  83. ) {
  84. const int col = blockDim.x*blockIdx.x + threadIdx.x;
  85. const int half_n_dims = ncols/4;
  86. if (col >= half_n_dims) {
  87. return;
  88. }
  89. const int row = blockDim.y*blockIdx.y + threadIdx.y;
  90. const int i = row*ncols + col;
  91. const int i2 = row/p_delta_rows;
  92. const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
  93. // FIXME: this is likely wrong
  94. const int p = pos != nullptr ? pos[i2] : 0;
  95. const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
  96. const float sin_theta = sinf(theta);
  97. const float cos_theta = cosf(theta);
  98. const float x0 = x[i + 0];
  99. const float x1 = x[i + half_n_dims];
  100. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  101. dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
  102. const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
  103. const float sin_block_theta = sinf(block_theta);
  104. const float cos_block_theta = cosf(block_theta);
  105. const float x2 = x[i + half_n_dims * 2];
  106. const float x3 = x[i + half_n_dims * 3];
  107. dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
  108. dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
  109. }
  110. template<typename T>
  111. static void rope_cuda(
  112. const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  113. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
  114. ) {
  115. GGML_ASSERT(ncols % 2 == 0);
  116. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  117. const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  118. const dim3 block_nums(nrows, num_blocks_x, 1);
  119. if (pos == nullptr) {
  120. rope<T, false><<<block_nums, block_dims, 0, stream>>>(
  121. x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
  122. );
  123. } else {
  124. rope<T, true><<<block_nums, block_dims, 0, stream>>>(
  125. x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
  126. );
  127. }
  128. }
  129. template<typename T>
  130. static void rope_neox_cuda(
  131. const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  132. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
  133. ) {
  134. GGML_ASSERT(ncols % 2 == 0);
  135. const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
  136. const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  137. const dim3 block_nums(nrows, num_blocks_x, 1);
  138. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  139. const float inv_ndims = -1.0f / n_dims;
  140. if (pos == nullptr) {
  141. if (freq_factors == nullptr) {
  142. rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
  143. x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  144. theta_scale, inv_ndims, freq_factors
  145. );
  146. } else {
  147. rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
  148. x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  149. theta_scale, inv_ndims, freq_factors
  150. );
  151. }
  152. } else {
  153. if (freq_factors == nullptr) {
  154. rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
  155. x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  156. theta_scale, inv_ndims, freq_factors
  157. );
  158. } else {
  159. rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
  160. x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
  161. theta_scale, inv_ndims, freq_factors
  162. );
  163. }
  164. }
  165. }
  166. static void rope_glm_f32_cuda(
  167. const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  168. float freq_base, int n_ctx, cudaStream_t stream
  169. ) {
  170. GGML_ASSERT(ncols % 4 == 0);
  171. const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
  172. const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
  173. const dim3 block_nums(num_blocks_x, nrows, 1);
  174. rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
  175. }
  176. static void rope_cuda_f16(
  177. const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  178. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
  179. rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
  180. }
  181. static void rope_cuda_f32(
  182. const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  183. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
  184. rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
  185. }
  186. static void rope_neox_cuda_f16(
  187. const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  188. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
  189. rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  190. }
  191. static void rope_neox_cuda_f32(
  192. const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
  193. float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
  194. ) {
  195. rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
  196. }
  197. void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  198. const ggml_tensor * src0 = dst->src[0];
  199. const ggml_tensor * src1 = dst->src[1];
  200. const ggml_tensor * src2 = dst->src[2];
  201. const float * src0_d = (const float *)src0->data;
  202. const float * src1_d = (const float *)src1->data;
  203. float * dst_d = (float *)dst->data;
  204. cudaStream_t stream = ctx.stream();
  205. GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
  206. GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  207. GGML_ASSERT(src0->type == dst->type);
  208. const int64_t ne00 = src0->ne[0];
  209. const int64_t ne01 = src0->ne[1];
  210. const int64_t nrows = ggml_nrows(src0);
  211. //const int n_past = ((int32_t *) dst->op_params)[0];
  212. const int n_dims = ((int32_t *) dst->op_params)[1];
  213. const int mode = ((int32_t *) dst->op_params)[2];
  214. const int n_ctx = ((int32_t *) dst->op_params)[3];
  215. const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
  216. // RoPE alteration for extended context
  217. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  218. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  219. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  220. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  221. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  222. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  223. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  224. const float * freq_factors = nullptr;
  225. const int32_t * pos = nullptr;
  226. const bool is_neox = mode & 2;
  227. const bool is_glm = mode & 4;
  228. pos = (const int32_t *) src1_d;
  229. if (is_neox) {
  230. if (src2 != nullptr) {
  231. freq_factors = (const float *) src2->data;
  232. }
  233. } else {
  234. GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
  235. }
  236. rope_corr_dims corr_dims;
  237. ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
  238. // compute
  239. if (is_glm) {
  240. GGML_ASSERT(false);
  241. rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
  242. } else if (is_neox) {
  243. if (src0->type == GGML_TYPE_F32) {
  244. rope_neox_cuda_f32(
  245. (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
  246. attn_factor, corr_dims, freq_factors, stream
  247. );
  248. } else if (src0->type == GGML_TYPE_F16) {
  249. rope_neox_cuda_f16(
  250. (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
  251. attn_factor, corr_dims, freq_factors, stream
  252. );
  253. } else {
  254. GGML_ASSERT(false);
  255. }
  256. } else {
  257. if (src0->type == GGML_TYPE_F32) {
  258. rope_cuda_f32(
  259. (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
  260. attn_factor, corr_dims, stream
  261. );
  262. } else if (src0->type == GGML_TYPE_F16) {
  263. rope_cuda_f16(
  264. (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
  265. attn_factor, corr_dims, stream
  266. );
  267. } else {
  268. GGML_ASSERT(false);
  269. }
  270. }
  271. }