mul_mat_vec_base.comp 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #extension GL_EXT_control_flow_attributes : enable
  2. #extension GL_EXT_shader_16bit_storage : require
  3. #extension GL_EXT_shader_8bit_storage : require
  4. #define K_QUANTS_PER_ITERATION 2
  5. #ifdef MUL_MAT_ID
  6. #define EXPERT_COUNT 8
  7. #endif
  8. #include "types.comp"
  9. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  10. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  11. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  12. #ifdef MUL_MAT_ID
  13. layout (binding = 3) readonly buffer IDS {int data_ids[];};
  14. #endif
  15. #include "dequant_funcs.comp"
  16. layout (push_constant) uniform parameter
  17. {
  18. uint ncols;
  19. uint stride_a;
  20. uint stride_b;
  21. uint stride_d;
  22. uint batch_stride_a;
  23. uint batch_stride_b;
  24. uint batch_stride_d;
  25. #ifdef MUL_MAT_ID
  26. uint nei0;
  27. uint ne11;
  28. #else
  29. uint ne02;
  30. uint ne12;
  31. uint broadcast2;
  32. uint broadcast3;
  33. #endif
  34. } p;
  35. void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
  36. #ifdef MUL_MAT_ID
  37. const uint expert_idx = gl_GlobalInvocationID.y;
  38. #else
  39. const uint batch_idx = gl_GlobalInvocationID.y;
  40. #endif
  41. #ifndef MUL_MAT_ID
  42. const uint i13 = batch_idx / p.ne12;
  43. const uint i12 = batch_idx % p.ne12;
  44. const uint i03 = i13 / p.broadcast3;
  45. const uint i02 = i12 / p.broadcast2;
  46. const uint batch_idx_a = i03 * p.ne02 + i02;
  47. #else
  48. const uint expert_id = data_ids[expert_idx];
  49. #endif
  50. a_offset =
  51. #ifdef MUL_MAT_ID
  52. expert_id * p.batch_stride_a;
  53. #else
  54. batch_idx_a * p.batch_stride_a;
  55. #endif
  56. b_offset =
  57. #ifdef MUL_MAT_ID
  58. (expert_idx % p.ne11) * p.stride_b;
  59. #else
  60. batch_idx * p.batch_stride_b;
  61. #endif
  62. d_offset =
  63. #ifdef MUL_MAT_ID
  64. expert_idx * p.stride_d;
  65. #else
  66. batch_idx * p.batch_stride_d;
  67. #endif
  68. }