|
|
@@ -2,6 +2,7 @@
|
|
|
#include "ggml.h"
|
|
|
#include "topk-moe.cuh"
|
|
|
|
|
|
+#include <cmath>
|
|
|
#include <initializer_list>
|
|
|
|
|
|
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
|
|
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
float * weights,
|
|
|
int32_t * ids,
|
|
|
const int n_rows,
|
|
|
- const int n_expert_used) {
|
|
|
+ const int n_expert_used,
|
|
|
+ const float clamp_val) {
|
|
|
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
|
|
if (row >= n_rows) {
|
|
|
return;
|
|
|
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
|
if constexpr (with_norm) {
|
|
|
wt_sum = warp_reduce_sum(wt_sum);
|
|
|
+ wt_sum = max(wt_sum, clamp_val);
|
|
|
const float inv_sum = 1.0f / wt_sum;
|
|
|
|
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
|
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
weights[idx] = output_weights[i];
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ if (!with_norm) {
|
|
|
+ GGML_UNUSED(clamp_val);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
template <bool with_norm, bool delayed_softmax = false>
|
|
|
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
int32_t * ids,
|
|
|
const int n_rows,
|
|
|
const int n_expert,
|
|
|
- const int n_expert_used) {
|
|
|
+ const int n_expert_used,
|
|
|
+ const float clamp_val) {
|
|
|
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
|
|
-
|
|
|
const int rows_per_block = 4;
|
|
|
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
|
|
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
|
|
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
switch (n_expert) {
|
|
|
case 1:
|
|
|
topk_moe_cuda<1, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 2:
|
|
|
topk_moe_cuda<2, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 4:
|
|
|
topk_moe_cuda<4, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 8:
|
|
|
topk_moe_cuda<8, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 16:
|
|
|
topk_moe_cuda<16, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 32:
|
|
|
topk_moe_cuda<32, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 64:
|
|
|
topk_moe_cuda<64, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 128:
|
|
|
topk_moe_cuda<128, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 256:
|
|
|
topk_moe_cuda<256, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
case 512:
|
|
|
topk_moe_cuda<512, with_norm, delayed_softmax>
|
|
|
- <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
|
|
break;
|
|
|
default:
|
|
|
GGML_ASSERT(false && "fatal error");
|
|
|
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|
|
ggml_tensor * weights,
|
|
|
ggml_tensor * ids,
|
|
|
const bool with_norm,
|
|
|
- const bool delayed_softmax) {
|
|
|
+ const bool delayed_softmax,
|
|
|
+ ggml_tensor * clamp) {
|
|
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
|
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|
|
|
|
|
const int n_expert_used = weights->ne[1];
|
|
|
|
|
|
+ float clamp_val = -INFINITY;
|
|
|
if (with_norm) {
|
|
|
- launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
+ if (clamp) {
|
|
|
+ clamp_val = ggml_get_op_params_f32(clamp, 0);
|
|
|
+ }
|
|
|
+ launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
|
|
|
} else {
|
|
|
+ GGML_ASSERT(clamp == nullptr);
|
|
|
if (delayed_softmax) {
|
|
|
- launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
+ launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
|
|
+ clamp_val);
|
|
|
} else {
|
|
|
- launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
+ launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
|
|
+ clamp_val);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
|
|
|
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
|
|
|
float scale = 1.0f;
|
|
|
float max_bias = 0.0f;
|
|
|
|
|
|
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+ if (clamp) {
|
|
|
+ if (clamp->op != GGML_OP_CLAMP) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ float max_val = ggml_get_op_params_f32(clamp, 1);
|
|
|
+
|
|
|
+ if (max_val != INFINITY) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
|
|
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
|
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
|
|
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
|
|
+ GGML_OP_RESHAPE };
|
|
|
|
|
|
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|