|
|
@@ -974,9 +974,16 @@ kernel void kernel_mul(
|
|
|
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
|
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
|
|
|
|
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
- const int i10 = i0%args.ne10;
|
|
|
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
|
|
+ if (args.ne10 == 1) {
|
|
|
+ const float x = *((device float *)(src1_ptr));
|
|
|
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
+ const int i10 = i0%args.ne10;
|
|
|
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1000,9 +1007,16 @@ kernel void kernel_div(
|
|
|
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
|
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
|
|
|
|
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
- const int i10 = i0%args.ne10;
|
|
|
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
|
|
+ if (args.ne10 == 1) {
|
|
|
+ const float x = 1.0f / *((device float *)(src1_ptr));
|
|
|
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
+ const int i10 = i0%args.ne10;
|
|
|
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -7491,97 +7505,81 @@ kernel void kernel_mul_mm(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template<typename T4>
|
|
|
+template<short ne20> // n_expert_used
|
|
|
kernel void kernel_mul_mm_id_map0(
|
|
|
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
|
|
- device const char * src1,
|
|
|
device const char * src2,
|
|
|
- device char * hsrc1,
|
|
|
device char * htpe,
|
|
|
device char * hids,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
- const int ide = tgpig[0]; // expert id
|
|
|
-
|
|
|
- int n_all = 0;
|
|
|
+ threadgroup char * shmem [[threadgroup(0)]],
|
|
|
+ ushort tpitg[[thread_position_in_threadgroup]],
|
|
|
+ ushort ntg[[threads_per_threadgroup]]) {
|
|
|
+ const short ide = tpitg; // expert id
|
|
|
|
|
|
- device int32_t * ids_i32 = (device int32_t *) (hids);
|
|
|
+ uint32_t n_all = 0;
|
|
|
|
|
|
- for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
|
|
- device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
|
|
+ device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
|
|
|
|
|
|
- for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
|
|
- if (src2_i32[i20] != ide) {
|
|
|
- continue;
|
|
|
- }
|
|
|
+ for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
|
|
|
+ if (i21 + tpitg < args.ne21) {
|
|
|
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
|
|
|
|
|
|
- device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
|
|
- device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
|
|
+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
|
|
|
|
|
|
- for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
|
|
- hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
|
|
+ #pragma unroll(ne20)
|
|
|
+ for (short i20 = 0; i20 < ne20; i20++) {
|
|
|
+ sids[i20] = src2_i32[i20];
|
|
|
}
|
|
|
-
|
|
|
- if (tpitg.x == 0) {
|
|
|
- ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
|
|
- }
|
|
|
-
|
|
|
- ++n_all;
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- if (tpitg.x == 0) {
|
|
|
- device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
|
|
- tpe_i32[ide] = n_all;
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
|
|
-
|
|
|
-template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
-template<typename T>
|
|
|
-kernel void kernel_mul_mm_id_map1(
|
|
|
- constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
|
|
- device const char * hdst,
|
|
|
- device const char * hids,
|
|
|
- device char * dst,
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
- ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
- ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
- const int i20 = tgpig[0]; // used expert
|
|
|
- const int i21 = tgpig[1]; // token
|
|
|
+ for (short t = 0; t < ntg; t++) {
|
|
|
+ if (i21 + t >= args.ne21) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
|
|
|
- device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
|
|
- device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
|
|
+ threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
|
|
|
|
|
|
- const int id = ids_i32[i21*args.ne20 + i20];
|
|
|
+ short sel = 0;
|
|
|
+ #pragma unroll(ne20)
|
|
|
+ for (short i20 = 0; i20 < ne20; i20++) {
|
|
|
+ sel += (sids[i20] == ide)*(i20 + 1);
|
|
|
+ }
|
|
|
|
|
|
- const int ide = id / args.neh1;
|
|
|
- const int idt = id % args.neh1;
|
|
|
+ ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
|
|
|
|
|
|
- device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
|
|
+ n_all += sel > 0;
|
|
|
+ }
|
|
|
|
|
|
- for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
|
|
- dst_f32x4[i0] = hdst_f32x4[i0];
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
}
|
|
|
+
|
|
|
+ device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
|
|
|
+ tpe_u32[ide] = n_all;
|
|
|
}
|
|
|
|
|
|
-typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
|
|
+typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
|
|
|
|
|
-template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
|
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
|
|
|
|
|
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
|
|
kernel void kernel_mul_mm_id(
|
|
|
constant ggml_metal_kargs_mul_mm_id & args,
|
|
|
device const char * src0,
|
|
|
device const char * src1,
|
|
|
- device const char * tpe,
|
|
|
+ device const char * htpe,
|
|
|
+ device const char * hids,
|
|
|
device char * dst,
|
|
|
threadgroup char * shmem [[threadgroup(0)]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
+ ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
threadgroup T * sa = (threadgroup T *)(shmem);
|
|
|
@@ -7589,19 +7587,20 @@ kernel void kernel_mul_mm_id(
|
|
|
|
|
|
const int r0 = tgpig.y;
|
|
|
const int r1 = tgpig.x;
|
|
|
- const int im = tgpig.z;
|
|
|
+ const int im = tgpig.z; // expert
|
|
|
|
|
|
- device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
|
|
+ device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
|
|
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
|
|
|
|
|
- const int neh1 = tpe_i32[im];
|
|
|
+ const int32_t neh1 = tpe_u32[im];
|
|
|
|
|
|
if (r1*BLOCK_SIZE_N >= neh1) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
// if this block is of 64x32 shape or smaller
|
|
|
- const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
|
- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
|
+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
|
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
|
|
|
|
// a thread shouldn't load data outside of the matrix
|
|
|
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
|
@@ -7617,20 +7616,23 @@ kernel void kernel_mul_mm_id(
|
|
|
|
|
|
short il = (tiitg % THREAD_PER_ROW);
|
|
|
|
|
|
- const int i12 = im%args.neh12;
|
|
|
- const int i13 = im/args.neh12;
|
|
|
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
|
|
|
|
|
|
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
|
+ const short i11 = (id % args.ne20) % args.ne11;
|
|
|
+ const short i12 = (id / args.ne20);
|
|
|
+ const short i13 = 0;
|
|
|
+
|
|
|
+ const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
|
|
const short offset1 = il/nl;
|
|
|
|
|
|
device const block_q * x = (device const block_q *)(src0
|
|
|
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
|
|
|
|
|
- device const half * y = (device const half *)(src1
|
|
|
- + args.nbh13*i13
|
|
|
- + args.nbh12*i12
|
|
|
- + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
|
|
- + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
|
+ device const float * y = (device const float *)(src1
|
|
|
+ + args.nb13*i13
|
|
|
+ + args.nb12*i12
|
|
|
+ + args.nb11*i11
|
|
|
+ + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
|
|
|
|
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
|
|
// load data and store to threadgroup memory
|
|
|
@@ -7646,7 +7648,7 @@ kernel void kernel_mul_mm_id(
|
|
|
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
|
|
}
|
|
|
|
|
|
- *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
|
|
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
|
|
|
|
|
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
|
@@ -7682,43 +7684,38 @@ kernel void kernel_mul_mm_id(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
|
|
- device float * C = (device float *) dst +
|
|
|
- (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
|
|
- (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- for (short i = 0; i < 8; i++) {
|
|
|
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
|
|
- }
|
|
|
- } else {
|
|
|
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
- threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
|
|
- + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
|
|
- for (short i = 0; i < 8; i++) {
|
|
|
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
|
|
- }
|
|
|
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
|
|
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ #pragma unroll(8)
|
|
|
+ for (short i = 0; i < 8; i++) {
|
|
|
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
|
|
+ }
|
|
|
|
|
|
- if (sgitg == 0) {
|
|
|
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
|
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
|
|
- device float4 * D4 = (device float4 *) D;
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
|
|
- threadgroup float4 * C4 = (threadgroup float4 *) C;
|
|
|
+ for (short j = sgitg; j < n_cols; j += 4) {
|
|
|
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
|
|
|
|
|
|
- int i = 0;
|
|
|
- for (; i < n_rows/4; i++) {
|
|
|
- *(D4 + i) = *(C4 + i);
|
|
|
- }
|
|
|
+ const short ide = id % args.ne20;
|
|
|
+ const short idt = id / args.ne20;
|
|
|
|
|
|
- i *= 4;
|
|
|
- for (; i < n_rows; i++) {
|
|
|
- *(D + i) = *(C + i);
|
|
|
- }
|
|
|
- }
|
|
|
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
|
|
|
+ device float4 * D4 = (device float4 *) D;
|
|
|
+
|
|
|
+ threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
|
|
|
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
|
|
|
+
|
|
|
+ int i = tiisg;
|
|
|
+ for (; i < n_rows/4; i += 32) {
|
|
|
+ *(D4 + i) = *(C4 + i);
|
|
|
+ }
|
|
|
+
|
|
|
+ i = (4*(n_rows/4)) + tiisg;
|
|
|
+ for (; i < n_rows; i += 32) {
|
|
|
+ *(D + i) = *(C + i);
|
|
|
}
|
|
|
}
|
|
|
}
|