op_rmsnorm.comp 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #version 450
  2. #include "common.comp"
  3. layout(local_size_x = 512) in;
  4. layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
  5. layout(binding = 1) buffer restrict tensorOut { float out_[]; };
  6. layout(push_constant) uniform PushConstants {
  7. uint inOff;
  8. uint outOff;
  9. uint ne00;
  10. uint nb01;
  11. float eps;
  12. } pcs;
  13. shared float sum[gl_WorkGroupSize.x];
  14. void main() {
  15. const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
  16. // parallel sum
  17. sum[gl_LocalInvocationID.x] = 0.0;
  18. for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
  19. sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
  20. }
  21. // reduce
  22. barrier();
  23. memoryBarrierShared();
  24. [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
  25. if (gl_LocalInvocationID.x < i) {
  26. sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
  27. }
  28. barrier();
  29. memoryBarrierShared();
  30. }
  31. // broadcast
  32. if (gl_LocalInvocationID.x == 0) {
  33. sum[0] /= float(pcs.ne00);
  34. }
  35. barrier();
  36. memoryBarrierShared();
  37. const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
  38. const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
  39. for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
  40. out_[y+i00] = in_[x+i00] * scale;
  41. }
  42. }