|
|
@@ -18,6 +18,12 @@ typedef struct {
|
|
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
|
|
} block_q4_1;
|
|
|
|
|
|
+#define QK8_0 32
|
|
|
+typedef struct {
|
|
|
+ half d; // delta
|
|
|
+ int8_t qs[QK8_0]; // quants
|
|
|
+} block_q8_0;
|
|
|
+
|
|
|
kernel void kernel_add(
|
|
|
device const float * src0,
|
|
|
device const float * src1,
|
|
|
@@ -357,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
|
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
|
|
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
|
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
|
float yl[16]; // src1 vector cache
|
|
|
float sumf[nr]={0.f};
|
|
|
|
|
|
@@ -429,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
+kernel void kernel_mul_mat_q8_0_f32(
|
|
|
+ device const void * src0,
|
|
|
+ device const float * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne01[[buffer(4)]],
|
|
|
+ constant int64_t & ne02[[buffer(5)]],
|
|
|
+ constant int64_t & ne10[[buffer(9)]],
|
|
|
+ constant int64_t & ne12[[buffer(11)]],
|
|
|
+ constant int64_t & ne0[[buffer(15)]],
|
|
|
+ constant int64_t & ne1[[buffer(16)]],
|
|
|
+ constant uint & gqa[[buffer(17)]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiisg[[thread_index_in_simdgroup]],
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
+ const int nr = N_DST;
|
|
|
+ const int nsg = N_SIMDGROUP;
|
|
|
+ const int nw = N_SIMDWIDTH;
|
|
|
+
|
|
|
+ const int nb = ne00/QK8_0;
|
|
|
+ const int r0 = tgpig.x;
|
|
|
+ const int r1 = tgpig.y;
|
|
|
+ const int im = tgpig.z;
|
|
|
+ const int first_row = (r0 * nsg + sgitg) * nr;
|
|
|
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
|
|
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
|
|
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
|
+
|
|
|
+ float yl[16];
|
|
|
+ float sumf[nr]={0.f};
|
|
|
+
|
|
|
+ const int ix = tiisg/2;
|
|
|
+ const int il = tiisg%2;
|
|
|
+
|
|
|
+ device const float * yb = y + ix * QK8_0 + 16*il;
|
|
|
+
|
|
|
+ // each thread in a SIMD group deals with half a block.
|
|
|
+ for (int ib = ix; ib < nb; ib += nw/2) {
|
|
|
+ for (int i = 0; i < 16; ++i) {
|
|
|
+ yl[i] = yb[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int row = 0; row < nr; row++) {
|
|
|
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
|
|
+ float sumq = 0.f;
|
|
|
+ for (int iq = 0; iq < 16; ++iq) {
|
|
|
+ sumq += qs[iq] * yl[iq];
|
|
|
+ }
|
|
|
+ sumf[row] += sumq*x[ib+row*nb].d;
|
|
|
+ }
|
|
|
+
|
|
|
+ yb += QK8_0 * 16;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int row = 0; row < nr; ++row) {
|
|
|
+ const float tot = simd_sum(sumf[row]);
|
|
|
+ if (tiisg == 0 && first_row + row < ne01) {
|
|
|
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
kernel void kernel_mul_mat_f16_f32(
|
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
|
@@ -480,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-
|
|
|
kernel void kernel_alibi_f32(
|
|
|
device const float * src0,
|
|
|
device float * dst,
|
|
|
@@ -1621,12 +1688,12 @@ template <typename type4x4>
|
|
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
|
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
|
const half d = il ? (xb->d / 16.h) : xb->d;
|
|
|
- const half m = il ? (-8.h * 16.h) : -8.h;
|
|
|
+ const half m = il ? ( -8.h * 16.h) : -8.h;
|
|
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
|
|
|
|
|
for (int i=0;i<8;i++) {
|
|
|
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
|
|
|
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
|
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
|
|
}
|
|
|
}
|
|
|
@@ -1640,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
|
|
|
|
|
for (int i=0;i<8;i++) {
|
|
|
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
|
|
|
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
|
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+template <typename type4x4>
|
|
|
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
|
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
|
+ const half d = xb->d;
|
|
|
+
|
|
|
+ for (int i=0;i<16;i++) {
|
|
|
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
template <typename type4x4>
|
|
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
|
const half d = xb->d;
|
|
|
@@ -1947,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
|
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
|
|
|
|
|
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
|
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
|
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
|
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
@@ -1960,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
|
|
|
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
|
|
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
|
|
|
|
|
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
|
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
|
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|