|
|
@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
|
float wt_sum = 0.f;
|
|
|
|
|
|
- extern __shared__ float data_topk_shared[];
|
|
|
- float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
|
|
|
+ float output_weights[experts_per_thread];
|
|
|
|
|
|
for (int k = 0; k < n_expert_used; k++) {
|
|
|
float max_val = wt[0];
|
|
|
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
|
|
|
+ output_weights[k / WARP_SIZE] = max_val;
|
|
|
+ }
|
|
|
+
|
|
|
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
|
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
|
|
|
|
|
- wt_shared_ptr[k] = max_val;
|
|
|
- ids[k] = max_expert;
|
|
|
+ ids[k] = max_expert;
|
|
|
if constexpr (with_norm) {
|
|
|
wt_sum += max_val;
|
|
|
}
|
|
|
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
const float inv_sum = 1.0f / wt_sum;
|
|
|
|
|
|
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
|
|
- wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
|
|
|
+ output_weights[i] *= inv_sum;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
|
|
- weights[i] = wt_shared_ptr[i];
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
|
|
+ const int idx = i * WARP_SIZE + threadIdx.x;
|
|
|
+ if (idx < n_expert_used) {
|
|
|
+ weights[idx] = output_weights[i];
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
- const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
|
|
|
-
|
|
|
switch (n_expert) {
|
|
|
case 1:
|
|
|
topk_moe_cuda<1, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 2:
|
|
|
topk_moe_cuda<2, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 4:
|
|
|
topk_moe_cuda<4, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 8:
|
|
|
topk_moe_cuda<8, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 16:
|
|
|
topk_moe_cuda<16, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 32:
|
|
|
topk_moe_cuda<32, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 64:
|
|
|
topk_moe_cuda<64, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 128:
|
|
|
topk_moe_cuda<128, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 256:
|
|
|
topk_moe_cuda<256, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
case 512:
|
|
|
topk_moe_cuda<512, with_norm>
|
|
|
- <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
+ <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
break;
|
|
|
default:
|
|
|
GGML_ASSERT(false && "fatal error");
|