|
|
@@ -101,6 +101,10 @@ void main() {
|
|
|
const uint lane = gl_SubgroupInvocationID;
|
|
|
|
|
|
float probs[experts_per_thread];
|
|
|
+ [[unroll]]
|
|
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
|
|
+ probs[i] = -INFINITY;
|
|
|
+ }
|
|
|
|
|
|
[[unroll]]
|
|
|
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
|
|
@@ -112,8 +116,9 @@ void main() {
|
|
|
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
|
|
|
} else if (gating_func == GATING_FUNC_SIGMOID) {
|
|
|
[[unroll]]
|
|
|
- for (int i = 0; i < experts_per_thread; i++) {
|
|
|
- probs[i] = 1.f / (1.f + exp(-probs[i]));
|
|
|
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
|
|
+ const uint expert = i + lane;
|
|
|
+ probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -150,11 +155,11 @@ void main() {
|
|
|
uint max_expert = lane;
|
|
|
|
|
|
[[unroll]]
|
|
|
- for (int i = 1; i < experts_per_thread; i++) {
|
|
|
- const uint expert = lane + i * WARP_SIZE;
|
|
|
- if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
|
|
|
- max_val = probs[i];
|
|
|
- max_val_s = selection_probs[i];
|
|
|
+ for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
|
|
|
+ const uint expert = i + lane;
|
|
|
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
|
|
|
+ max_val = probs[i / WARP_SIZE];
|
|
|
+ max_val_s = selection_probs[i / WARP_SIZE];
|
|
|
max_expert = expert;
|
|
|
}
|
|
|
}
|