softmax.cu 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. #include "softmax.cuh"
  2. template <bool vals_smem, int ncols_template, int block_size_template>
  3. static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
  4. const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
  5. const int tid = threadIdx.x;
  6. const int rowx = blockIdx.x;
  7. const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
  8. const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
  9. const int warp_id = threadIdx.x / WARP_SIZE;
  10. const int lane_id = threadIdx.x % WARP_SIZE;
  11. float slope = 0.0f;
  12. // ALiBi
  13. if (max_bias > 0.0f) {
  14. const int h = rowx/nrows_y; // head index
  15. const float base = h < n_head_log2 ? m0 : m1;
  16. const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  17. slope = powf(base, exp);
  18. }
  19. extern __shared__ float data_soft_max_f32[];
  20. float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
  21. // shared memory buffer to cache values between iterations:
  22. float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
  23. float max_val = -INFINITY;
  24. #pragma unroll
  25. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  26. const int col = col0 + tid;
  27. if (ncols_template == 0 && col >= ncols) {
  28. break;
  29. }
  30. const int ix = rowx*ncols + col;
  31. const int iy = rowy*ncols + col;
  32. const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
  33. vals[col] = val;
  34. max_val = max(max_val, val);
  35. }
  36. // find the max value in the block
  37. max_val = warp_reduce_max(max_val);
  38. if (block_size > WARP_SIZE) {
  39. if (warp_id == 0) {
  40. buf_iw[lane_id] = -INFINITY;
  41. }
  42. __syncthreads();
  43. if (lane_id == 0) {
  44. buf_iw[warp_id] = max_val;
  45. }
  46. __syncthreads();
  47. max_val = buf_iw[lane_id];
  48. max_val = warp_reduce_max(max_val);
  49. }
  50. float tmp = 0.0f; // partial sum
  51. #pragma unroll
  52. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  53. const int col = col0 + tid;
  54. if (ncols_template == 0 && col >= ncols) {
  55. break;
  56. }
  57. const float val = expf(vals[col] - max_val);
  58. tmp += val;
  59. vals[col] = val;
  60. }
  61. // find the sum of exps in the block
  62. tmp = warp_reduce_sum(tmp);
  63. if (block_size > WARP_SIZE) {
  64. __syncthreads();
  65. if (warp_id == 0) {
  66. buf_iw[lane_id] = 0.0f;
  67. }
  68. __syncthreads();
  69. if (lane_id == 0) {
  70. buf_iw[warp_id] = tmp;
  71. }
  72. __syncthreads();
  73. tmp = buf_iw[lane_id];
  74. tmp = warp_reduce_sum(tmp);
  75. }
  76. const float inv_sum = 1.0f / tmp;
  77. #pragma unroll
  78. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  79. const int col = col0 + tid;
  80. if (ncols_template == 0 && col >= ncols) {
  81. return;
  82. }
  83. const int idst = rowx*ncols + col;
  84. dst[idst] = vals[col] * inv_sum;
  85. }
  86. }
  87. static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
  88. int nth = WARP_SIZE;
  89. while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
  90. const dim3 block_dims(nth, 1, 1);
  91. const dim3 block_nums(nrows_x, 1, 1);
  92. const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
  93. static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
  94. const uint32_t n_head_kv = nrows_x/nrows_y;
  95. const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
  96. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  97. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  98. if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
  99. switch (ncols_x) {
  100. case 32:
  101. soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  102. break;
  103. case 64:
  104. soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  105. break;
  106. case 128:
  107. soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  108. break;
  109. case 256:
  110. soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  111. break;
  112. case 512:
  113. soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  114. break;
  115. case 1024:
  116. soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  117. break;
  118. case 2048:
  119. soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  120. break;
  121. case 4096:
  122. soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  123. break;
  124. default:
  125. soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  126. break;
  127. }
  128. } else {
  129. const size_t shmem_low = WARP_SIZE*sizeof(float);
  130. soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  131. }
  132. }
  133. void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  134. const ggml_tensor * src0 = dst->src[0];
  135. const ggml_tensor * src1 = dst->src[1];
  136. const float * src0_d = (const float *)src0->data;
  137. const float * src1_d = src1 ? (const float *)src1->data : nullptr;
  138. float * dst_d = (float *)dst->data;
  139. cudaStream_t stream = ctx.stream();
  140. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  141. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  142. GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
  143. const int64_t ne00 = src0->ne[0];
  144. const int64_t nrows_x = ggml_nrows(src0);
  145. const int64_t nrows_y = src0->ne[1];
  146. float scale = 1.0f;
  147. float max_bias = 0.0f;
  148. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  149. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  150. // positions tensor
  151. float * src2_dd = nullptr;
  152. ggml_tensor * src2 = dst->src[2];
  153. const bool use_src2 = src2 != nullptr;
  154. if (use_src2) {
  155. src2_dd = (float *)src2->data;
  156. }
  157. soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
  158. }