rms_norm.comp 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #version 450
  2. #include "generic_head.comp"
  3. #include "types.comp"
  4. #extension GL_EXT_control_flow_attributes : enable
  5. #define BLOCK_SIZE 512
  6. layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
  7. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  8. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  9. shared FLOAT_TYPE sum[BLOCK_SIZE];
  10. void main() {
  11. const uint row = gl_WorkGroupID.x;
  12. const uint tid = gl_LocalInvocationID.x;
  13. sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
  14. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  15. const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
  16. sum[tid] += xi * xi;
  17. }
  18. // sum up partial sums and write back result
  19. barrier();
  20. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  21. if (tid < s) {
  22. sum[tid] += sum[tid + s];
  23. }
  24. barrier();
  25. }
  26. const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
  27. const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
  28. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  29. data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
  30. }
  31. }