|
@@ -19,6 +19,7 @@
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
#include "types.comp"
|
|
#include "types.comp"
|
|
|
|
|
+#include "utils.comp"
|
|
|
|
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
|
|
|
@@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
uint _ne1;
|
|
uint _ne1;
|
|
|
-shared uint _ne1_sh;
|
|
|
|
|
|
|
+layout (constant_id = 5) const uint subgroup_size = 32;
|
|
|
|
|
+shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
|
|
|
|
|
|
|
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
|
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
|
|
{
|
|
{
|
|
@@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
|
|
return elem;
|
|
return elem;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
|
|
|
|
+ _ne1 = 0;
|
|
|
|
|
+ uint num_elements = p.nei1 * p.nei0;
|
|
|
|
|
+ uint nei0shift = findLSB(p.nei0);
|
|
|
|
|
+
|
|
|
|
|
+ uint ids[16];
|
|
|
|
|
+ uint iter = 0;
|
|
|
|
|
+
|
|
|
|
|
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
|
|
|
|
+ // prefetch up to 16 elements
|
|
|
|
|
+ if (iter == 0) {
|
|
|
|
|
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
|
|
|
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
|
|
|
|
+ bool in_range = i < num_elements;
|
|
|
|
|
+ uint ii1;
|
|
|
|
|
+ if (nei0_is_pow2) {
|
|
|
|
|
+ ii1 = i >> nei0shift;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ ii1 = i / p.nei0;
|
|
|
|
|
+ }
|
|
|
|
|
+ uint ii0 = i - ii1 * p.nei0;
|
|
|
|
|
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ uint i = j + gl_LocalInvocationIndex;
|
|
|
|
|
+ bool in_range = i < num_elements;
|
|
|
|
|
+ uint ii1;
|
|
|
|
|
+ if (nei0_is_pow2) {
|
|
|
|
|
+ ii1 = i >> nei0shift;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ ii1 = i / p.nei0;
|
|
|
|
|
+ }
|
|
|
|
|
+ uint ii0 = i - ii1 * p.nei0;
|
|
|
|
|
+ uint id = ids[iter++];
|
|
|
|
|
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
|
|
|
+
|
|
|
|
|
+ ballots_sh[gl_SubgroupID] = ballot;
|
|
|
|
|
+ barrier();
|
|
|
|
|
+
|
|
|
|
|
+ uint subgroup_base = 0;
|
|
|
|
|
+ uint total = 0;
|
|
|
|
|
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
|
|
|
|
+ if (k == gl_SubgroupID) {
|
|
|
|
|
+ subgroup_base = total;
|
|
|
|
|
+ }
|
|
|
|
|
+ total += subgroupBallotBitCount(ballots_sh[k]);
|
|
|
|
|
+ }
|
|
|
|
|
+ barrier();
|
|
|
|
|
+
|
|
|
|
|
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
|
|
|
|
+ if (in_range && id == expert_idx) {
|
|
|
|
|
+ row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
|
|
|
|
+ }
|
|
|
|
|
+ _ne1 += total;
|
|
|
|
|
+ iter &= 15;
|
|
|
|
|
+ }
|
|
|
|
|
+ barrier();
|
|
|
|
|
+}
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
void main() {
|
|
void main() {
|
|
@@ -157,45 +217,12 @@ void main() {
|
|
|
const uint ic = gl_WorkGroupID.y;
|
|
const uint ic = gl_WorkGroupID.y;
|
|
|
|
|
|
|
|
#ifdef MUL_MAT_ID
|
|
#ifdef MUL_MAT_ID
|
|
|
- // Spread the search across all elements in the first subgroup
|
|
|
|
|
- if (gl_SubgroupID == 0) {
|
|
|
|
|
- _ne1 = 0;
|
|
|
|
|
- uint num_elements = p.nei1 * p.nei0;
|
|
|
|
|
-
|
|
|
|
|
- uint ids[16];
|
|
|
|
|
- uint iter = 0;
|
|
|
|
|
-
|
|
|
|
|
- for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
|
|
|
|
- // prefetch up to 16 elements
|
|
|
|
|
- if (iter == 0) {
|
|
|
|
|
- [[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
|
|
|
- uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
|
|
|
|
- bool in_range = i < num_elements;
|
|
|
|
|
- uint ii1 = i / p.nei0;
|
|
|
|
|
- uint ii0 = i % p.nei0;
|
|
|
|
|
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- uint i = j + gl_SubgroupInvocationID;
|
|
|
|
|
- bool in_range = i < num_elements;
|
|
|
|
|
- uint ii1 = i / p.nei0;
|
|
|
|
|
- uint ii0 = i % p.nei0;
|
|
|
|
|
- uint id = ids[iter++];
|
|
|
|
|
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
|
|
|
- uint idx = subgroupBallotExclusiveBitCount(ballot);
|
|
|
|
|
- if (in_range && id == expert_idx) {
|
|
|
|
|
- row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
|
|
|
|
- }
|
|
|
|
|
- _ne1 += subgroupBallotBitCount(ballot);
|
|
|
|
|
- iter &= 15;
|
|
|
|
|
- }
|
|
|
|
|
- _ne1_sh = _ne1;
|
|
|
|
|
|
|
+ if (bitCount(p.nei0) == 1) {
|
|
|
|
|
+ load_row_ids(expert_idx, true);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ load_row_ids(expert_idx, false);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- barrier();
|
|
|
|
|
-
|
|
|
|
|
- _ne1 = _ne1_sh;
|
|
|
|
|
-
|
|
|
|
|
// Workgroup has no work
|
|
// Workgroup has no work
|
|
|
if (ic * BN >= _ne1) return;
|
|
if (ic * BN >= _ne1) return;
|
|
|
#endif
|
|
#endif
|