| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- #version 450
- #include "common.comp"
- #include "op_mul_mv_q_n_pre.comp"
- #define SIZE_OF_D 2
- #define N_DST 4 // each SIMD group works on 4 rows
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
- #define NB_Q8_0 8
- void main() {
- // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
- if (gl_SubgroupInvocationID > 31)
- return;
- const int nr = N_DST;
- const int nsg = N_SIMDGROUP;
- const int nw = N_SIMDWIDTH;
- const int nb = pcs.ne00/QK8_0;
- const uint r0 = gl_WorkGroupID.x;
- const uint r1 = gl_WorkGroupID.y;
- const uint im = gl_WorkGroupID.z;
- const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
- const uint i12 = im%pcs.ne12;
- const uint i13 = im/pcs.ne12;
- const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
- const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
- const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
- float yl[NB_Q8_0];
- float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
- const uint ix = gl_SubgroupInvocationID.x/4;
- const uint il = gl_SubgroupInvocationID.x%4;
- uint yb = y + ix * QK8_0 + NB_Q8_0*il;
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
- for (uint ib = ix; ib < nb; ib += nw/4) {
- for (int i = 0; i < NB_Q8_0; ++i) {
- yl[i] = inB[yb + i];
- }
- for (int row = 0; row < nr; row++) {
- const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
- float sumq = 0.f;
- for (int iq = 0; iq < NB_Q8_0; ++iq) {
- const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
- sumq += qs_iq * yl[iq];
- }
- const float16_t d = u8BufToFloat16(inA, x + block_offset);
- sumf[row] += sumq*d;
- }
- yb += NB_Q8_0 * nw;
- }
- for (int row = 0; row < nr; ++row) {
- const float tot = subgroupAdd(sumf[row]);
- if (subgroupElect() && first_row + row < pcs.ne01) {
- out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
- }
- }
- }
|