|
|
@@ -801,7 +801,7 @@ void main() {
|
|
|
}
|
|
|
#else
|
|
|
const uint row_i = ic * BN + loadc_b + l;
|
|
|
- if (row_i < _ne1) {
|
|
|
+ if (row_i < _ne1 && block + loadr_b < end_k) {
|
|
|
const u16vec2 row_idx = row_ids[row_i];
|
|
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
|
|
} else {
|
|
|
@@ -875,7 +875,9 @@ void main() {
|
|
|
|
|
|
const u16vec2 row_idx = row_ids[row_i];
|
|
|
|
|
|
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
|
+ if (dr + cm_row * TM + store_r < p.M) {
|
|
|
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -925,7 +927,9 @@ void main() {
|
|
|
#endif // MUL_MAT_ID
|
|
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
|
|
#ifdef MUL_MAT_ID
|
|
|
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
|
|
+ if (dr_warp + cr < p.M) {
|
|
|
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
|
|
+ }
|
|
|
#else
|
|
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
|
|
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|