Просмотр исходного кода

opencl: add kernel to handle mat mul in attention to improve encoding speed (#17181)

* Add mul_mm_f16_f32_kq_kqv kernel

* Add ggml_cl_mul_mat_kq_kqv_adreno func

* fix whitespace

* remove unused variable

* remove redundant

* refactor and clean up

* remove trailing whitespace
shaofeiqi 2 месяцев назад
Родитель
Сommit
4db5641210

+ 1 - 0
ggml/src/ggml-opencl/CMakeLists.txt

@@ -119,6 +119,7 @@ set(GGML_OPENCL_KERNELS
     pad
     repeat
     mul_mat_f16_f32
+    mul_mm_f16_f32_kq_kqv
     conv2d
     conv2d_f16_f32
     flash_attn_f32_f16

+ 170 - 0
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -407,6 +407,8 @@ struct ggml_backend_opencl_context {
     cl_program program_mul_mv_f32_f32;
     cl_program program_mul;
     cl_program program_mul_mat_f16_f32_tiled;
+    cl_program program_mul_mm_f16_f32_kqv;
+    cl_program program_mul_mm_f16_f32_kq;
     cl_program program_div;
     cl_program program_sub;
     cl_program program_norm;
@@ -481,6 +483,8 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mat_f16_f32;
     cl_kernel kernel_mul_mat_f16_f32_l4;
     cl_kernel kernel_mul_mat_f16_f32_tiled;
+    cl_kernel kernel_mul_mm_f16_f32_kqv;
+    cl_kernel kernel_mul_mm_f16_f32_kq;
     cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
     cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
     cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
@@ -1235,6 +1239,25 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_f16_f32_kq_kqv
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_f16_f32_kq_kqv.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl");
+#endif
+        backend_ctx->program_mul_mm_f16_f32_kqv =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV ");
+        backend_ctx->program_mul_mm_f16_f32_kq =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err));
+        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6665,6 +6688,146 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
     backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
 }
 
+static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    const int  ne00 = src0->ne[0];
+    const int  ne01 = src0->ne[1];
+    const int  ne02 = src0->ne[2];
+
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+
+    const int  ne10 = src1->ne[0];
+    const int  ne11 = src1->ne[1];
+    const int  ne12 = src1->ne[2];
+
+    const cl_ulong nb10 = src1->nb[0];
+
+    const int  ne0 = dst->ne[0];
+    const int  ne1 = dst->ne[1];
+
+    GGML_ASSERT(ne00 == ne10);
+
+    cl_kernel kernel;
+    cl_context context = backend_ctx->context;
+
+    cl_int              status;
+    cl_image_format     img_fmt_1d;
+    cl_image_desc       img_desc_1d;
+    cl_buffer_region    region;
+    cl_mem              A_image1d;
+    cl_mem              A_sub_buffer;
+    cl_mem              B_sub_buffer;
+    cl_mem              D_image1d;
+    cl_mem              D_sub_buffer;
+
+    int M = ne01;
+    int N = ne1;
+    int K = ne00;
+
+    if (nb01 > nb02) {
+        // KQ
+        kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;
+    } else {
+        // KQV
+        kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;
+    }
+    // create sub-buffer for A
+    // <--------------------------------------------> //
+    extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;
+
+    region.origin = (extra0->offset);
+    if (nb01 > nb02) {
+        // KQ
+        region.size = nb01 * ne01;
+    } else {
+        // KQV
+        region.size = nb02 * ne02;
+    }
+
+    A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
+    CL_CHECK(status);
+
+    // <--------------------------------------------> //
+
+    // create sub-buffer for B
+    // <--------------------------------------------> //
+    region.origin = (extra1->offset);
+    region.size = nb10 * ne10 * ne11 * ne12;
+    B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
+    CL_CHECK(status);
+    // <--------------------------------------------> //
+
+    img_fmt_1d = {CL_RGBA, CL_FLOAT};
+    memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+    if (nb01 > nb02) {
+        img_desc_1d.image_width = (nb01 * ne01 / 4)/4;
+    }
+    else {
+        img_desc_1d.image_width = (nb02 * ne02 / 4)/4;
+    }
+    img_desc_1d.buffer = A_sub_buffer;
+    A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+    CL_CHECK(status);
+
+    // create sub-buffer for output C
+    // <--------------------------------------------> //
+    region.origin = (extrad->offset);
+    region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes
+    D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
+    CL_CHECK(status);
+    // <--------------------------------------------> //
+
+    // create image for C output
+    // <--------------------------------------------> //
+    img_fmt_1d = {CL_R, CL_FLOAT};
+    memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+    img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;
+    img_desc_1d.buffer = D_sub_buffer;
+    D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+    CL_CHECK(status);
+    // <--------------------------------------------> //
+
+    int offset_src0 = 0;
+    int offset_src1 = 0;
+
+    // set kernel args
+    // <--------------------------------------------> //
+    cl_uint k_arg = 0;
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem), &A_image1d));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &offset_src0));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem), &B_sub_buffer));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &offset_src1));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem), &D_image1d));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &extrad->offset));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &M));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &K));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &N));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &ne12));
+    CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),    &nb01));
+
+    size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};
+    size_t local_work_size[3] = {64, 1, 2};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+
+    // deallocate sub buffers and images
+    // <--------------------------------------------> //
+    CL_CHECK(clReleaseMemObject(A_image1d));
+    CL_CHECK(clReleaseMemObject(D_image1d));
+    CL_CHECK(clReleaseMemObject(A_sub_buffer));
+    CL_CHECK(clReleaseMemObject(B_sub_buffer));
+    CL_CHECK(clReleaseMemObject(D_sub_buffer));
+}
+
 static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -6731,6 +6894,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
     cl_context context = backend_ctx->context;
 
