|
|
@@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
|
|
-// TODO: this kernel needs to be reimplemented from scratch for better performance
|
|
|
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
-void kernel_mul_mm_id_impl(
|
|
|
- int32_t ne00,
|
|
|
- int32_t ne02,
|
|
|
- uint64_t nb01,
|
|
|
- uint64_t nb02,
|
|
|
- int32_t ne11,
|
|
|
- int32_t ne12,
|
|
|
- uint64_t nb10,
|
|
|
- uint64_t nb11,
|
|
|
- uint64_t nb12,
|
|
|
- int32_t ne0,
|
|
|
- int32_t ne1,
|
|
|
- int64_t ne0ne1,
|
|
|
- device const char * src0,
|
|
|
- device const char * src1,
|
|
|
- threadgroup ushort2 * rowids,
|
|
|
- device char * dst,
|
|
|
- threadgroup char * shmem,
|
|
|
+template<typename T4>
|
|
|
+kernel void kernel_mul_mm_id_map0(
|
|
|
+ constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
|
|
+ device const char * src1,
|
|
|
+ device const char * src2,
|
|
|
+ device char * hsrc1,
|
|
|
+ device char * htpe,
|
|
|
+ device char * hids,
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
+ const int ide = tgpig[0]; // expert id
|
|
|
+
|
|
|
+ int n_all = 0;
|
|
|
+
|
|
|
+ device int32_t * ids_i32 = (device int32_t *) (hids);
|
|
|
+
|
|
|
+ for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
|
|
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
|
|
+
|
|
|
+ for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
|
|
+ if (src2_i32[i20] != ide) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
|
|
+ device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
|
|
+
|
|
|
+ for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
|
|
+ hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (tpitg.x == 0) {
|
|
|
+ ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
|
|
+ }
|
|
|
+
|
|
|
+ ++n_all;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (tpitg.x == 0) {
|
|
|
+ device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
|
|
+ tpe_i32[ide] = n_all;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
|
|
+
|
|
|
+template<typename T>
|
|
|
+kernel void kernel_mul_mm_id_map1(
|
|
|
+ constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
|
|
+ device const char * hdst,
|
|
|
+ device const char * hids,
|
|
|
+ device char * dst,
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
+ const int i20 = tgpig[0]; // used expert
|
|
|
+ const int i21 = tgpig[1]; // token
|
|
|
+
|
|
|
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
|
|
+ device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
|
|
+
|
|
|
+ const int id = ids_i32[i21*args.ne20 + i20];
|
|
|
+
|
|
|
+ const int ide = id / args.neh1;
|
|
|
+ const int idt = id % args.neh1;
|
|
|
+
|
|
|
+ device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
|
|
+
|
|
|
+ for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
|
|
+ dst_f32x4[i0] = hdst_f32x4[i0];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
|
|
+
|
|
|
+template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
|
|
+kernel void kernel_mul_mm_id(
|
|
|
+ constant ggml_metal_kargs_mul_mm_id & args,
|
|
|
+ device const char * src0,
|
|
|
+ device const char * src1,
|
|
|
+ device const char * tpe,
|
|
|
+ device char * dst,
|
|
|
+ threadgroup char * shmem [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- threadgroup half * sa = (threadgroup half *)(shmem);
|
|
|
- threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
|
|
+ threadgroup T * sa = (threadgroup T *)(shmem);
|
|
|
+ threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
|
|
|
|
|
const int r0 = tgpig.y;
|
|
|
const int r1 = tgpig.x;
|
|
|
+ const int im = tgpig.z;
|
|
|
+
|
|
|
+ device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
|
|
+
|
|
|
+ const int neh1 = tpe_i32[im];
|
|
|
|
|
|
- if (r1*BLOCK_SIZE_N >= ne1) return;
|
|
|
+ if (r1*BLOCK_SIZE_N >= neh1) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
// if this block is of 64x32 shape or smaller
|
|
|
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
|
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
|
+ const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
|
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
|
|
|
|
// a thread shouldn't load data outside of the matrix
|
|
|
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
|
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
|
+ const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
|
+ const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
|
|
|
|
- simdgroup_half8x8 ma[4];
|
|
|
- simdgroup_float8x8 mb[2];
|
|
|
+ simdgroup_T8x8 ma[4];
|
|
|
+ simdgroup_half8x8 mb[2];
|
|
|
simdgroup_float8x8 mc[8];
|
|
|
- for (int i = 0; i < 8; i++){
|
|
|
+
|
|
|
+ for (short i = 0; i < 8; i++){
|
|
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
|
}
|
|
|
+
|
|
|
short il = (tiitg % THREAD_PER_ROW);
|
|
|
|
|
|
- ushort offset1 = il/nl;
|
|
|
+ const int i12 = im%args.neh12;
|
|
|
+ const int i13 = im/args.neh12;
|
|
|
|
|
|
- threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
|
|
+ const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
|
+ const short offset1 = il/nl;
|
|
|
|
|
|
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
|
|
- device const float * y = (device const float *)(src1
|
|
|
- + nb12 * id[1]
|
|
|
- + nb11 * (id[0] % ne11)
|
|
|
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
|
+ device const block_q * x = (device const block_q *)(src0
|
|
|
+ + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
|
|
|
|
|
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
|
+ device const half * y = (device const half *)(src1
|
|
|
+ + args.nbh13*i13
|
|
|
+ + args.nbh12*i12
|
|
|
+ + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
|
|
+ + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
|
+
|
|
|
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
|
|
// load data and store to threadgroup memory
|
|
|
- half4x4 temp_a;
|
|
|
+ T4x4 temp_a;
|
|
|
dequantize_func(x, il, temp_a);
|
|
|
+
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- for (int i = 0; i < 16; i++) {
|
|
|
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
|
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
|
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
|
+ #pragma unroll(16)
|
|
|
+ for (short i = 0; i < 16; i++) {
|
|
|
+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
|
|
+ + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
|
|
+ + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
|
|
}
|
|
|
|
|
|
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
|
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
|
|
|
|
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
|
- x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
|
+ x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
|
y += BLOCK_SIZE_K;
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
// load matrices from threadgroup memory and conduct outer products
|
|
|
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
|
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
|
+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
|
|
+ threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
|
|
|
|
|
- #pragma unroll(BLOCK_SIZE_K/8)
|
|
|
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
|
+ #pragma unroll(4)
|
|
|
+ for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
|
|
#pragma unroll(4)
|
|
|
- for (int i = 0; i < 4; i++) {
|
|
|
+ for (short i = 0; i < 4; i++) {
|
|
|
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
|
|
}
|
|
|
+
|
|
|
simdgroup_barrier(mem_flags::mem_none);
|
|
|
+
|
|
|
#pragma unroll(2)
|
|
|
- for (int i = 0; i < 2; i++) {
|
|
|
+ for (short i = 0; i < 2; i++) {
|
|
|
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
|
|
}
|
|
|
|
|
|
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
|
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
|
-
|
|
|
#pragma unroll(8)
|
|
|
- for (int i = 0; i < 8; i++){
|
|
|
+ for (short i = 0; i < 8; i++){
|
|
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
|
|
}
|
|
|
+
|
|
|
+ lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
|
|
|
+ lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- {
|
|
|
+ if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
|
|
+ device float * C = (device float *) dst +
|
|
|
+ (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
|
|
+ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
|
|
+
|
|
|
+ for (short i = 0; i < 8; i++) {
|
|
|
+ simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
|
|
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
|
- for (int i = 0; i < 8; i++) {
|
|
|
- simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
|
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
|
|
+ for (short i = 0; i < 8; i++) {
|
|
|
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
|
|
}
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
if (sgitg == 0) {
|
|
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
|
- threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
|
|
- int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
|
|
-
|
|
|
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
|
|
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
|
|
device float4 * D4 = (device float4 *) D;
|
|
|
|
|
|
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
|
|
@@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
-kernel void kernel_mul_mm_id(
|
|
|
- constant ggml_metal_kargs_mul_mm_id & args,
|
|
|
- device const char * src0s,
|
|
|
- device const char * src1,
|
|
|
- device char * dst,
|
|
|
- device const char * ids,
|
|
|
- threadgroup char * shmem [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
-
|
|
|
- const int32_t i02 = tgpig.z;
|
|
|
-
|
|
|
- tgpig.z = 0;
|
|
|
-
|
|
|
- device const char * src0 = src0s + i02*args.nb02;
|
|
|
-
|
|
|
- // row indices
|
|
|
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
|
|
-
|
|
|
- // TODO: parallelize this loop
|
|
|
- int32_t _ne1 = 0;
|
|
|
- for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
|
|
- for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
|
|
- int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
|
|
- if (id == i02) {
|
|
|
- if (tiitg == 0) {
|
|
|
- rowids[_ne1] = ushort2(ii0, ii1);
|
|
|
- }
|
|
|
- _ne1++;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
-
|
|
|
- kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
|
- args.ne00,
|
|
|
- args.ne02,
|
|
|
- args.nb01,
|
|
|
- args.nb02,
|
|
|
- args.ne11,
|
|
|
- args.ne12,
|
|
|
- args.nb10,
|
|
|
- args.nb11,
|
|
|
- args.nb12,
|
|
|
- args.ne0,
|
|
|
- _ne1,
|
|
|
- (int64_t)args.ne0*args.ne1,
|
|
|
- src0,
|
|
|
- src1,
|
|
|
- rowids,
|
|
|
- dst,
|
|
|
- shmem,
|
|
|
- tgpig,
|
|
|
- tiitg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
#define QK_NL 16
|
|
|
|
|
|
//
|
|
|
@@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
|
|
|
// matrix-matrix multiplication
|
|
|
//
|
|
|
|
|
|
-typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
|
|
+typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
|
|
|
|
|
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
|
|
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
|
|
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
|
|
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
|
|
#if defined(GGML_METAL_USE_BF16)
|
|
|
-template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
|
|
+template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
|
|
#endif
|
|
|
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
|
|
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
|
|
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
|
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
|
|
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
|
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
-template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
-template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
-template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
-template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
|
|
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
|
|
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
|
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
|
|
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
|
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
|
|
//
|
|
|
// indirect matrix-matrix multiplication
|
|
|
//
|
|
|
|
|
|
-typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
|
|
+typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
|
|
|
|
|
-template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
|
|
-template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
|
+template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
|
|
+template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
|
|
#if defined(GGML_METAL_USE_BF16)
|
|
|
-template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
|
|
+template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
|
|
#endif
|
|
|
-template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
-template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
-template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
+template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
+
|
|
|
|
|
|
//
|
|
|
// matrix-vector multiplication
|