|
|
@@ -2531,6 +2531,12 @@ typedef struct {
|
|
|
uint8_t scales[QK_K/16];
|
|
|
} block_iq1_s;
|
|
|
|
|
|
+// Non-linear quants
|
|
|
+#define QK4_NL 32
|
|
|
+typedef struct {
|
|
|
+ half d;
|
|
|
+ uint8_t qs[QK4_NL/2];
|
|
|
+} block_iq4_nl;
|
|
|
|
|
|
//====================================== dot products =========================
|
|
|
|
|
|
@@ -4384,7 +4390,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
const uint i13 = im/ne12;
|
|
|
|
|
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
|
-
|
|
|
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
|
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
|
|
|
|
@@ -4447,6 +4452,103 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
|
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
|
+};
|
|
|
+
|
|
|
+void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
+ device const void * src0,
|
|
|
+ device const float * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant int64_t & ne10,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
+
|
|
|
+ const int nb = ne00/QK4_NL;
|
|
|
+ const int r0 = tgpig.x;
|
|
|
+ const int r1 = tgpig.y;
|
|
|
+ const int im = tgpig.z;
|
|
|
+ const int first_row = (r0 * 2 + sgitg) * 2;
|
|
|
+ const int ib_row = first_row * nb;
|
|
|
+
|
|
|
+ const uint i12 = im%ne12;
|
|
|
+ const uint i13 = im/ne12;
|
|
|
+
|
|
|
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
|
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
|
|
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
|
+
|
|
|
+ const int ix = tiisg/2; // 0...15
|
|
|
+ const int it = tiisg%2; // 0 or 1
|
|
|
+
|
|
|
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+
|
|
|
+ float4 yl[4];
|
|
|
+ float sumf[2]={0.f}, all_sum;
|
|
|
+
|
|
|
+ device const float * yb = y + ix * QK4_NL + it * 8;
|
|
|
+
|
|
|
+ uint32_t aux32[2];
|
|
|
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
|
|
+
|
|
|
+ float4 qf1, qf2;
|
|
|
+
|
|
|
+ for (int ib = ix; ib < nb; ib += 16) {
|
|
|
+
|
|
|
+ device const float4 * y4 = (device const float4 *)yb;
|
|
|
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
|
|
+
|
|
|
+ for (int row = 0; row < 2; ++row) {
|
|
|
+
|
|
|
+ device const block_iq4_nl & xb = x[row*nb + ib];
|
|
|
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
|
|
+
|
|
|
+ float4 acc1 = {0.f}, acc2 = {0.f};
|
|
|
+
|
|
|
+ aux32[0] = q4[0] | (q4[1] << 16);
|
|
|
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
|
|
+ aux32[0] &= 0x0f0f0f0f;
|
|
|
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
|
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
|
+ acc1 += yl[0] * qf1;
|
|
|
+ acc2 += yl[1] * qf2;
|
|
|
+
|
|
|
+ aux32[0] = q4[2] | (q4[3] << 16);
|
|
|
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
|
|
+ aux32[0] &= 0x0f0f0f0f;
|
|
|
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
|
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
|
+ acc1 += yl[2] * qf1;
|
|
|
+ acc2 += yl[3] * qf2;
|
|
|
+
|
|
|
+ acc1 += acc2;
|
|
|
+
|
|
|
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ yb += 16 * QK4_NL;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int row = 0; row < 2; ++row) {
|
|
|
+ all_sum = simd_sum(sumf[row]);
|
|
|
+ if (tiisg == 0) {
|
|
|
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
|
|
kernel void kernel_mul_mv_iq1_s_f32(
|
|
|
device const void * src0,
|
|
|
@@ -4475,6 +4577,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
|
}
|
|
|
|
|
|
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
|
+kernel void kernel_mul_mv_iq4_nl_f32(
|
|
|
+ device const void * src0,
|
|
|
+ device const float * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant uint64_t & nb00,
|
|
|
+ constant uint64_t & nb01,
|
|
|
+ constant uint64_t & nb02,
|
|
|
+ constant int64_t & ne10,
|
|
|
+ constant int64_t & ne11,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant uint64_t & nb10,
|
|
|
+ constant uint64_t & nb11,
|
|
|
+ constant uint64_t & nb12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
+
|
|
|
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
|
+}
|
|
|
|
|
|
//============================= templates and their specializations =============================
|
|
|
|
|
|
@@ -4838,6 +4968,21 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+template <typename type4x4>
|
|
|
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
|
|
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
|
|
+ const float d = xb->d;
|
|
|
+ uint32_t aux32;
|
|
|
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
|
|
+ for (int i = 0; i < 4; ++i) {
|
|
|
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
|
|
|
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
|
|
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
|
|
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
|
|
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
kernel void kernel_get_rows(
|
|
|
device const void * src0,
|
|
|
@@ -5381,6 +5526,7 @@ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_r
|
|
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
|
|
|
//
|
|
|
// matrix-matrix multiplication
|
|
|
@@ -5421,6 +5567,7 @@ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
|
|
|
//
|
|
|
// indirect matrix-matrix multiplication
|
|
|
@@ -5473,6 +5620,7 @@ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel
|
|
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
|
|
|
//
|
|
|
// matrix-vector multiplication
|
|
|
@@ -6503,3 +6651,68 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
|
tiisg,
|
|
|
sgitg);
|
|
|
}
|
|
|
+
|
|
|
+[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
|
|
+kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
|
+ device const char * ids,
|
|
|
+ device const char * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant uint64_t & nbi1,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant uint64_t & nb00,
|
|
|
+ constant uint64_t & nb01,
|
|
|
+ constant uint64_t & nb02,
|
|
|
+ constant int64_t & ne10,
|
|
|
+ constant int64_t & ne11,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant int64_t & ne13,
|
|
|
+ constant uint64_t & nb10,
|
|
|
+ constant uint64_t & nb11,
|
|
|
+ constant uint64_t & nb12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint64_t & nb1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ constant int & idx,
|
|
|
+ device const char * src00,
|
|
|
+ device const char * src01,
|
|
|
+ device const char * src02,
|
|
|
+ device const char * src03,
|
|
|
+ device const char * src04,
|
|
|
+ device const char * src05,
|
|
|
+ device const char * src06,
|
|
|
+ device const char * src07,
|
|
|
+ threadgroup float * shared_values [[threadgroup(0)]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
|
+
|
|
|
+ const int64_t bid = tgpig.z/(ne12*ne13);
|
|
|
+
|
|
|
+ tgpig.z = tgpig.z%(ne12*ne13);
|
|
|
+
|
|
|
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
|
+
|
|
|
+ kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
+ src0[id],
|
|
|
+ (device const float *) (src1 + bid*nb11),
|
|
|
+ dst + bid*ne0,
|
|
|
+ ne00,
|
|
|
+ ne01,
|
|
|
+ ne02,
|
|
|
+ ne10,
|
|
|
+ ne12,
|
|
|
+ ne0,
|
|
|
+ ne1,
|
|
|
+ r2,
|
|
|
+ r3,
|
|
|
+ shared_values,
|
|
|
+ tgpig,
|
|
|
+ tiisg,
|
|
|
+ sgitg);
|
|
|
+}
|