1
0

op_norm.comp 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #version 450
  2. #include "common.comp"
  3. layout(local_size_x = 256) in;
  4. layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
  5. layout(binding = 1) buffer restrict tensorOut { float out_[]; };
  6. layout(push_constant) uniform PushConstants {
  7. uint inOff;
  8. uint outOff;
  9. uint ne00;
  10. uint nb01;
  11. float eps;
  12. } pcs;
  13. shared float sum[gl_WorkGroupSize.x];
  14. void main() {
  15. const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
  16. // MEAN
  17. // parallel sum
  18. sum[gl_LocalInvocationID.x] = 0.0;
  19. for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
  20. sum[gl_LocalInvocationID.x] += in_[x+i00];
  21. }
  22. // reduce
  23. barrier();
  24. memoryBarrierShared();
  25. [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
  26. if (gl_LocalInvocationID.x < i) {
  27. sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
  28. }
  29. barrier();
  30. memoryBarrierShared();
  31. }
  32. // broadcast
  33. if (gl_LocalInvocationID.x == 0) {
  34. sum[0] /= float(pcs.ne00);
  35. }
  36. barrier();
  37. memoryBarrierShared();
  38. const float mean = sum[0];
  39. // recenter
  40. const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
  41. for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
  42. out_[y+i00] = in_[x+i00] - mean;
  43. }
  44. // VARIANCE
  45. // parallel sum
  46. sum[gl_LocalInvocationID.x] = 0.0;
  47. for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
  48. sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
  49. }
  50. // reduce
  51. barrier();
  52. memoryBarrierShared();
  53. [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
  54. if (gl_LocalInvocationID.x < i) {
  55. sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
  56. }
  57. barrier();
  58. memoryBarrierShared();
  59. }
  60. // broadcast
  61. if (gl_LocalInvocationID.x == 0) {
  62. sum[0] /= float(pcs.ne00);
  63. }
  64. barrier();
  65. memoryBarrierShared();
  66. const float variance = sum[0];
  67. const float scale = 1.0f/sqrt(variance + pcs.eps);
  68. for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
  69. out_[y+i00] *= scale;
  70. }
  71. }