|
@@ -1,5 +1,5 @@
|
|
|
#version 450
|
|
#version 450
|
|
|
-#extension GL_EXT_shader_explicit_arithmetic_types : require
|
|
|
|
|
|
|
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
|
|
|
|
|
|
#include "mul_mat_vec_base.comp"
|
|
#include "mul_mat_vec_base.comp"
|
|
|
|
|
|
|
@@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|
|
|
|
|
|
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
|
|
- f16vec2 d = data_a[ib0 + i].d;
|
|
|
|
|
- const FLOAT_TYPE dall = d.x;
|
|
|
|
|
- const FLOAT_TYPE dmin = d.y;
|
|
|
|
|
|
|
+ vec2 d = vec2(data_a[ib0 + i].d);
|
|
|
|
|
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
|
|
|
|
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
|
|
|
|
|
|
|
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
|
|
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
|
|
|
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
|
|
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
|
|
@@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|
|
uvec2 qs16 = uvec2(unpack8(qs16_u16));
|
|
uvec2 qs16 = uvec2(unpack8(qs16_u16));
|
|
|
|
|
|
|
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
|
|
- B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
|
|
|
|
|
- B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
|
|
|
|
- B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
|
|
|
|
- B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
|
|
|
|
- B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
|
|
|
|
- B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
|
|
|
|
- B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
|
|
|
|
- B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
|
|
|
|
|
|
+ vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
|
|
|
|
+ vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
|
|
|
|
|
+ vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
|
|
|
|
|
+ vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
|
|
|
|
|
+ vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
|
|
|
|
|
+ vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
|
|
|
|
|
+ vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
|
|
|
|
|
+ vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
|
|
|
|
|
|
|
|
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
|
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
|
|
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
|
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|