|
|
@@ -4,16 +4,61 @@
|
|
|
|
|
|
#include <initializer_list>
|
|
|
|
|
|
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
|
|
+template <int experts_per_thread, bool use_limit>
|
|
|
+__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
|
|
+ float max_val = -INFINITY;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
|
|
+ const int idx = lane + i * WARP_SIZE;
|
|
|
+ const bool active = !use_limit || (idx < limit);
|
|
|
+ if (active) {
|
|
|
+ max_val = max(max_val, vals[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ max_val = warp_reduce_max(max_val);
|
|
|
+
|
|
|
+ float sum = 0.f;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
|
|
+ const int idx = lane + i * WARP_SIZE;
|
|
|
+ const bool active = !use_limit || (idx < limit);
|
|
|
+ if (active) {
|
|
|
+ const float val = expf(vals[i] - max_val);
|
|
|
+ vals[i] = val;
|
|
|
+ sum += val;
|
|
|
+ } else {
|
|
|
+ vals[i] = 0.f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ sum = warp_reduce_sum(sum);
|
|
|
+
|
|
|
+ const float inv_sum = 1.0f / sum;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
|
|
+ const int idx = lane + i * WARP_SIZE;
|
|
|
+ const bool active = !use_limit || (idx < limit);
|
|
|
+ if (active) {
|
|
|
+ vals[i] *= inv_sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
/*
|
|
|
This kernel does the following:
|
|
|
- 1. softmax over the logits per token [n_experts, n_tokens]
|
|
|
+ 1. optionally softmax over the logits per token [n_experts, n_tokens]
|
|
|
2. argmax reduce over the top-k (n_experts_used) logits
|
|
|
3. write weights + ids to global memory
|
|
|
- 4. optionally normalize the weights
|
|
|
+ 4. optionally normalize the weights or apply softmax over the selected logits
|
|
|
|
|
|
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
|
|
*/
|
|
|
-template <int n_experts, bool with_norm>
|
|
|
+template <int n_experts, bool with_norm, bool delayed_softmax = false>
|
|
|
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
|
|
float * weights,
|
|
|
int32_t * ids,
|
|
|
@@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
|
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
|
|
|
|
|
- float logits_r[experts_per_thread];
|
|
|
+ float wt[experts_per_thread];
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
|
|
- const int expert = i + threadIdx.x;
|
|
|
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
|
|
|
+ const int expert = i + threadIdx.x;
|
|
|
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
|
|
|
}
|
|
|
|
|
|
- float max_val = logits_r[0];
|
|
|
-
|
|
|
-#pragma unroll
|
|
|
- for (int i = 1; i < experts_per_thread; i++) {
|
|
|
- const float val = logits_r[i];
|
|
|
- max_val = max(val, max_val);
|
|
|
+ if constexpr (!delayed_softmax) {
|
|
|
+ softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
|
|
}
|
|
|
|
|
|
- max_val = warp_reduce_max(max_val);
|
|
|
-
|
|
|
- float wt[experts_per_thread];
|
|
|
- float tmp = 0.f;
|
|
|
-
|
|
|
-#pragma unroll
|
|
|
- for (int i = 0; i < experts_per_thread; i++) {
|
|
|
- const float val = logits_r[i];
|
|
|
- wt[i] = expf(val - max_val);
|
|
|
- tmp += wt[i];
|
|
|
- }
|
|
|
+ //at this point, each thread holds either a portion of the softmax distribution
|
|
|
+ //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
|
|
|
+ //the expert weight as -inf to exclude from the next iteration
|
|
|
|
|
|
- tmp = warp_reduce_sum(tmp);
|
|
|
+ float wt_sum = 0.f;
|
|
|
|
|
|
- const float inv_sum = 1.0f / tmp;
|
|
|
+ float output_weights[experts_per_thread];
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
|
- wt[i] = wt[i] * inv_sum;
|
|
|
+ output_weights[i] = 0.f;
|
|
|
}
|
|
|
|
|
|
- //at this point, each thread holds a portion of softmax,
|
|
|
- //we do the argmax reduce over n_expert_used, each time marking
|
|
|
- //the expert weight as -inf to exclude from the next iteration
|
|
|
-
|
|
|
- float wt_sum = 0.f;
|
|
|
-
|
|
|
- float output_weights[experts_per_thread];
|
|
|
-
|
|
|
for (int k = 0; k < n_expert_used; k++) {
|
|
|
float max_val = wt[0];
|
|
|
int max_expert = threadIdx.x;
|
|
|
@@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if constexpr (delayed_softmax) {
|
|
|
+ softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
|
|
|
+ }
|
|
|
+
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
|
const int idx = i * WARP_SIZE + threadIdx.x;
|
|
|
@@ -130,7 +159,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template <bool with_norm>
|
|
|
+template <bool with_norm, bool delayed_softmax = false>
|
|
|
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
const float * logits,
|
|
|
float * weights,
|
|
|
@@ -138,6 +167,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
const int n_rows,
|
|
|
const int n_expert,
|
|
|
const int n_expert_used) {
|
|
|
+ static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
|
|
+
|
|
|
const int rows_per_block = 4;
|
|
|
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
|
|
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
|
|
@@ -145,43 +176,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
|
|
|
switch (n_expert) {
|
|
|
case 1:
|
|
|
- topk_moe_cuda<1, with_norm>
|
|
|
+ topk_moe_cuda<1, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 2:
|
|
|
- topk_moe_cuda<2, with_norm>
|
|
|
+ topk_moe_cuda<2, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 4:
|
|
|
- topk_moe_cuda<4, with_norm>
|
|
|
+ topk_moe_cuda<4, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 8:
|
|
|
- topk_moe_cuda<8, with_norm>
|
|
|
+ topk_moe_cuda<8, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 16:
|
|
|
- topk_moe_cuda<16, with_norm>
|
|
|
+ topk_moe_cuda<16, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 32:
|
|
|
- topk_moe_cuda<32, with_norm>
|
|
|
+ topk_moe_cuda<32, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 64:
|
|
|
- topk_moe_cuda<64, with_norm>
|
|
|
+ topk_moe_cuda<64, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 128:
|
|
|
- topk_moe_cuda<128, with_norm>
|
|
|
+ topk_moe_cuda<128, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 256:
|
|
|
- topk_moe_cuda<256, with_norm>
|
|
|
+ topk_moe_cuda<256, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 512:
|
|
|
- topk_moe_cuda<512, with_norm>
|
|
|
+ topk_moe_cuda<512, with_norm, delayed_softmax>
|
|
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
default:
|
|
|
@@ -194,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|
|
const ggml_tensor * logits,
|
|
|
ggml_tensor * weights,
|
|
|
ggml_tensor * ids,
|
|
|
- const bool with_norm) {
|
|
|
+ const bool with_norm,
|
|
|
+ const bool delayed_softmax) {
|
|
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
|
@@ -202,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|
|
const int n_experts = logits->ne[0];
|
|
|
const int n_rows = logits->ne[1];
|
|
|
|
|
|
- const float * logits_d = (const float *) logits->src[0]->data;
|
|
|
+ const float * logits_d = (const float *) logits->data;
|
|
|
float * weights_d = (float *) weights->data;
|
|
|
int32_t * ids_d = (int32_t *) ids->data;
|
|
|
|
|
|
@@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|
|
if (with_norm) {
|
|
|
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
} else {
|
|
|
- launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
+ if (delayed_softmax) {
|
|
|
+ launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
+ } else {
|
|
|
+ launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -246,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
-std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
|
|
+std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
|
|
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
|
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
|
|
@@ -254,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
|
|
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
|
|
|
|
|
+ static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
|
|
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
|
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
|
|
+
|
|
|
+ GGML_ASSERT(!norm || !delayed_softmax);
|
|
|
+
|
|
|
+ if (delayed_softmax) {
|
|
|
+ return delayed_softmax_ops;
|
|
|
+ }
|
|
|
+
|
|
|
if (norm) {
|
|
|
return norm_ops;
|
|
|
}
|
|
|
+
|
|
|
return no_norm_ops;
|
|
|
}
|