|
@@ -347,9 +347,9 @@ kernel void kernel_soft_max(
|
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
|
|
|
|
|
|
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
|
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
|
|
|
|
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
|
|
|
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
|
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
|
|
|
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
|
|
|
|
// parallel max
|
|
// parallel max
|
|
|
float lmax = -INFINITY;
|
|
float lmax = -INFINITY;
|
|
@@ -385,7 +385,12 @@ kernel void kernel_soft_max(
|
|
|
pdst[i00] = exp_psrc0;
|
|
pdst[i00] = exp_psrc0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // This barrier fixes a failing test
|
|
|
|
|
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
+
|
|
|
float sum = simd_sum(lsum);
|
|
float sum = simd_sum(lsum);
|
|
|
|
|
+
|
|
|
if (ntg > N_SIMDWIDTH) {
|
|
if (ntg > N_SIMDWIDTH) {
|
|
|
if (sgitg == 0) {
|
|
if (sgitg == 0) {
|
|
|
buf[tiisg] = 0.0f;
|
|
buf[tiisg] = 0.0f;
|
|
@@ -428,9 +433,9 @@ kernel void kernel_soft_max_4(
|
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
|
|
|
|
|
|
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
|
|
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
|
|
|
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
|
|
|
|
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
|
|
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
|
|
|
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
|
|
|
|
|
// parallel max
|
|
// parallel max
|
|
|
float4 lmax4 = -INFINITY;
|
|
float4 lmax4 = -INFINITY;
|
|
@@ -468,7 +473,13 @@ kernel void kernel_soft_max_4(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
|
|
|
+
|
|
|
|
|
+ // This barrier fixes a failing test
|
|
|
|
|
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
+
|
|
|
float sum = simd_sum(lsum);
|
|
float sum = simd_sum(lsum);
|
|
|
|
|
+
|
|
|
if (ntg > N_SIMDWIDTH) {
|
|
if (ntg > N_SIMDWIDTH) {
|
|
|
if (sgitg == 0) {
|
|
if (sgitg == 0) {
|
|
|
buf[tiisg] = 0.0f;
|
|
buf[tiisg] = 0.0f;
|
|
@@ -731,7 +742,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
// giard against the number of rows not being divisible by
|
|
// giard against the number of rows not being divisible by
|
|
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
|
-void mul_vec_q_n_f32(
|
|
|
|
|
|
|
+void mul_vec_q_n_f32_impl(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -813,7 +824,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -832,7 +843,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -851,7 +862,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -870,28 +881,28 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
#define NB_Q8_0 8
|
|
#define NB_Q8_0 8
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_q8_0_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_q8_0_f32_impl(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
- constant int64_t & ne01[[buffer(4)]],
|
|
|
|
|
- constant int64_t & ne02[[buffer(5)]],
|
|
|
|
|
- constant int64_t & ne10[[buffer(9)]],
|
|
|
|
|
- constant int64_t & ne12[[buffer(11)]],
|
|
|
|
|
- constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
- constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
const int nr = N_DST;
|
|
const int nr = N_DST;
|
|
|
const int nsg = N_SIMDGROUP;
|
|
const int nsg = N_SIMDGROUP;
|
|
|
const int nw = N_SIMDWIDTH;
|
|
const int nw = N_SIMDWIDTH;
|
|
@@ -945,9 +956,29 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_q8_0_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_q8_0_f32(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2 [[buffer(17)]],
|
|
|
|
|
+ constant uint & r3 [[buffer(18)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
#define N_F32_F32 4
|
|
#define N_F32_F32 4
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_f32_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_f32_f32_impl(
|
|
|
device const char * src0,
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -965,8 +996,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
constant uint64_t & nb12,
|
|
constant uint64_t & nb12,
|
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne0,
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
|
|
|
|
@@ -1025,6 +1056,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_f32_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_f32_f32(
|
|
|
|
|
+ device const char * src0,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2 [[buffer(17)]],
|
|
|
|
|
+ constant uint & r3 [[buffer(18)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
|
|
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
#define N_F16_F16 4
|
|
#define N_F16_F16 4
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_f16_f16(
|
|
kernel void kernel_mul_mv_f16_f16(
|
|
@@ -1105,7 +1162,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
|
|
|
|
+void kernel_mul_mv_f16_f32_1row_impl(
|
|
|
device const char * src0,
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -1123,8 +1180,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
constant uint64_t & nb12,
|
|
constant uint64_t & nb12,
|
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne0,
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
|
|
|
|
@@ -1161,12 +1218,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
|
|
|
|
+kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
|
|
+ device const char * src0,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2 [[buffer(17)]],
|
|
|
|
|
+ constant uint & r3 [[buffer(18)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
|
|
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#define N_F16_F32 4
|
|
#define N_F16_F32 4
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_f16_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_f16_f32_impl(
|
|
|
device const char * src0,
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -1184,8 +1266,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
constant uint64_t & nb12,
|
|
constant uint64_t & nb12,
|
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne0,
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
|
|
|
|
@@ -1244,6 +1326,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_f16_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_f16_f32(
|
|
|
|
|
+ device const char * src0,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2 [[buffer(17)]],
|
|
|
|
|
+ constant uint & r3 [[buffer(18)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
|
|
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// Assumes row size (ne00) is a multiple of 4
|
|
// Assumes row size (ne00) is a multiple of 4
|
|
|
kernel void kernel_mul_mv_f16_f32_l4(
|
|
kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
device const char * src0,
|
|
device const char * src0,
|
|
@@ -1601,8 +1709,8 @@ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_ar
|
|
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
|
|
|
|
|
|
|
kernel void kernel_cpy_f16_f16(
|
|
kernel void kernel_cpy_f16_f16(
|
|
|
- device const half * src0,
|
|
|
|
|
- device half * dst,
|
|
|
|
|
|
|
+ device const half * src0,
|
|
|
|
|
+ device half * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
@@ -1641,6 +1749,47 @@ kernel void kernel_cpy_f16_f16(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+kernel void kernel_cpy_f16_f32(
|
|
|
|
|
+ device const half * src0,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne03,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant uint64_t & nb03,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & ne2,
|
|
|
|
|
+ constant int64_t & ne3,
|
|
|
|
|
+ constant uint64_t & nb0,
|
|
|
|
|
+ constant uint64_t & nb1,
|
|
|
|
|
+ constant uint64_t & nb2,
|
|
|
|
|
+ constant uint64_t & nb3,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+ const int64_t i03 = tgpig[2];
|
|
|
|
|
+ const int64_t i02 = tgpig[1];
|
|
|
|
|
+ const int64_t i01 = tgpig[0];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t i3 = n / (ne2*ne1*ne0);
|
|
|
|
|
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
|
|
|
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
|
|
|
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
|
+
|
|
|
|
|
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
|
|
|
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
|
|
+ dst_data[i00] = src[0];
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
kernel void kernel_cpy_f32_f16(
|
|
kernel void kernel_cpy_f32_f16(
|
|
|
device const float * src0,
|
|
device const float * src0,
|
|
|
device half * dst,
|
|
device half * dst,
|
|
@@ -2064,19 +2213,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
|
|
|
|
|
|
//====================================== dot products =========================
|
|
//====================================== dot products =========================
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_q2_K_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_q2_K_f32_impl(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
- constant int64_t & ne01[[buffer(4)]],
|
|
|
|
|
- constant int64_t & ne02[[buffer(5)]],
|
|
|
|
|
- constant int64_t & ne10[[buffer(9)]],
|
|
|
|
|
- constant int64_t & ne12[[buffer(11)]],
|
|
|
|
|
- constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
- constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2214,8 +2363,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#if QK_K == 256
|
|
|
|
|
-kernel void kernel_mul_mv_q3_K_f32(
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_q2_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_q2_K_f32(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -2229,8 +2378,29 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
constant uint & r2 [[buffer(17)]],
|
|
constant uint & r2 [[buffer(17)]],
|
|
|
constant uint & r3 [[buffer(18)]],
|
|
constant uint & r3 [[buffer(18)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+#if QK_K == 256
|
|
|
|
|
+void kernel_mul_mv_q3_K_f32_impl(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
const int nb = ne00/QK_K;
|
|
const int nb = ne00/QK_K;
|
|
|
|
|
|
|
@@ -2373,19 +2543,19 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
#else
|
|
#else
|
|
|
-kernel void kernel_mul_mv_q3_K_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_q3_K_f32_impl(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
- constant int64_t & ne01[[buffer(4)]],
|
|
|
|
|
- constant int64_t & ne02[[buffer(5)]],
|
|
|
|
|
- constant int64_t & ne10[[buffer(9)]],
|
|
|
|
|
- constant int64_t & ne12[[buffer(11)]],
|
|
|
|
|
- constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
- constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2450,20 +2620,41 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
}
|
|
}
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_q3_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_q3_K_f32(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01[[buffer(4)]],
|
|
|
|
|
+ constant int64_t & ne02[[buffer(5)]],
|
|
|
|
|
+ constant int64_t & ne10[[buffer(9)]],
|
|
|
|
|
+ constant int64_t & ne12[[buffer(11)]],
|
|
|
|
|
+ constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
+ constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
+ constant uint & r2 [[buffer(17)]],
|
|
|
|
|
+ constant uint & r3 [[buffer(18)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
#if QK_K == 256
|
|
#if QK_K == 256
|
|
|
-kernel void kernel_mul_mv_q4_K_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_q4_K_f32_impl(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
- constant int64_t & ne01 [[buffer(4)]],
|
|
|
|
|
- constant int64_t & ne02 [[buffer(5)]],
|
|
|
|
|
- constant int64_t & ne10 [[buffer(9)]],
|
|
|
|
|
- constant int64_t & ne12 [[buffer(11)]],
|
|
|
|
|
- constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
- constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2564,21 +2755,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
#else
|
|
#else
|
|
|
-kernel void kernel_mul_mv_q4_K_f32(
|
|
|
|
|
|
|
+void kernel_mul_mv_q4_K_f32_impl(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
- constant int64_t & ne01[[buffer(4)]],
|
|
|
|
|
- constant int64_t & ne02[[buffer(5)]],
|
|
|
|
|
- constant int64_t & ne10[[buffer(9)]],
|
|
|
|
|
- constant int64_t & ne12[[buffer(11)]],
|
|
|
|
|
- constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
- constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
- constant uint & r2 [[buffer(17)]],
|
|
|
|
|
- constant uint & r3 [[buffer(18)]],
|
|
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
const int ix = tiisg/4; // 0...7
|
|
const int ix = tiisg/4; // 0...7
|
|
@@ -2660,7 +2851,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
}
|
|
}
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_q5_K_f32(
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_q4_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_q4_K_f32(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -2677,6 +2869,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+
|
|
|
const int nb = ne00/QK_K;
|
|
const int nb = ne00/QK_K;
|
|
|
|
|
|
|
|
const int64_t r0 = tgpig.x;
|
|
const int64_t r0 = tgpig.x;
|
|
@@ -2836,10 +3048,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-kernel void kernel_mul_mv_q6_K_f32(
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_q5_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_q5_K_f32(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
@@ -2853,8 +3065,28 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
constant uint & r2 [[buffer(17)]],
|
|
constant uint & r2 [[buffer(17)]],
|
|
|
constant uint & r3 [[buffer(18)]],
|
|
constant uint & r3 [[buffer(18)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
const uint8_t kmask1 = 0x03;
|
|
const uint8_t kmask1 = 0x03;
|
|
|
const uint8_t kmask2 = 0x0C;
|
|
const uint8_t kmask2 = 0x0C;
|
|
@@ -2945,6 +3177,27 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_q6_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_q6_K_f32(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01[[buffer(4)]],
|
|
|
|
|
+ constant int64_t & ne02[[buffer(5)]],
|
|
|
|
|
+ constant int64_t & ne10[[buffer(9)]],
|
|
|
|
|
+ constant int64_t & ne12[[buffer(11)]],
|
|
|
|
|
+ constant int64_t & ne0 [[buffer(15)]],
|
|
|
|
|
+ constant int64_t & ne1 [[buffer(16)]],
|
|
|
|
|
+ constant uint & r2 [[buffer(17)]],
|
|
|
|
|
+ constant uint & r3 [[buffer(18)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
//============================= templates and their specializations =============================
|
|
//============================= templates and their specializations =============================
|
|
|
|
|
|
|
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -3219,22 +3472,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
kernel void kernel_get_rows(
|
|
kernel void kernel_get_rows(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
- device const int * src1,
|
|
|
|
|
|
|
+ device const char * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
constant uint64_t & nb1,
|
|
constant uint64_t & nb1,
|
|
|
- uint tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
|
|
+ constant uint64_t & nb2,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tptg[[threads_per_threadgroup]]) {
|
|
|
|
|
- const int i = tgpig;
|
|
|
|
|
- const int r = ((device int32_t *) src1)[i];
|
|
|
|
|
|
|
+ uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
+ //const int64_t i = tgpig;
|
|
|
|
|
+ //const int64_t r = ((device int32_t *) src1)[i];
|
|
|
|
|
|
|
|
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
|
|
|
|
|
|
|
+ const int64_t i10 = tgpig.x;
|
|
|
|
|
+ const int64_t i11 = tgpig.y;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t i02 = i11;
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
|
|
float4x4 temp;
|
|
float4x4 temp;
|
|
|
dequantize_func(
|
|
dequantize_func(
|
|
|
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
|
|
|
|
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
|
|
|
|
|
|
|
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
|
|
|
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+kernel void kernel_get_rows_f32(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb1,
|
|
|
|
|
+ constant uint64_t & nb2,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
+ const int64_t i10 = tgpig.x;
|
|
|
|
|
+ const int64_t i11 = tgpig.y;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t i02 = i11;
|
|
|
|
|
+
|
|
|
|
|
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
|
|
|
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
|
|
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+kernel void kernel_get_rows_f16(
|
|
|
|
|
+ device const void * src0,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb1,
|
|
|
|
|
+ constant uint64_t & nb2,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
+ const int64_t i10 = tgpig.x;
|
|
|
|
|
+ const int64_t i11 = tgpig.y;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t i02 = i11;
|
|
|
|
|
+
|
|
|
|
|
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
|
|
|
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
|
|
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -3426,19 +3747,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
|
|
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
kernel void kernel_mul_mm_id(
|
|
kernel void kernel_mul_mm_id(
|
|
|
- device const int32_t * ids,
|
|
|
|
|
|
|
+ device const uchar * ids,
|
|
|
device const uchar * src1,
|
|
device const uchar * src1,
|
|
|
- device float * dst,
|
|
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
|
constant int64_t & nb01,
|
|
constant int64_t & nb01,
|
|
|
constant int64_t & nb02,
|
|
constant int64_t & nb02,
|
|
|
constant int64_t & ne12,
|
|
constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
constant int64_t & nb10,
|
|
constant int64_t & nb10,
|
|
|
constant int64_t & nb11,
|
|
constant int64_t & nb11,
|
|
|
constant int64_t & nb12,
|
|
constant int64_t & nb12,
|
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne0,
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
constant uint & r2,
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
constant uint & r3,
|
|
|
constant int & idx,
|
|
constant int & idx,
|
|
@@ -3456,10 +3780,16 @@ kernel void kernel_mul_mm_id(
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
|
- src0[ids[idx]],
|
|
|
|
|
- src1,
|
|
|
|
|
- dst,
|
|
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ src1 + bid*nb11,
|
|
|
|
|
+ (device float *) (dst + bid*nb1),
|
|
|
ne00,
|
|
ne00,
|
|
|
ne02,
|
|
ne02,
|
|
|
nb01,
|
|
nb01,
|
|
@@ -3484,17 +3814,26 @@ kernel void kernel_mul_mm_id(
|
|
|
#define QK_NL 4
|
|
#define QK_NL 4
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
|
|
+//
|
|
|
|
|
+// get rows
|
|
|
|
|
+//
|
|
|
|
|
+
|
|
|
typedef void (get_rows_t)(
|
|
typedef void (get_rows_t)(
|
|
|
device const void * src0,
|
|
device const void * src0,
|
|
|
- device const int * src1,
|
|
|
|
|
|
|
+ device const char * src1,
|
|
|
device float * dst,
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
constant uint64_t & nb1,
|
|
constant uint64_t & nb1,
|
|
|
- uint, uint, uint);
|
|
|
|
|
|
|
+ constant uint64_t & nb2,
|
|
|
|
|
+ uint3, uint, uint3);
|
|
|
|
|
|
|
|
-template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
|
|
|
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
|
|
|
|
|
+//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
|
|
|
+//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -3506,6 +3845,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
|
|
|
|
|
|
+//
|
|
|
|
|
+// matrix-matrix multiplication
|
|
|
|
|
+//
|
|
|
|
|
+
|
|
|
typedef void (mat_mm_t)(
|
|
typedef void (mat_mm_t)(
|
|
|
device const uchar * src0,
|
|
device const uchar * src0,
|
|
|
device const uchar * src1,
|
|
device const uchar * src1,
|
|
@@ -3538,20 +3881,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
|
|
|
|
|
|
+//
|
|
|
|
|
+// indirect matrix-matrix multiplication
|
|
|
|
|
+//
|
|
|
|
|
+
|
|
|
typedef void (mat_mm_id_t)(
|
|
typedef void (mat_mm_id_t)(
|
|
|
- device const int32_t * ids,
|
|
|
|
|
|
|
+ device const uchar * ids,
|
|
|
device const uchar * src1,
|
|
device const uchar * src1,
|
|
|
- device float * dst,
|
|
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
|
constant int64_t & nb01,
|
|
constant int64_t & nb01,
|
|
|
constant int64_t & nb02,
|
|
constant int64_t & nb02,
|
|
|
constant int64_t & ne12,
|
|
constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
constant int64_t & nb10,
|
|
constant int64_t & nb10,
|
|
|
constant int64_t & nb11,
|
|
constant int64_t & nb11,
|
|
|
constant int64_t & nb12,
|
|
constant int64_t & nb12,
|
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne0,
|
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
constant uint & r2,
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
constant uint & r3,
|
|
|
constant int & idx,
|
|
constant int & idx,
|
|
@@ -3578,3 +3928,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
|
|
+
|
|
|
|
|
+//
|
|
|
|
|
+// matrix-vector multiplication
|
|
|
|
|
+//
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_f32_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_f32_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_f32_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ src1 + bid*nb11,
|
|
|
|
|
+ (device float *) (dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ nb00,
|
|
|
|
|
+ nb01,
|
|
|
|
|
+ nb02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne11,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ nb10,
|
|
|
|
|
+ nb11,
|
|
|
|
|
+ nb12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_f16_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_f16_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_f16_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ src1 + bid*nb11,
|
|
|
|
|
+ (device float *) (dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ nb00,
|
|
|
|
|
+ nb01,
|
|
|
|
|
+ nb02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne11,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ nb10,
|
|
|
|
|
+ nb11,
|
|
|
|
|
+ nb12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q8_0_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q3_K_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q4_K_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
|
|
|
|
+kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
|
|
+ device const char * ids,
|
|
|
|
|
+ device const char * src1,
|
|
|
|
|
+ device uchar * dst,
|
|
|
|
|
+ constant int64_t & nbi1,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne01,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb00,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne10,
|
|
|
|
|
+ constant int64_t & ne11,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant int64_t & ne13,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ constant int64_t & ne1,
|
|
|
|
|
+ constant int64_t & nb1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ constant int & idx,
|
|
|
|
|
+ device const char * src00,
|
|
|
|
|
+ device const char * src01,
|
|
|
|
|
+ device const char * src02,
|
|
|
|
|
+ device const char * src03,
|
|
|
|
|
+ device const char * src04,
|
|
|
|
|
+ device const char * src05,
|
|
|
|
|
+ device const char * src06,
|
|
|
|
|
+ device const char * src07,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
|
+ src0[id],
|
|
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
|
|
+ (device float *) ( dst + bid*nb1),
|
|
|
|
|
+ ne00,
|
|
|
|
|
+ ne01,
|
|
|
|
|
+ ne02,
|
|
|
|
|
+ ne10,
|
|
|
|
|
+ ne12,
|
|
|
|
|
+ ne0,
|
|
|
|
|
+ ne1,
|
|
|
|
|
+ r2,
|
|
|
|
|
+ r3,
|
|
|
|
|
+ tgpig,
|
|
|
|
|
+ tiisg,
|
|
|
|
|
+ sgitg);
|
|
|
|
|
+}
|