|
|
@@ -52,13 +52,16 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
|
|
#endif
|
|
|
|
|
|
#ifndef MUL_MAT_ID
|
|
|
- const uint i13 = batch_idx / p.ne12;
|
|
|
- const uint i12 = batch_idx % p.ne12;
|
|
|
+ uint batch_idx_a = 0;
|
|
|
+ if (batch_idx != 0) {
|
|
|
+ const uint i13 = batch_idx / p.ne12;
|
|
|
+ const uint i12 = batch_idx % p.ne12;
|
|
|
|
|
|
- const uint i03 = i13 / p.broadcast3;
|
|
|
- const uint i02 = i12 / p.broadcast2;
|
|
|
+ const uint i03 = i13 / p.broadcast3;
|
|
|
+ const uint i02 = i12 / p.broadcast2;
|
|
|
|
|
|
- const uint batch_idx_a = i03 * p.ne02 + i02;
|
|
|
+ batch_idx_a = i03 * p.ne02 + i02;
|
|
|
+ }
|
|
|
#else
|
|
|
const uint expert_id = data_ids[expert_idx];
|
|
|
#endif
|