|
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
|
//Note: This is a template, but strictly speaking it only applies to
|
|
//Note: This is a template, but strictly speaking it only applies to
|
|
|
// quantizations where the block size is 32. It also does not
|
|
// quantizations where the block size is 32. It also does not
|
|
|
-// giard against the number of rows not being divisible by
|
|
|
|
|
|
|
+// guard against the number of rows not being divisible by
|
|
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
|
void mul_vec_q_n_f32_impl(
|
|
void mul_vec_q_n_f32_impl(
|
|
@@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
|
|
|
|
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
|
|
+void kernel_mul_mm_id_impl(
|
|
|
|
|
+ device const uchar * src0,
|
|
|
|
|
+ device const uchar * src1,
|
|
|
|
|
+ thread short * src1ids,
|
|
|
|
|
+ device float * dst,
|
|
|
|
|
+ constant int64_t & ne00,
|
|
|
|
|
+ constant int64_t & ne02,
|
|
|
|
|
+ constant uint64_t & nb01,
|
|
|
|
|
+ constant uint64_t & nb02,
|
|
|
|
|
+ constant int64_t & ne12,
|
|
|
|
|
+ constant uint64_t & nb10,
|
|
|
|
|
+ constant uint64_t & nb11,
|
|
|
|
|
+ constant uint64_t & nb12,
|
|
|
|
|
+ constant int64_t & ne0,
|
|
|
|
|
+ int64_t ne1,
|
|
|
|
|
+ constant uint & r2,
|
|
|
|
|
+ constant uint & r3,
|
|
|
|
|
+ threadgroup uchar * shared_memory,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
|
|
|
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
|
|
|
+
|
|
|
|
|
+ const uint r0 = tgpig.y;
|
|
|
|
|
+ const uint r1 = tgpig.x;
|
|
|
|
|
+ const uint im = tgpig.z;
|
|
|
|
|
+
|
|
|
|
|
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
|
|
|
|
|
+
|
|
|
|
|
+ // if this block is of 64x32 shape or smaller
|
|
|
|
|
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
|
|
|
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
|
|
|
+
|
|
|
|
|
+ // a thread shouldn't load data outside of the matrix
|
|
|
|
|
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
|
|
|
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
|
|
|
+
|
|
|
|
|
+ simdgroup_half8x8 ma[4];
|
|
|
|
|
+ simdgroup_float8x8 mb[2];
|
|
|
|
|
+ simdgroup_float8x8 c_res[8];
|
|
|
|
|
+ for (int i = 0; i < 8; i++){
|
|
|
|
|
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ short il = (tiitg % THREAD_PER_ROW);
|
|
|
|
|
+
|
|
|
|
|
+ const uint i12 = im%ne12;
|
|
|
|
|
+ const uint i13 = im/ne12;
|
|
|
|
|
+
|
|
|
|
|
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
|
|
|
+ ushort offset1 = il/nl;
|
|
|
|
|
+
|
|
|
|
|
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
|
|
|
+ device const float * y = (device const float *)(src1
|
|
|
|
|
+ + nb12 * im
|
|
|
|
|
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
|
|
|
|
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
|
|
|
+
|
|
|
|
|
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
|
|
|
+ // load data and store to threadgroup memory
|
|
|
|
|
+ half4x4 temp_a;
|
|
|
|
|
+ dequantize_func(x, il, temp_a);
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < 16; i++) {
|
|
|
|
|
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
|
|
|
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
|
|
|
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
|
|
|
+
|
|
|
|
|
+ il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
|
|
|
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
|
|
|
+ y += BLOCK_SIZE_K;
|
|
|
|
|
+
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
+
|
|
|
|
|
+ // load matrices from threadgroup memory and conduct outer products
|
|
|
|
|
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
|
|
|
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
|
|
|
+
|
|
|
|
|
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
|
|
|
+ for (int i = 0; i < 4; i++) {
|
|
|
|
|
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
|
|
|
+ }
|
|
|
|
|
+ simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
|
+ for (int i = 0; i < 2; i++) {
|
|
|
|
|
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
|
|
|
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < 8; i++){
|
|
|
|
|
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ {
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
|
|
|
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
|
|
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
+
|
|
|
|
|
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
|
|
|
|
+ if (sgitg == 0) {
|
|
|
|
|
+ for (int i = 0; i < n_rows; i++) {
|
|
|
|
|
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
|
|
|
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
kernel void kernel_mul_mm(device const uchar * src0,
|
|
kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
device const uchar * src1,
|
|
device const uchar * src1,
|
|
@@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|
|
kernel void kernel_mul_mm_id(
|
|
kernel void kernel_mul_mm_id(
|
|
|
device const uchar * ids,
|
|
device const uchar * ids,
|
|
|
device const uchar * src1,
|
|
device const uchar * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
@@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
|
|
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
|
|
|
|
|
- const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
|
|
|
|
+ // expert id
|
|
|
|
|
+ const int32_t id = tgpig.z/(ne12*ne13);
|
|
|
|
|
|
|
|
tgpig.z = tgpig.z%(ne12*ne13);
|
|
tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
|
|
|
|
|
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
|
|
|
|
+ // row indices of src1 for expert id
|
|
|
|
|
+ int64_t _ne1 = 0;
|
|
|
|
|
+ short src1ids[512];
|
|
|
|
|
|
|
|
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
|
|
|
- src0[id],
|
|
|
|
|
- src1 + bid*nb11,
|
|
|
|
|
- (device float *) (dst + bid*nb1),
|
|
|
|
|
|
|
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
|
|
|
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
|
|
|
|
+ src1ids[_ne1++] = i1;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
|
|
|
+ src0s[id],
|
|
|
|
|
+ src1,
|
|
|
|
|
+ src1ids,
|
|
|
|
|
+ dst,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne02,
|
|
ne02,
|
|
|
nb01,
|
|
nb01,
|
|
@@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id(
|
|
|
nb11,
|
|
nb11,
|
|
|
nb12,
|
|
nb12,
|
|
|
ne0,
|
|
ne0,
|
|
|
- ne1,
|
|
|
|
|
|
|
+ _ne1,
|
|
|
r2,
|
|
r2,
|
|
|
r3,
|
|
r3,
|
|
|
shared_memory,
|
|
shared_memory,
|
|
@@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
typedef void (mat_mm_id_t)(
|
|
typedef void (mat_mm_id_t)(
|
|
|
device const uchar * ids,
|
|
device const uchar * ids,
|
|
|
device const uchar * src1,
|
|
device const uchar * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne02,
|
|
@@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
kernel void kernel_mul_mv_id_f32_f32(
|
|
kernel void kernel_mul_mv_id_f32_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
kernel_mul_mv_f32_f32_impl(
|
|
kernel_mul_mv_f32_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
src1 + bid*nb11,
|
|
src1 + bid*nb11,
|
|
|
- (device float *) (dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
kernel void kernel_mul_mv_id_f16_f32(
|
|
kernel void kernel_mul_mv_id_f16_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
kernel_mul_mv_f16_f32_impl(
|
|
kernel_mul_mv_f16_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
src1 + bid*nb11,
|
|
src1 + bid*nb11,
|
|
|
- (device float *) (dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
kernel_mul_mv_q8_0_f32_impl(
|
|
kernel_mul_mv_q8_0_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
kernel_mul_mv_q2_K_f32_impl(
|
|
kernel_mul_mv_q2_K_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
kernel_mul_mv_q3_K_f32_impl(
|
|
kernel_mul_mv_q3_K_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
kernel_mul_mv_q4_K_f32_impl(
|
|
kernel_mul_mv_q4_K_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
kernel_mul_mv_q5_K_f32_impl(
|
|
kernel_mul_mv_q5_K_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|
|
@@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
device const char * ids,
|
|
device const char * ids,
|
|
|
device const char * src1,
|
|
device const char * src1,
|
|
|
- device uchar * dst,
|
|
|
|
|
|
|
+ device float * dst,
|
|
|
constant uint64_t & nbi1,
|
|
constant uint64_t & nbi1,
|
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne01,
|
|
@@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
kernel_mul_mv_q6_K_f32_impl(
|
|
kernel_mul_mv_q6_K_f32_impl(
|
|
|
src0[id],
|
|
src0[id],
|
|
|
(device const float *) (src1 + bid*nb11),
|
|
(device const float *) (src1 + bid*nb11),
|
|
|
- (device float *) ( dst + bid*nb1),
|
|
|
|
|
|
|
+ dst + bid*ne0,
|
|
|
ne00,
|
|
ne00,
|
|
|
ne01,
|
|
ne01,
|
|
|
ne02,
|
|
ne02,
|