| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- #extension GL_EXT_control_flow_attributes : enable
- #extension GL_EXT_shader_16bit_storage : require
- #extension GL_EXT_shader_8bit_storage : require
- #define K_QUANTS_PER_ITERATION 2
- #ifdef MUL_MAT_ID
- #define EXPERT_COUNT 8
- #endif
- #include "types.comp"
- layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
- layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
- layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
- #ifdef MUL_MAT_ID
- layout (binding = 3) readonly buffer IDS {int data_ids[];};
- #endif
- #include "dequant_funcs.comp"
- layout (push_constant) uniform parameter
- {
- uint ncols;
- uint stride_a;
- uint stride_b;
- uint stride_d;
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
- #ifdef MUL_MAT_ID
- uint nei0;
- uint ne11;
- #else
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
- #endif
- } p;
- void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
- #ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.y;
- #else
- const uint batch_idx = gl_GlobalInvocationID.y;
- #endif
- #ifndef MUL_MAT_ID
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
- const uint batch_idx_a = i03 * p.ne02 + i02;
- #else
- const uint expert_id = data_ids[expert_idx];
- #endif
- a_offset =
- #ifdef MUL_MAT_ID
- expert_id * p.batch_stride_a;
- #else
- batch_idx_a * p.batch_stride_a;
- #endif
- b_offset =
- #ifdef MUL_MAT_ID
- (expert_idx % p.ne11) * p.stride_b;
- #else
- batch_idx * p.batch_stride_b;
- #endif
- d_offset =
- #ifdef MUL_MAT_ID
- expert_idx * p.stride_d;
- #else
- batch_idx * p.batch_stride_d;
- #endif
- }
|