|
@@ -2,6 +2,7 @@
|
|
|
#include "ggml.h"
|
|
#include "ggml.h"
|
|
|
#include "softmax.cuh"
|
|
#include "softmax.cuh"
|
|
|
#include <cstdint>
|
|
#include <cstdint>
|
|
|
|
|
+#include <utility>
|
|
|
|
|
|
|
|
template <typename T>
|
|
template <typename T>
|
|
|
static __device__ __forceinline__ float t2f32(T val) {
|
|
static __device__ __forceinline__ float t2f32(T val) {
|
|
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+template<int... Ns, typename T>
|
|
|
|
|
+static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
|
|
|
|
+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
|
|
|
|
+{
|
|
|
|
|
+ const int id = ggml_cuda_get_device();
|
|
|
|
|
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
|
|
|
+
|
|
|
|
|
+ auto launch_kernel = [=](auto I) -> bool {
|
|
|
|
|
+ constexpr int ncols = decltype(I)::value;
|
|
|
|
|
+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
|
|
|
|
+
|
|
|
|
|
+ if (p.ncols == ncols) {
|
|
|
|
|
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
|
|
|
|
+ soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
+ (x, mask, dst, p);
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+ return false;
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // unary fold over launch_kernel
|
|
|
|
|
+ if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //default case
|
|
|
|
|
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
|
|
|
|
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
template<typename T>
|
|
template<typename T>
|
|
|
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
|
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
|
|
int nth = WARP_SIZE;
|
|
int nth = WARP_SIZE;
|
|
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
|
|
|
|
|
|
|
|
|
|
|
|
- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
|
|
|
|
- if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
|
|
|
|
- switch (ncols_x) {
|
|
|
|
|
- case 32:
|
|
|
|
|
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 64:
|
|
|
|
|
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 128:
|
|
|
|
|
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 256:
|
|
|
|
|
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 512:
|
|
|
|
|
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 1024:
|
|
|
|
|
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 2048:
|
|
|
|
|
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- case 4096:
|
|
|
|
|
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- default:
|
|
|
|
|
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
|
|
- (x, mask, dst, params);
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ const int id = ggml_cuda_get_device();
|
|
|
|
|
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ if (nbytes_shared <= smpbo) {
|
|
|
|
|
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
|
|
} else {
|
|
} else {
|
|
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
|
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|