softmax.cu 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. #include "common.cuh"
  2. #include "ggml.h"
  3. #include "softmax.cuh"
  4. #include <cstdint>
  5. template <typename T>
  6. static __device__ __forceinline__ float t2f32(T val) {
  7. return (float) val;
  8. }
  9. template <>
  10. __device__ float __forceinline__ t2f32<half>(half val) {
  11. return __half2float(val);
  12. }
  13. // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
  14. // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
  15. #ifdef __clang__
  16. #pragma clang diagnostic push
  17. #pragma clang diagnostic ignored "-Wpass-failed"
  18. #endif // __clang__
  19. template <bool use_shared, int ncols_template, int block_size_template, typename T>
  20. static __global__ void soft_max_f32(
  21. const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
  22. const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
  23. const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
  24. const int tid = threadIdx.x;
  25. const int rowx = blockIdx.x;
  26. const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
  27. x += int64_t(rowx)*ncols;
  28. mask += int64_t(rowy)*ncols * (mask != nullptr);
  29. dst += int64_t(rowx)*ncols;
  30. const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
  31. const int warp_id = threadIdx.x / WARP_SIZE;
  32. const int lane_id = threadIdx.x % WARP_SIZE;
  33. const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
  34. extern __shared__ float data_soft_max_f32[];
  35. float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
  36. // shared memory buffer to cache values between iterations:
  37. float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
  38. float max_val = -INFINITY;
  39. #pragma unroll
  40. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  41. const int col = col0 + tid;
  42. if (ncols_template == 0 && col >= ncols) {
  43. break;
  44. }
  45. const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
  46. vals[col] = val;
  47. max_val = max(max_val, val);
  48. }
  49. // find the max value in the block
  50. max_val = warp_reduce_max(max_val);
  51. if (block_size > WARP_SIZE) {
  52. if (warp_id == 0) {
  53. buf_iw[lane_id] = -INFINITY;
  54. }
  55. __syncthreads();
  56. if (lane_id == 0) {
  57. buf_iw[warp_id] = max_val;
  58. }
  59. __syncthreads();
  60. max_val = buf_iw[lane_id];
  61. max_val = warp_reduce_max(max_val);
  62. }
  63. float tmp = 0.0f; // partial sum
  64. #pragma unroll
  65. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  66. const int col = col0 + tid;
  67. if (ncols_template == 0 && col >= ncols) {
  68. break;
  69. }
  70. const float val = expf(vals[col] - max_val);
  71. tmp += val;
  72. vals[col] = val;
  73. }
  74. // find the sum of exps in the block
  75. tmp = warp_reduce_sum(tmp);
  76. if (block_size > WARP_SIZE) {
  77. __syncthreads();
  78. if (warp_id == 0) {
  79. buf_iw[lane_id] = 0.0f;
  80. }
  81. __syncthreads();
  82. if (lane_id == 0) {
  83. buf_iw[warp_id] = tmp;
  84. }
  85. __syncthreads();
  86. tmp = buf_iw[lane_id];
  87. tmp = warp_reduce_sum(tmp);
  88. }
  89. const float inv_sum = 1.0f / tmp;
  90. #pragma unroll
  91. for (int col0 = 0; col0 < ncols; col0 += block_size) {
  92. const int col = col0 + tid;
  93. if (ncols_template == 0 && col >= ncols) {
  94. return;
  95. }
  96. dst[col] = vals[col] * inv_sum;
  97. }
  98. }
  99. #ifdef __clang__
  100. #pragma clang diagnostic pop
  101. #endif // __clang__
  102. static __global__ void soft_max_back_f32(
  103. const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
  104. const int tid = threadIdx.x;
  105. const int rowx = blockIdx.x;
  106. grad += int64_t(rowx)*ncols;
  107. dstf += int64_t(rowx)*ncols;
  108. dst += int64_t(rowx)*ncols;
  109. float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
  110. for (int col = tid; col < ncols; col += WARP_SIZE) {
  111. dgf_dot += dstf[col]*grad[col];
  112. }
  113. dgf_dot = warp_reduce_sum(dgf_dot);
  114. for (int col = tid; col < ncols; col += WARP_SIZE) {
  115. dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
  116. }
  117. }
  118. template<typename T>
  119. static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
  120. int nth = WARP_SIZE;
  121. while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
  122. const dim3 block_dims(nth, 1, 1);
  123. const dim3 block_nums(nrows_x, 1, 1);
  124. const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
  125. static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
  126. const uint32_t n_head = nrows_x/nrows_y;
  127. const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
  128. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  129. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  130. // FIXME: this limit could be raised by ~2-4x on Ampere or newer
  131. if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
  132. switch (ncols_x) {
  133. case 32:
  134. soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
  135. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  136. break;
  137. case 64:
  138. soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
  139. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  140. break;
  141. case 128:
  142. soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
  143. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  144. break;
  145. case 256:
  146. soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
  147. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  148. break;
  149. case 512:
  150. soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
  151. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  152. break;
  153. case 1024:
  154. soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
  155. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  156. break;
  157. case 2048:
  158. soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
  159. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  160. break;
  161. case 4096:
  162. soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
  163. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  164. break;
  165. default:
  166. soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
  167. (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  168. break;
  169. }
  170. } else {
  171. const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
  172. soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
  173. }
  174. }
  175. static void soft_max_back_f32_cuda(
  176. const float * grad, const float * dstf, float * dst,
  177. const int ncols, const int nrows, const float scale, cudaStream_t stream) {
  178. const dim3 block_dims(WARP_SIZE, 1, 1);
  179. const dim3 block_nums(nrows, 1, 1);
  180. soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
  181. }
  182. void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  183. const ggml_tensor * src0 = dst->src[0];
  184. const ggml_tensor * src1 = dst->src[1];
  185. const float * src0_d = (const float *) src0->data;
  186. const void * src1_d = src1 ? (const void *) src1->data : nullptr;
  187. float * dst_d = (float *) dst->data;
  188. cudaStream_t stream = ctx.stream();
  189. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  190. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  191. GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
  192. const int64_t ne00 = src0->ne[0];
  193. const int64_t nrows_x = ggml_nrows(src0);
  194. const int64_t nrows_y = src0->ne[1];
  195. float scale = 1.0f;
  196. float max_bias = 0.0f;
  197. memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
  198. memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
  199. const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
  200. if (use_f16) {
  201. soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
  202. } else {
  203. soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
  204. }
  205. }
  206. void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  207. const ggml_tensor * src0 = dst->src[0]; // grad
  208. const ggml_tensor * src1 = dst->src[1]; // forward pass output
  209. const float * src0_d = (const float *) src0->data;
  210. const float * src1_d = (const float *) src1->data;
  211. float * dst_d = (float *) dst->data;
  212. cudaStream_t stream = ctx.stream();
  213. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  214. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  215. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  216. const int64_t ncols = src0->ne[0];
  217. const int64_t nrows = ggml_nrows(src0);
  218. float scale = 1.0f;
  219. float max_bias = 0.0f;
  220. memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
  221. memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
  222. GGML_ASSERT(max_bias == 0.0f);
  223. soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
  224. }