1
0

op_softmax.comp 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. // TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
  2. #version 450
  3. #include "common.comp"
  4. layout(local_size_x_id = 0) in;
  5. layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
  6. layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
  7. layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
  8. layout(push_constant) uniform PushConstants {
  9. uint inAOff;
  10. uint inBOff;
  11. uint outOff;
  12. int ne00;
  13. int ne01;
  14. int ne02;
  15. float scale;
  16. int mask;
  17. } pcs;
  18. void main() {
  19. if (gl_SubgroupInvocationID > 31)
  20. return;
  21. const uint i03 = gl_WorkGroupID.z;
  22. const uint i02 = gl_WorkGroupID.y;
  23. const uint i01 = gl_WorkGroupID.x;
  24. const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
  25. const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
  26. const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
  27. const uint pdst = extra_off + pcs.outOff; // Based from out_
  28. // parallel max
  29. float localMax = uintBitsToFloat(0xFF800000);
  30. for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
  31. localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
  32. }
  33. float max_ = subgroupMax(localMax);
  34. // parallel sum
  35. float localSum = 0.0f;
  36. for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
  37. const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
  38. localSum += exp_psrc0;
  39. out_[pdst + i00] = exp_psrc0;
  40. }
  41. const float sum = subgroupAdd(localSum);
  42. for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
  43. out_[pdst + i00] /= sum;
  44. }
  45. }