|
|
@@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
-#define N_F32_F32 4
|
|
|
+#define N_MV_T_T 4
|
|
|
|
|
|
-void kernel_mul_mv_f32_f32_impl(
|
|
|
+template<typename T0, typename T04, typename T1, typename T14>
|
|
|
+void kernel_mul_mv_impl(
|
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
@@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
|
|
|
uint64_t nb12,
|
|
|
int64_t ne0,
|
|
|
int64_t ne1,
|
|
|
- uint r2,
|
|
|
- uint r3,
|
|
|
- uint3 tgpig,
|
|
|
- uint tiisg) {
|
|
|
-
|
|
|
+ uint r2,
|
|
|
+ uint r3,
|
|
|
+ uint3 tgpig,
|
|
|
+ uint tiisg) {
|
|
|
const int64_t r0 = tgpig.x;
|
|
|
- const int64_t rb = tgpig.y*N_F32_F32;
|
|
|
+ const int64_t rb = tgpig.y*N_MV_T_T;
|
|
|
const int64_t im = tgpig.z;
|
|
|
|
|
|
const uint i12 = im%ne12;
|
|
|
@@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
|
|
|
|
|
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
|
|
|
|
- device const float * x = (device const float *) (src0 + offset0);
|
|
|
+ device const T0 * x = (device const T0 *) (src0 + offset0);
|
|
|
|
|
|
if (ne00 < 128) {
|
|
|
- for (int row = 0; row < N_F32_F32; ++row) {
|
|
|
+ for (int row = 0; row < N_MV_T_T; ++row) {
|
|
|
int r1 = rb + row;
|
|
|
if (r1 >= ne11) {
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
|
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
|
|
|
|
|
float sumf = 0;
|
|
|
for (int i = tiisg; i < ne00; i += 32) {
|
|
|
- sumf += (float) x[i] * (float) y[i];
|
|
|
+ sumf += (T0) x[i] * (T1) y[i];
|
|
|
}
|
|
|
|
|
|
float all_sum = simd_sum(sumf);
|
|
|
@@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
- device const float4 * x4 = (device const float4 *)x;
|
|
|
- for (int row = 0; row < N_F32_F32; ++row) {
|
|
|
+ device const T04 * x4 = (device const T04 *) x;
|
|
|
+ for (int row = 0; row < N_MV_T_T; ++row) {
|
|
|
int r1 = rb + row;
|
|
|
if (r1 >= ne11) {
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
|
- device const float4 * y4 = (device const float4 *) y;
|
|
|
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
|
|
+ device const T14 * y4 = (device const T14 *) y;
|
|
|
|
|
|
float sumf = 0;
|
|
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
|
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
|
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
|
|
}
|
|
|
|
|
|
float all_sum = simd_sum(sumf);
|
|
|
if (tiisg == 0) {
|
|
|
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
|
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_f32_f32")]]
|
|
|
-kernel void kernel_mul_mv_f32_f32(
|
|
|
+template<typename T0, typename T04, typename T1, typename T14>
|
|
|
+kernel void kernel_mul_mv(
|
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
@@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
constant uint & r3,
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
- kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
+ kernel_mul_mv_impl<T0, T04, T1, T14>(
|
|
|
+ src0,
|
|
|
+ src1,
|
|
|
+ dst,
|
|
|
+ ne00,
|
|
|
+ ne01,
|
|
|
+ ne02,
|
|
|
+ nb00,
|
|
|
+ nb01,
|
|
|
+ nb02,
|
|
|
+ ne10,
|
|
|
+ ne11,
|
|
|
+ ne12,
|
|
|
+ nb10,
|
|
|
+ nb11,
|
|
|
+ nb12,
|
|
|
+ ne0,
|
|
|
+ ne1,
|
|
|
+ r2,
|
|
|
+ r3,
|
|
|
+ tgpig,
|
|
|
+ tiisg);
|
|
|
}
|
|
|
|
|
|
-#define N_F16_F16 4
|
|
|
+typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
|
|
|
|
|
-kernel void kernel_mul_mv_f16_f16(
|
|
|
- device const char * src0,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
+template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
|
|
+template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
|
|
+template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
|
|
|
|
|
- const int64_t r0 = tgpig.x;
|
|
|
- const int64_t rb = tgpig.y*N_F16_F16;
|
|
|
- const int64_t im = tgpig.z;
|
|
|
-
|
|
|
- const uint i12 = im%ne12;
|
|
|
- const uint i13 = im/ne12;
|
|
|
-
|
|
|
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
|
-
|
|
|
- device const half * x = (device const half *) (src0 + offset0);
|
|
|
-
|
|
|
- if (ne00 < 128) {
|
|
|
- for (int row = 0; row < N_F16_F16; ++row) {
|
|
|
- int r1 = rb + row;
|
|
|
- if (r1 >= ne11) {
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
|
-
|
|
|
- float sumf = 0;
|
|
|
- for (int i = tiisg; i < ne00; i += 32) {
|
|
|
- sumf += (half) x[i] * (half) y[i];
|
|
|
- }
|
|
|
-
|
|
|
- float all_sum = simd_sum(sumf);
|
|
|
- if (tiisg == 0) {
|
|
|
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- device const half4 * x4 = (device const half4 *)x;
|
|
|
- for (int row = 0; row < N_F16_F16; ++row) {
|
|
|
- int r1 = rb + row;
|
|
|
- if (r1 >= ne11) {
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
|
- device const half4 * y4 = (device const half4 *) y;
|
|
|
-
|
|
|
- float sumf = 0;
|
|
|
- for (int i = tiisg; i < ne00/4; i += 32) {
|
|
|
- for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
|
|
- }
|
|
|
-
|
|
|
- float all_sum = simd_sum(sumf);
|
|
|
- if (tiisg == 0) {
|
|
|
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
|
|
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-void kernel_mul_mv_f16_f32_1row_impl(
|
|
|
+template<typename T, typename T4>
|
|
|
+kernel void kernel_mul_mv_1row(
|
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
@@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|
|
|
|
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
|
|
|
|
- device const half * x = (device const half *) (src0 + offset0);
|
|
|
+ device const T * x = (device const T *) (src0 + offset0);
|
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
|
|
|
|
float sumf = 0;
|
|
|
@@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
}
|
|
|
} else {
|
|
|
- device const half4 * x4 = (device const half4 *) x;
|
|
|
+ device const T4 * x4 = (device const T4 *) x;
|
|
|
device const float4 * y4 = (device const float4 *) y;
|
|
|
+
|
|
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
|
- for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
|
|
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
|
|
}
|
|
|
+
|
|
|
float all_sum = simd_sum(sumf);
|
|
|
+
|
|
|
if (tiisg == 0) {
|
|
|
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
|
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
|
|
-kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
- device const char * src0,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
- kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
-}
|
|
|
-
|
|
|
-#define N_F16_F32 4
|
|
|
-
|
|
|
-void kernel_mul_mv_f16_f32_impl(
|
|
|
- device const char * src0,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- int64_t ne00,
|
|
|
- int64_t ne01,
|
|
|
- int64_t ne02,
|
|
|
- uint64_t nb00,
|
|
|
- uint64_t nb01,
|
|
|
- uint64_t nb02,
|
|
|
- int64_t ne10,
|
|
|
- int64_t ne11,
|
|
|
- int64_t ne12,
|
|
|
- uint64_t nb10,
|
|
|
- uint64_t nb11,
|
|
|
- uint64_t nb12,
|
|
|
- int64_t ne0,
|
|
|
- int64_t ne1,
|
|
|
- uint r2,
|
|
|
- uint r3,
|
|
|
- uint3 tgpig,
|
|
|
- uint tiisg) {
|
|
|
-
|
|
|
- const int64_t r0 = tgpig.x;
|
|
|
- const int64_t rb = tgpig.y*N_F16_F32;
|
|
|
- const int64_t im = tgpig.z;
|
|
|
-
|
|
|
- const uint i12 = im%ne12;
|
|
|
- const uint i13 = im/ne12;
|
|
|
-
|
|
|
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
|
-
|
|
|
- device const half * x = (device const half *) (src0 + offset0);
|
|
|
-
|
|
|
- if (ne00 < 128) {
|
|
|
- for (int row = 0; row < N_F16_F32; ++row) {
|
|
|
- int r1 = rb + row;
|
|
|
- if (r1 >= ne11) {
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
|
-
|
|
|
- float sumf = 0;
|
|
|
- for (int i = tiisg; i < ne00; i += 32) {
|
|
|
- sumf += (float) x[i] * (float) y[i];
|
|
|
- }
|
|
|
-
|
|
|
- float all_sum = simd_sum(sumf);
|
|
|
- if (tiisg == 0) {
|
|
|
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- device const half4 * x4 = (device const half4 *)x;
|
|
|
- for (int row = 0; row < N_F16_F32; ++row) {
|
|
|
- int r1 = rb + row;
|
|
|
- if (r1 >= ne11) {
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
|
- device const float4 * y4 = (device const float4 *) y;
|
|
|
-
|
|
|
- float sumf = 0;
|
|
|
- for (int i = tiisg; i < ne00/4; i += 32) {
|
|
|
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
|
- }
|
|
|
+typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
|
|
|
|
|
- float all_sum = simd_sum(sumf);
|
|
|
- if (tiisg == 0) {
|
|
|
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
|
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-[[host_name("kernel_mul_mv_f16_f32")]]
|
|
|
-kernel void kernel_mul_mv_f16_f32(
|
|
|
- device const char * src0,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant int64_t & ne11,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiisg[[thread_index_in_simdgroup]]) {
|
|
|
- kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
|
-}
|
|
|
+template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
|
|
|
|
|
// Assumes row size (ne00) is a multiple of 4
|
|
|
-kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
+template<typename T, typename T4>
|
|
|
+kernel void kernel_mul_mv_l4(
|
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
|
device float * dst,
|
|
|
@@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
|
|
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
|
|
|
|
- device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
|
|
+ device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
|
|
|
|
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
|
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
|
|
|
|
|
float sumf = 0;
|
|
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
|
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
|
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
|
|
}
|
|
|
|
|
|
float all_sum = simd_sum(sumf);
|
|
|
@@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
|
|
+
|
|
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
|
return 1.0f - min(1.0f, max(0.0f, y));
|
|
|
@@ -2765,91 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
|
|
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
|
|
|
|
|
-kernel void kernel_cpy_f16_f16(
|
|
|
- device const half * src0,
|
|
|
- device half * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant int64_t & ne03,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant uint64_t & nb03,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant int64_t & ne2,
|
|
|
- constant int64_t & ne3,
|
|
|
- constant uint64_t & nb0,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint64_t & nb2,
|
|
|
- constant uint64_t & nb3,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
- const int64_t i03 = tgpig[2];
|
|
|
- const int64_t i02 = tgpig[1];
|
|
|
- const int64_t i01 = tgpig[0];
|
|
|
-
|
|
|
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
-
|
|
|
- const int64_t i3 = n / (ne2*ne1*ne0);
|
|
|
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
|
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
|
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
-
|
|
|
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
|
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
- dst_data[i00] = src[0];
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-kernel void kernel_cpy_f16_f32(
|
|
|
- device const half * src0,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant int64_t & ne03,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant uint64_t & nb03,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant int64_t & ne2,
|
|
|
- constant int64_t & ne3,
|
|
|
- constant uint64_t & nb0,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint64_t & nb2,
|
|
|
- constant uint64_t & nb3,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
- const int64_t i03 = tgpig[2];
|
|
|
- const int64_t i02 = tgpig[1];
|
|
|
- const int64_t i01 = tgpig[0];
|
|
|
-
|
|
|
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
-
|
|
|
- const int64_t i3 = n / (ne2*ne1*ne0);
|
|
|
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
|
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
|
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
-
|
|
|
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
|
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
- dst_data[i00] = src[0];
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-kernel void kernel_cpy_f32_f16(
|
|
|
- device const float * src0,
|
|
|
- device half * dst,
|
|
|
+template<typename T0, typename T1>
|
|
|
+kernel void kernel_cpy(
|
|
|
+ device const void * src0,
|
|
|
+ device void * dst,
|
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
|
@@ -2880,56 +2627,20 @@ kernel void kernel_cpy_f32_f16(
|
|
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
|
|
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
|
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
-
|
|
|
- dst_data[i00] = src[0];
|
|
|
+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
+ dst_data[i00] = (T1) src[0];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-kernel void kernel_cpy_f32_f32(
|
|
|
- device const float * src0,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne01,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant int64_t & ne03,
|
|
|
- constant uint64_t & nb00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant uint64_t & nb03,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant int64_t & ne2,
|
|
|
- constant int64_t & ne3,
|
|
|
- constant uint64_t & nb0,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint64_t & nb2,
|
|
|
- constant uint64_t & nb3,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
- const int64_t i03 = tgpig[2];
|
|
|
- const int64_t i02 = tgpig[1];
|
|
|
- const int64_t i01 = tgpig[0];
|
|
|
-
|
|
|
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
+typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
|
|
|
|
|
- const int64_t i3 = n / (ne2*ne1*ne0);
|
|
|
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
|
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
|
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
-
|
|
|
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
|
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
-
|
|
|
- dst_data[i00] = src[0];
|
|
|
- }
|
|
|
-}
|
|
|
+template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
|
|
+template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
|
|
+template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
|
|
+template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
|
|
|
|
|
kernel void kernel_cpy_f32_q8_0(
|
|
|
device const float * src0,
|
|
|
@@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
|
}
|
|
|
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
-kernel void kernel_get_rows(
|
|
|
+kernel void kernel_get_rows_q(
|
|
|
device const void * src0,
|
|
|
- device const char * src1,
|
|
|
+ device const void * src1,
|
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
|
constant uint64_t & nb01,
|
|
|
@@ -5745,55 +5456,24 @@ kernel void kernel_get_rows(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
- //const int64_t i = tgpig;
|
|
|
- //const int64_t r = ((device int32_t *) src1)[i];
|
|
|
-
|
|
|
const int64_t i10 = tgpig.x;
|
|
|
const int64_t i11 = tgpig.y;
|
|
|
|
|
|
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
|
|
|
const int64_t i02 = i11;
|
|
|
|
|
|
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
|
|
float4x4 temp;
|
|
|
- dequantize_func(
|
|
|
- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
|
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
|
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-kernel void kernel_get_rows_f32(
|
|
|
- device const void * src0,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint64_t & nb2,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
- const int64_t i10 = tgpig.x;
|
|
|
- const int64_t i11 = tgpig.y;
|
|
|
-
|
|
|
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
-
|
|
|
- const int64_t i02 = i11;
|
|
|
-
|
|
|
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
|
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-kernel void kernel_get_rows_f16(
|
|
|
+template<typename T>
|
|
|
+kernel void kernel_get_rows_f(
|
|
|
device const void * src0,
|
|
|
- device const char * src1,
|
|
|
+ device const void * src1,
|
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
|
constant uint64_t & nb01,
|
|
|
@@ -5809,19 +5489,19 @@ kernel void kernel_get_rows_f16(
|
|
|
const int64_t i10 = tgpig.x;
|
|
|
const int64_t i11 = tgpig.y;
|
|
|
|
|
|
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
|
|
|
const int64_t i02 = i11;
|
|
|
|
|
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
|
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
|
+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
+ ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
kernel void kernel_get_rows_i32(
|
|
|
device const void * src0,
|
|
|
- device const char * src1,
|
|
|
+ device const void * src1,
|
|
|
device int32_t * dst,
|
|
|
constant int64_t & ne00,
|
|
|
constant uint64_t & nb01,
|
|
|
@@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
|
|
|
const int64_t i10 = tgpig.x;
|
|
|
const int64_t i11 = tgpig.y;
|
|
|
|
|
|
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
|
|
|
|
const int64_t i02 = i11;
|
|
|
|
|
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
|
- ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
- ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
|
+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
|
+ ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
|
|
|
#define SG_MAT_ROW 8
|
|
|
|
|
|
// each block_q contains 16*nl weights
|
|
|
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
-void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
- device const uchar * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
+template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
|
|
+kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
+ device const uchar * src1,
|
|
|
+ device float * dst,
|
|
|
+ constant int64_t & ne00,
|
|
|
+ constant int64_t & ne02,
|
|
|
+ constant uint64_t & nb01,
|
|
|
+ constant uint64_t & nb02,
|
|
|
+ constant int64_t & ne12,
|
|
|
+ constant uint64_t & nb10,
|
|
|
+ constant uint64_t & nb11,
|
|
|
+ constant uint64_t & nb12,
|
|
|
+ constant int64_t & ne0,
|
|
|
+ constant int64_t & ne1,
|
|
|
+ constant uint & r2,
|
|
|
+ constant uint & r3,
|
|
|
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
+ uint tiitg[[thread_index_in_threadgroup]],
|
|
|
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
- threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
|
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
|
|
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
|
|
|
|
const uint r0 = tgpig.y;
|
|
|
@@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
|
|
|
|
- simdgroup_half8x8 ma[4];
|
|
|
+ simdgroup_T8x8 ma[4];
|
|
|
simdgroup_float8x8 mb[2];
|
|
|
simdgroup_float8x8 c_res[8];
|
|
|
for (int i = 0; i < 8; i++){
|
|
|
@@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
|
|
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
|
// load data and store to threadgroup memory
|
|
|
- half4x4 temp_a;
|
|
|
+ T4x4 temp_a;
|
|
|
dequantize_func(x, il, temp_a);
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
@@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
// load matrices from threadgroup memory and conduct outer products
|
|
|
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
|
+ threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
|
|
|
|
#pragma unroll(4)
|
|
|
@@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
-kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
- device const uchar * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant int64_t & ne02,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne12,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb12,
|
|
|
- constant int64_t & ne0,
|
|
|
- constant int64_t & ne1,
|
|
|
- constant uint & r2,
|
|
|
- constant uint & r3,
|
|
|
- threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- uint tiitg[[thread_index_in_threadgroup]],
|
|
|
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
|
- src0,
|
|
|
- src1,
|
|
|
- dst,
|
|
|
- ne00,
|
|
|
- ne02,
|
|
|
- nb01,
|
|
|
- nb02,
|
|
|
- ne12,
|
|
|
- nb10,
|
|
|
- nb11,
|
|
|
- nb12,
|
|
|
- ne0,
|
|
|
- ne1,
|
|
|
- r2,
|
|
|
- r3,
|
|
|
- shared_memory,
|
|
|
- tgpig,
|
|
|
- tiitg,
|
|
|
- sgitg);
|
|
|
-}
|
|
|
-
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
|
kernel void kernel_mul_mm_id(
|
|
|
device const uchar * src0s,
|
|
|
@@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
|
|
|
// get rows
|
|
|
//
|
|
|
|
|
|
-typedef void (get_rows_t)(
|
|
|
- device const void * src0,
|
|
|
- device const char * src1,
|
|
|
- device float * dst,
|
|
|
- constant int64_t & ne00,
|
|
|
- constant uint64_t & nb01,
|
|
|
- constant uint64_t & nb02,
|
|
|
- constant int64_t & ne10,
|
|
|
- constant uint64_t & nb10,
|
|
|
- constant uint64_t & nb11,
|
|
|
- constant uint64_t & nb1,
|
|
|
- constant uint64_t & nb2,
|
|
|
- uint3, uint, uint3);
|
|
|
-
|
|
|
-//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
|
-//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
|
-template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
|
-template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
|
-template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
|
-template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
|
|
-template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
|
|
-template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
-template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
-template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
-template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
-template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
-template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
-template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
-template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
-template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
-template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
-template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
-template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
-template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
-template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
+typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
|
|
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
|
|
+
|
|
|
+typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
|
|
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
|
|
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
|
|
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
|
|
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
|
|
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
+template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
+template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
+template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
+template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
+template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
+template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
|
|
//
|
|
|
// matrix-matrix multiplication
|
|
|
//
|
|
|
|
|
|
-typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
|
|
-
|
|
|
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
|
|
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
|
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
|
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
|
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
|
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
|
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
|
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
-template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
-template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
-template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
-template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
+typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
|
|
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
|
|
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
|
|
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
|
|
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
|
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
|
|
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
|
|
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
|
|
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
|
|
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
|
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
|
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
|
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
|
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
|
|
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
|
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
|
|
//
|
|
|
// indirect matrix-matrix multiplication
|
|
|
@@ -6436,7 +6065,7 @@ void mmv_fn(
|
|
|
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
|
|
}
|
|
|
|
|
|
-typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
|
|
|
+typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
|
|
|
|
|
template<mul_mv_impl_fn_t impl_fn>
|
|
|
kernel void kernel_mul_mv_id(
|
|
|
@@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
|
|
|
sgitg);
|
|
|
}
|
|
|
|
|
|
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
|
|
|
-
|
|
|
-template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
|
|
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
|
|
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
|
|
+
|
|
|
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
|
|
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
|
|
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
|
|
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|