|
|
@@ -857,15 +857,16 @@ void mul_vec_q_n_f32_impl(
|
|
|
device const void * src0,
|
|
|
device const float * src1,
|
|
|
device float * dst,
|
|
|
- int64_t ne00,
|
|
|
- int64_t ne01,
|
|
|
- int64_t ne02,
|
|
|
- int64_t ne10,
|
|
|
- int64_t ne12,
|
|
|
- int64_t ne0,
|
|
|
- int64_t ne1,
|
|
|
- uint r2,
|
|
|
- uint r3,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant int64_t & ne10,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values,
|
|
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
|
|
const int nb = ne00/QK4_0;
|
|
|
|
|
|
@@ -942,7 +943,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
|
@@ -968,7 +969,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
|
@@ -994,7 +995,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
|
@@ -1020,7 +1021,7 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -1039,6 +1040,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -1119,7 +1121,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
|
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
#define N_F32_F32 4
|
|
|
@@ -2709,6 +2711,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -2871,7 +2874,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
#if QK_K == 256
|
|
|
@@ -2888,6 +2891,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -3046,6 +3050,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -3135,7 +3140,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
#if QK_K == 256
|
|
|
@@ -3152,6 +3157,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -3265,6 +3271,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -3373,7 +3380,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
void kernel_mul_mv_q5_K_f32_impl(
|
|
|
@@ -3389,6 +3396,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -3579,7 +3587,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
void kernel_mul_mv_q6_K_f32_impl(
|
|
|
@@ -3595,6 +3603,7 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -3713,7 +3722,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
// ======================= "True" 2-bit
|
|
|
@@ -4396,6 +4405,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -4485,6 +4495,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -4593,11 +4604,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
|
|
const int nb = ne00/QK4_NL;
|
|
|
const int r0 = tgpig.x;
|
|
|
const int r1 = tgpig.y;
|
|
|
@@ -4687,11 +4699,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
-
|
|
|
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
|
|
const int nb = ne00/QK_K;
|
|
|
const int r0 = tgpig.x;
|
|
|
const int r1 = tgpig.y;
|
|
|
@@ -4794,7 +4806,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
|
|
@@ -4822,7 +4834,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
|
@@ -4846,7 +4858,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -4875,7 +4887,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
|
constant int64_t & ne1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
@@ -6022,135 +6034,52 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
|
// matrix-vector multiplication
|
|
|
//
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_id_f32_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_f32_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_f32_f32_impl(
|
|
|
- src0,
|
|
|
- src1 + bid*nb11,
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- nb00,
|
|
|
- nb01,
|
|
|
- nb02,
|
|
|
- ne10,
|
|
|
- ne11,
|
|
|
- ne12,
|
|
|
- nb10,
|
|
|
- nb11,
|
|
|
- nb12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_f16_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_f16_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
+typedef void (kernel_mul_mv_impl_t)(
|
|
|
+ device const char * src0,
|
|
|
+ device const char * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant uint64_t & nb00,
|
|
|
+ constant uint64_t & nb01,
|
|
|
+ constant uint64_t & nb02,
|
|
|
+ constant int64_t & ne10,
|
|
|
+ constant int64_t & ne11,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant uint64_t & nb10,
|
|
|
+ constant uint64_t & nb11,
|
|
|
+ constant uint64_t & nb12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]]);
|
|
|
|
|
|
- kernel_mul_mv_f16_f32_impl(
|
|
|
- src0,
|
|
|
- src1 + bid*nb11,
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- nb00,
|
|
|
- nb01,
|
|
|
- nb02,
|
|
|
- ne10,
|
|
|
- ne11,
|
|
|
- ne12,
|
|
|
- nb10,
|
|
|
- nb11,
|
|
|
- nb12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg);
|
|
|
-}
|
|
|
+typedef void (kernel_mul_mv2_impl_t)(
|
|
|
+ device const void * src0,
|
|
|
+ device const float * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant int64_t & ne10,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
- device const char * src0s,
|
|
|
+template<kernel_mul_mv_impl_t impl_fn>
|
|
|
+void mmv_fn(
|
|
|
+ device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
|
@@ -6169,43 +6098,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
constant uint64_t & nb1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_q8_0_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
|
|
}
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
- device const char * src0s,
|
|
|
+template<kernel_mul_mv2_impl_t impl_fn>
|
|
|
+void mmv_fn(
|
|
|
+ device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
|
@@ -6224,43 +6129,18 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
constant uint64_t & nb1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
- device const char * src0s,
|
|
|
+typedef void (mul_mv_impl_fn_t)(
|
|
|
+ device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
|
@@ -6279,38 +6159,14 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
constant uint64_t & nb1,
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
+template<mul_mv_impl_fn_t impl_fn>
|
|
|
+kernel void kernel_mul_mv_id(
|
|
|
device const char * src0s,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
@@ -6335,6 +6191,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
constant int & idx,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
@@ -6346,26 +6203,36 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
device const char * src0 = src0s + id*nb02;
|
|
|
|
|
|
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
+ impl_fn(
|
|
|
src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
+ src1 + bid*nb11,
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
|
ne01,
|
|
|
ne02,
|
|
|
+ nb00,
|
|
|
+ nb01,
|
|
|
+ nb02,
|
|
|
ne10,
|
|
|
+ ne11,
|
|
|
ne12,
|
|
|
+ ne13,
|
|
|
+ nb10,
|
|
|
+ nb11,
|
|
|
+ nb12,
|
|
|
ne0,
|
|
|
ne1,
|
|
|
+ nb1,
|
|
|
r2,
|
|
|
r3,
|
|
|
+ shared_values,
|
|
|
tgpig,
|
|
|
+ tiitg,
|
|
|
tiisg,
|
|
|
sgitg);
|
|
|
}
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
+typedef void (kernel_mul_mv_id_t)(
|
|
|
device const char * src0s,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
@@ -6390,819 +6257,33 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
constant uint & r2,
|
|
|
constant uint & r3,
|
|
|
constant int & idx,
|
|
|
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
|
|
|
+
|
|
|
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
|
|
+#if QK_K != 64
|
|
|
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
|
|
+#endif
|
|
|
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_q2_K_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_q3_K_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_q4_K_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_q5_K_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_q6_K_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq3_s_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq2_s_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq1_s_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq1_m_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq1_m_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
- kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
|
|
-kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- device const char * ids,
|
|
|
- constant uint64_t & nbi1,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant int64_t & ne13,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- constant int & idx,
|
|
|
- threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
-
|
|
|
- tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
-
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
- device const char * src0 = src0s + id*nb02;
|
|
|
-
|
|
|
-#if QK_K == 64
|
|
|
- kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
-#else
|
|
|
- kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
-#endif
|
|
|
- src0,
|
|
|
- (device const float *) (src1 + bid*nb11),
|
|
|
- dst + bid*ne0,
|
|
|
- ne00,
|
|
|
- ne01,
|
|
|
- ne02,
|
|
|
- ne10,
|
|
|
- ne12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_values,
|
|
|
- tgpig,
|
|
|
- tiisg,
|
|
|
- sgitg);
|
|
|
-}
|