op_mul_mv_q_n.comp 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. void main() {
  2. // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
  3. if (gl_SubgroupInvocationID > 31)
  4. return;
  5. const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
  6. const uint r0 = gl_WorkGroupID.x;
  7. const uint r1 = gl_WorkGroupID.y;
  8. const uint im = gl_WorkGroupID.z;
  9. const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
  10. const uint i12 = im%pcs.ne12;
  11. const uint i13 = im/pcs.ne12;
  12. const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
  13. const uint x = offset0; // Based from inA without base offset
  14. const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
  15. float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
  16. const uint ix = gl_SubgroupInvocationID/2;
  17. const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
  18. uint yb = y + ix * BLOCKS_IN_QUANT + il;
  19. //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
  20. // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
  21. // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
  22. for (uint ib = ix; ib < nb; ib += 16) {
  23. for (int row = 0; row < N_ROWS; row++) {
  24. const uint block_index = x + ib + row * nb;
  25. sumf[row] += block_q_n_dot_y(block_index, yb, il);
  26. }
  27. yb += BLOCKS_IN_QUANT * 16;
  28. }
  29. for (int row = 0; row < N_ROWS; ++row) {
  30. const float tot = subgroupAdd(sumf[row]);
  31. if (first_row + row < pcs.ne01 && subgroupElect()) {
  32. out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
  33. }
  34. }
  35. }