|
|
@@ -109,13 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
|
|
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
|
|
|
|
|
#ifdef MUL_MAT_ID
|
|
|
-shared u16vec2 row_ids[4096];
|
|
|
+shared u16vec2 row_ids[BN];
|
|
|
uint _ne1;
|
|
|
|
|
|
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
|
|
shared uvec4 ballots_sh[NUM_WARPS];
|
|
|
|
|
|
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
|
|
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
|
|
_ne1 = 0;
|
|
|
uint num_elements = p.nei1 * p.nei0;
|
|
|
uint nei0shift = findLSB(p.nei0);
|
|
|
@@ -165,11 +165,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
|
|
barrier();
|
|
|
|
|
|
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
|
|
- if (in_range && id == expert_idx) {
|
|
|
- row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
|
|
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
|
|
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
|
|
}
|
|
|
_ne1 += total;
|
|
|
iter &= 15;
|
|
|
+ if (_ne1 >= (ic + 1) * BN) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
barrier();
|
|
|
}
|
|
|
@@ -242,16 +245,18 @@ void main() {
|
|
|
#ifdef MUL_MAT_ID
|
|
|
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
|
|
if (bitCount(p.nei0) == 1) {
|
|
|
- load_row_ids(expert_idx, true);
|
|
|
+ load_row_ids(expert_idx, true, ic);
|
|
|
} else {
|
|
|
- load_row_ids(expert_idx, false);
|
|
|
+ load_row_ids(expert_idx, false, ic);
|
|
|
}
|
|
|
#else
|
|
|
_ne1 = 0;
|
|
|
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
|
|
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
|
|
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
|
|
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
|
|
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
|
|
- row_ids[_ne1] = u16vec2(ii0, ii1);
|
|
|
+ if (_ne1 >= ic * BN) {
|
|
|
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
|
|
+ }
|
|
|
_ne1++;
|
|
|
}
|
|
|
}
|
|
|
@@ -797,7 +802,7 @@ void main() {
|
|
|
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
|
|
#if LOAD_VEC_B == 8
|
|
|
#ifdef MUL_MAT_ID
|
|
|
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
|
|
+ const u16vec2 row_idx = row_ids[loadc_b + l];
|
|
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
|
#else
|
|
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
|
@@ -813,7 +818,7 @@ void main() {
|
|
|
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
|
|
#elif LOAD_VEC_B == 4
|
|
|
#ifdef MUL_MAT_ID
|
|
|
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
|
|
+ const u16vec2 row_idx = row_ids[loadc_b + l];
|
|
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
|
#else
|
|
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
|
@@ -832,7 +837,7 @@ void main() {
|
|
|
#else
|
|
|
const uint row_i = ic * BN + loadc_b + l;
|
|
|
if (row_i < _ne1 && block + loadr_b < end_k) {
|
|
|
- const u16vec2 row_idx = row_ids[row_i];
|
|
|
+ const u16vec2 row_idx = row_ids[loadc_b + l];
|
|
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
|
|
} else {
|
|
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
|
|
@@ -903,7 +908,7 @@ void main() {
|
|
|
const uint row_i = dc + cm_col * TN + col + store_c;
|
|
|
if (row_i >= _ne1) break;
|
|
|
|
|
|
- const u16vec2 row_idx = row_ids[row_i];
|
|
|
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
|
|
|
|
|
if (dr + cm_row * TM + store_r < p.M) {
|
|
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
|
@@ -953,7 +958,7 @@ void main() {
|
|
|
const uint row_i = dc_warp + cc;
|
|
|
if (row_i >= _ne1) break;
|
|
|
|
|
|
- const u16vec2 row_idx = row_ids[row_i];
|
|
|
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
|
|
#endif // MUL_MAT_ID
|
|
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
|
|
#ifdef MUL_MAT_ID
|