softmax.cu 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #include "softmax.cuh"
  2. template <typename T>
  3. static __device__ __forceinline__ float t2f32(T val) {
  4. return (float) val;
  5. }
  6. template <>
  7. __device__ float __forceinline__ t2f32<half>(half val) {
  8. return __half2float(val);
  9. }
  10. template <bool vals_smem, int ncols_template, int block_size_template, typename T>
  11. static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
  12. const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
  13. const int tid = threadIdx.x;
  14. const int rowx = blockIdx.x;
  15. const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
  16. const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
  17. const int warp_id = threadIdx.x / WARP_SIZE;
  18. const int lane_id = threadIdx.x % WARP_SIZE;
  19. float slope = 0.0f;
  20. // ALiBi
  21. if (max_bias > 0.0f) {
  22. const int h = rowx/nrows_y; // head index
  23. const float base = h < n_head_log2 ? m0 : m1;
  24. const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  25. slope = powf(base, exp);
  26. }
  27. extern __shared__ float data_soft_max_f32[];
  28. float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
  29. // shared memory buffer to cache values between iterations:
  30. float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
  31. float max_val = -INFINITY;
  32. #pragma unroll
  33. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  34. const int col = col0 + tid;
  35. if (ncols_template == 0 && col >= ncols) {
  36. break;
  37. }
  38. const int64_t ix = (int64_t)rowx*ncols + col;
  39. const int64_t iy = (int64_t)rowy*ncols + col;
  40. const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
  41. vals[col] = val;
  42. max_val = max(max_val, val);
  43. }
  44. // find the max value in the block
  45. max_val = warp_reduce_max(max_val);
  46. if (block_size > WARP_SIZE) {
  47. if (warp_id == 0) {
  48. buf_iw[lane_id] = -INFINITY;
  49. }
  50. __syncthreads();
  51. if (lane_id == 0) {
  52. buf_iw[warp_id] = max_val;
  53. }
  54. __syncthreads();
  55. max_val = buf_iw[lane_id];
  56. max_val = warp_reduce_max(max_val);
  57. }
  58. float tmp = 0.0f; // partial sum
  59. #pragma unroll
  60. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  61. const int col = col0 + tid;
  62. if (ncols_template == 0 && col >= ncols) {
  63. break;
  64. }
  65. const float val = expf(vals[col] - max_val);
  66. tmp += val;
  67. vals[col] = val;
  68. }
  69. // find the sum of exps in the block
  70. tmp = warp_reduce_sum(tmp);
  71. if (block_size > WARP_SIZE) {
  72. __syncthreads();
  73. if (warp_id == 0) {
  74. buf_iw[lane_id] = 0.0f;
  75. }
  76. __syncthreads();
  77. if (lane_id == 0) {
  78. buf_iw[warp_id] = tmp;
  79. }
  80. __syncthreads();
  81. tmp = buf_iw[lane_id];
  82. tmp = warp_reduce_sum(tmp);
  83. }
  84. const float inv_sum = 1.0f / tmp;
  85. #pragma unroll
  86. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  87. const int col = col0 + tid;
  88. if (ncols_template == 0 && col >= ncols) {
  89. return;
  90. }
  91. const int64_t idst = (int64_t)rowx*ncols + col;
  92. dst[idst] = vals[col] * inv_sum;
  93. }
  94. }
  95. template<typename T>
  96. static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
  97. int nth = WARP_SIZE;
  98. while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
  99. const dim3 block_dims(nth, 1, 1);
  100. const dim3 block_nums(nrows_x, 1, 1);
  101. const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
  102. static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
  103. const uint32_t n_head_kv = nrows_x/nrows_y;
  104. const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
  105. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  106. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  107. if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
  108. switch (ncols_x) {
  109. case 32:
  110. soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  111. break;
  112. case 64:
  113. soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  114. break;
  115. case 128:
  116. soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  117. break;
  118. case 256:
  119. soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  120. break;
  121. case 512:
  122. soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  123. break;
  124. case 1024:
  125. soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  126. break;
  127. case 2048:
  128. soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  129. break;
  130. case 4096:
  131. soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  132. break;
  133. default:
  134. soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  135. break;
  136. }
  137. } else {
  138. const size_t shmem_low = WARP_SIZE*sizeof(float);
  139. soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  140. }
  141. }
  142. void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  143. const ggml_tensor * src0 = dst->src[0];
  144. const ggml_tensor * src1 = dst->src[1];
  145. const ggml_tensor * src2 = dst->src[2];
  146. const float * src0_d = (const float *)src0->data;
  147. const void * src1_d = src1 ? (const void *)src1->data : nullptr;
  148. float * dst_d = (float *)dst->data;
  149. cudaStream_t stream = ctx.stream();
  150. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  151. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  152. GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
  153. GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
  154. const int64_t ne00 = src0->ne[0];
  155. const int64_t nrows_x = ggml_nrows(src0);
  156. const int64_t nrows_y = src0->ne[1];
  157. float scale = 1.0f;
  158. float max_bias = 0.0f;
  159. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  160. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  161. // positions tensor
  162. void * src2_d = nullptr;
  163. const bool use_src2 = src2 != nullptr;
  164. if (use_src2) {
  165. src2_d = (void *)src2->data;
  166. }
  167. const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
  168. if (use_f16) {
  169. const half * src1_dd = (const half *)src1_d;
  170. const half * src2_dd = (const half *)src2_d;
  171. soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
  172. } else {
  173. const float * src1_dd = (const float *)src1_d;
  174. const float * src2_dd = (const float *)src2_d;
  175. soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
  176. }
  177. }