op_mul_mat_q8_0.comp 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #version 450
  2. #include "common.comp"
  3. #include "op_mul_mv_q_n_pre.comp"
  4. #define SIZE_OF_D 2
  5. #define N_DST 4 // each SIMD group works on 4 rows
  6. #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
  7. #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
  8. #define NB_Q8_0 8
  9. void main() {
  10. // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
  11. if (gl_SubgroupInvocationID > 31)
  12. return;
  13. const int nr = N_DST;
  14. const int nsg = N_SIMDGROUP;
  15. const int nw = N_SIMDWIDTH;
  16. const int nb = pcs.ne00/QK8_0;
  17. const uint r0 = gl_WorkGroupID.x;
  18. const uint r1 = gl_WorkGroupID.y;
  19. const uint im = gl_WorkGroupID.z;
  20. const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
  21. const uint i12 = im%pcs.ne12;
  22. const uint i13 = im/pcs.ne12;
  23. const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
  24. const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
  25. const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
  26. float yl[NB_Q8_0];
  27. float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
  28. const uint ix = gl_SubgroupInvocationID.x/4;
  29. const uint il = gl_SubgroupInvocationID.x%4;
  30. uint yb = y + ix * QK8_0 + NB_Q8_0*il;
  31. // each thread in a SIMD group deals with NB_Q8_0 quants at a time
  32. for (uint ib = ix; ib < nb; ib += nw/4) {
  33. for (int i = 0; i < NB_Q8_0; ++i) {
  34. yl[i] = inB[yb + i];
  35. }
  36. for (int row = 0; row < nr; row++) {
  37. const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
  38. float sumq = 0.f;
  39. for (int iq = 0; iq < NB_Q8_0; ++iq) {
  40. const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
  41. sumq += qs_iq * yl[iq];
  42. }
  43. const float16_t d = u8BufToFloat16(inA, x + block_offset);
  44. sumf[row] += sumq*d;
  45. }
  46. yb += NB_Q8_0 * nw;
  47. }
  48. for (int row = 0; row < nr; ++row) {
  49. const float tot = subgroupAdd(sumf[row]);
  50. if (subgroupElect() && first_row + row < pcs.ne01) {
  51. out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
  52. }
  53. }
  54. }