+    if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
+        if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
+            ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
+            return;
+        }
+    }
+
     if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {
 
     // init CL objects

+ 273 - 0
ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl

@@ -0,0 +1,273 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+
+#define LM_FIRST_256B   0
+#define LM_SECOND_256B  64
+#define LM_THIRD_256B   128
+#define LM_FOURTH_256B  192
+
+
+inline float16 mm_load_a(
+    image1d_buffer_t matrix_A,
+    uint subMatrixAStartInElements,
+    int nb01,
+    int line_stride_matrix_A_in_bytes
+) {
+    __private float8 regA;
+    size_t sub_block_id_m = get_local_id(0);
+
+#ifdef KQV
+    uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
+#else // KQ
+    uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
+#endif
+
+    regA.s0123  = read_imagef(matrix_A, a_texCoord/4);
+    regA.s4567  = read_imagef(matrix_A, (a_texCoord+4)/4);
+
+    return convert_float16(as_half16(regA));
+}
+
+inline float4 alu_32(
+    float16 regA,
+    __local float4* matrix_B_vec
+) {
+
+    __private float4 rC = 0;
+    int i = get_sub_group_id() * 64;
+
+    rC += regA.s0  * matrix_B_vec[i];
+    rC += regA.s1  * matrix_B_vec[i + 16];
+    rC += regA.s4  * matrix_B_vec[i + 1];
+    rC += regA.s5  * matrix_B_vec[i + 17];
+    rC += regA.s8  * matrix_B_vec[i + 2];
+    rC += regA.s9  * matrix_B_vec[i + 18];
+    rC += regA.sc  * matrix_B_vec[i + 3];
+    rC += regA.sd  * matrix_B_vec[i + 19];
+
+    i += 32;
+
+    rC += regA.s2  * matrix_B_vec[i];
+     rC += regA.s3  * matrix_B_vec[i + 16];
+    rC += regA.s6  * matrix_B_vec[i + 1];
+    rC += regA.s7  * matrix_B_vec[i + 17];
+    rC += regA.sa  * matrix_B_vec[i + 2];
+    rC += regA.sb  * matrix_B_vec[i + 18];
+    rC += regA.se  * matrix_B_vec[i + 3];
+    rC += regA.sf  * matrix_B_vec[i + 19];
+
+    return rC;
+}
+
+inline float16 alu_16(
+    float16 regA,
+    __local float* matrix_B_local
+) {
+    float16 out;
+    __local float4* matrix_B_vec = (__local float4*)matrix_B_local;
+
+    out.s0123 = alu_32(regA, matrix_B_vec);
+    out.s4567 = alu_32(regA, matrix_B_vec + 4);
+    out.s89ab = alu_32(regA, matrix_B_vec + 8);
+    out.scdef = alu_32(regA, matrix_B_vec + 12);
+
+    return out;
+}
+
+inline void mm_mad(
+    __local float* matrix_B_local,
+    float16 regA,
+    float8 regB,
+    uint b_localOffsetInWords,
+    float16* regC0_ptr,
+    float16* regC1_ptr
+) {
+    int offset = b_localOffsetInWords + get_sub_group_id() * 256;
+
+    matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
+    matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
+    matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
+    matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
+
+    float16 add0 = alu_16(regA, matrix_B_local);
+    *regC0_ptr += add0;
+
+    matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
+    matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
+    matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
+    matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
+
+    float16 add1 = alu_16(regA, matrix_B_local);
+    *regC1_ptr += add1;
+}
+
+inline void mm_store_c_N(
+    __write_only image1d_buffer_t matrix_C,
+    float16 regC0,
+    float16 regC1,
+    uint subMatrixCStartInElements,
+    int line_stride_matrix_C_in_bytes,
+    int mask
+) {
+    size_t sub_block_id_m = get_local_id(0);
+
+    uint strideInWords     = line_stride_matrix_C_in_bytes/4;
+    uint c_coordInWords_0  = (subMatrixCStartInElements + sub_block_id_m);
+
+    uint c_coordInWords_1  = c_coordInWords_0 + 1  * strideInWords;
+    uint c_coordInWords_2  = c_coordInWords_0 + 2  * strideInWords;
+    uint c_coordInWords_3  = c_coordInWords_0 + 3  * strideInWords;
+    uint c_coordInWords_4  = c_coordInWords_0 + 4  * strideInWords;
+    uint c_coordInWords_5  = c_coordInWords_0 + 5  * strideInWords;
+    uint c_coordInWords_6  = c_coordInWords_0 + 6  * strideInWords;
+    uint c_coordInWords_7  = c_coordInWords_0 + 7  * strideInWords;
+    uint c_coordInWords_8  = c_coordInWords_0 + 8  * strideInWords;
+    uint c_coordInWords_9  = c_coordInWords_0 + 9  * strideInWords;
+    uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
+    uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
+    uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
+    uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
+    uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
+    uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
+    uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
+    uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
+    uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
+    uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
+    uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
+    uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
+    uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
+    uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
+    uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
+    uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
+    uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
+    uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
+    uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
+    uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
+    uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
+    uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
+
+    if (mask > 0)  { write_imagef(matrix_C, c_coordInWords_0, regC0.s0);  }
+    if (mask > 1)  { write_imagef(matrix_C, c_coordInWords_1, regC0.s1);  }
+    if (mask > 2)  { write_imagef(matrix_C, c_coordInWords_2, regC0.s2);  }
+    if (mask > 3)  { write_imagef(matrix_C, c_coordInWords_3, regC0.s3);  }
+    if (mask > 4)  { write_imagef(matrix_C, c_coordInWords_4, regC0.s4);  }
+    if (mask > 5)  { write_imagef(matrix_C, c_coordInWords_5, regC0.s5);  }
+    if (mask > 6)  { write_imagef(matrix_C, c_coordInWords_6, regC0.s6);  }
+    if (mask > 7)  { write_imagef(matrix_C, c_coordInWords_7, regC0.s7);  }
+    if (mask > 8)  { write_imagef(matrix_C, c_coordInWords_8, regC0.s8);  }
+    if (mask > 9)  { write_imagef(matrix_C, c_coordInWords_9, regC0.s9);  }
+    if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
+    if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
+    if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
+    if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
+    if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
+    if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
+    if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
+    if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
+    if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
+    if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
+    if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
+    if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
+    if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
+    if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
+    if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
+    if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
+    if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
+    if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
+    if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
+    if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
+    if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
+    if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
+}
+
+#define TILESIZE_K 16
+#define TILESIZE_M 64
+#define TILESIZE_N 32
+#ifdef KQV
+__kernel void mul_mm_f16_f32_kqv(
+#else
+__kernel void mul_mm_f16_f32_kq(
+#endif
+        __read_only  image1d_buffer_t matrix_A,
+        int offset0,
+        __global float* matrix_B,
+        int offset1,
+        __write_only image1d_buffer_t matrix_C,
+        int offsetd,
+        int M, int K, int N,
+        int D_A,
+        int D_B,
+        int nb01
+) {
+
+    uint block_id_m = get_global_id(1);
+    uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
+    uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
+
+    __private float16  regA;
+    __private float8   regB;
+    __private float16 regC0;
+    __private float16 regC1;
+
+    const uint col   = block_id_m * TILESIZE_M;
+    const uint row   = block_id_n * TILESIZE_N;
+    const uint depth_A = block_id_d / (D_B/D_A);
+    const uint depth_B = block_id_d;
+
+#ifdef KQV
+    int line_stride_matrix_A_in_bytes = nb01 * M;
+    int line_stride_matrix_B_in_bytes = K * N * 4;
+#else
+    int line_stride_matrix_A_in_bytes = K * D_A * 2;
+    int line_stride_matrix_B_in_bytes = K * D_B * 4;
+#endif
+
+    int line_stride_matrix_C_in_bytes = M * 4;
+
+    const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
+    const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
+
+    size_t sub_block_id_m = get_local_id(0);
+
+    uint b_localOffsetInWords = (sub_block_id_m/16)*16
+                           + ((((sub_block_id_m)>>0)&1)<<2)
+                           + ((((sub_block_id_m)>>1)&1)<<3)
+                           + ((((sub_block_id_m)>>2)&1)<<0)
+                           + ((((sub_block_id_m)>>3)&1)<<1);
+
+    uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
+    uint b_globalOffsetInWords00, b_globalOffsetInWords16;
+#ifdef KQV
+    b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
+    b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
+    uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
+    uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
+#else
+    b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
+    b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
+    uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
+    uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
+#endif
+
+    __local float matrix_B_local[1024];
+
+    for (uint step=0; step < K; step+=TILESIZE_K) {
+        size_t sub_block_id_m = get_local_id(0);
+        regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
+
+        uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
+        uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
+
+        regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
+        regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
+
+        mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, &regC0, &regC1);
+
+        subMatrixAStartInElements += TILESIZE_K;
+        subMatrixBStartInElements += TILESIZE_K;
+    }
+
+    uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
+    mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
+}
+