|
|
@@ -291,6 +291,10 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
|
|
@@ -575,6 +579,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
|
|
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
|
|
|
GGML_METAL_KERNEL_TYPE_SET_I32,
|
|
|
GGML_METAL_KERNEL_TYPE_SET_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
|
@@ -1324,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
|
|
@@ -1609,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
|
@@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
|
// to the matrix-vector kernel
|
|
|
- const int ne11_mm_min = 4;
|
|
|
+ const int ne11_mm_min = 8;
|
|
|
|
|
|
// first try to use small-batch mat-mv kernels
|
|
|
// these should be efficient for BS [2, ~8]
|
|
|
- if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
|
|
+ if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
|
|
|
(
|
|
|
(
|
|
|
(
|
|
|
- src0t == GGML_TYPE_F16 || // TODO: helper function
|
|
|
+ src0t == GGML_TYPE_F32 || // TODO: helper function
|
|
|
+ src0t == GGML_TYPE_F16 ||
|
|
|
src0t == GGML_TYPE_Q4_0 ||
|
|
|
src0t == GGML_TYPE_Q4_1 ||
|
|
|
src0t == GGML_TYPE_Q5_0 ||
|
|
|
@@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
|
|
|
// values and there can be some tail effects when nsg is high. need to confirm this
|
|
|
//
|
|
|
const int nsg = 2; // num simdgroups per threadgroup
|
|
|
- const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
|
|
+
|
|
|
+ // num threads along row per simdgroup
|
|
|
+ int nxpsg = 0;
|
|
|
+ if (ne00 % 256 == 0 && ne11 < 3) {
|
|
|
+ nxpsg = 16;
|
|
|
+ } else if (ne00 % 128 == 0) {
|
|
|
+ nxpsg = 8;
|
|
|
+ } else {
|
|
|
+ nxpsg = 4;
|
|
|
+ }
|
|
|
+
|
|
|
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
|
|
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
|
|
int r1ptg = 4; // num src1 rows per threadgroup
|
|
|
@@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
switch (src0->type) {
|
|
|
+ case GGML_TYPE_F32:
|
|
|
+ switch (r1ptg) {
|
|
|
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
|
|
|
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
|
|
|
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
|
|
|
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
|
|
|
+ default: GGML_ABORT("not implemented");
|
|
|
+ } break;
|
|
|
case GGML_TYPE_F16:
|
|
|
switch (r1ptg) {
|
|
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
|
|
@@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
|
- case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
|
|
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
|
@@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
|
|
|
/*.nb33 =*/ nb33,
|
|
|
/*.ne1 =*/ ne1,
|
|
|
/*.ne2 =*/ ne2,
|
|
|
+ /*.ne3 =*/ ne3,
|
|
|
/*.scale =*/ scale,
|
|
|
/*.max_bias =*/ max_bias,
|
|
|
/*.m0 =*/ m0,
|
|
|
@@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
|
|
|
} else {
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
|
|
|
}
|
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
|
|
|
|
|
if (!use_vec_kernel) {
|
|
|
// half8x8 kernel
|
|
|
@@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
while (true) {
|
|
|
const size_t smem = FATTN_SMEM(nsgmax);
|
|
|
- if (smem > device.maxThreadgroupMemoryLength) {
|
|
|
+ if (smem > device.maxThreadgroupMemoryLength/2) {
|
|
|
break;
|
|
|
}
|
|
|
nsgmax *= 2;
|
|
|
@@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
const size_t smem = FATTN_SMEM(nsg);
|
|
|
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
|
|
+
|
|
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
|
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
|
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
-#undef FATTN_SMEM
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
+#undef FATTN_SMEM
|
|
|
} else {
|
|
|
// half4x4 kernel
|
|
|
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
|
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
|
|
+ const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
|
|
|
|
|
|
GGML_ASSERT(nqptg <= 32);
|
|
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
|
@@ -5561,15 +5593,17 @@ static int ggml_metal_encode_node(
|
|
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
|
// and store the soft_max values and the mask
|
|
|
//
|
|
|
- // ne00*(nsg)
|
|
|
+ // ne20*(nsg)
|
|
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
|
//
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
|
|
+//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
|
|
|
|
|
|
int64_t nsgmax = 2;
|
|
|
while (true) {
|
|
|
const size_t smem = FATTN_SMEM(nsgmax);
|
|
|
- if (smem > device.maxThreadgroupMemoryLength) {
|
|
|
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
|
|
+ if (smem > device.maxThreadgroupMemoryLength/2) {
|
|
|
break;
|
|
|
}
|
|
|
nsgmax *= 2;
|
|
|
@@ -5577,7 +5611,7 @@ static int ggml_metal_encode_node(
|
|
|
nsgmax /= 2;
|
|
|
|
|
|
// simdgroups per threadgroup (a.k.a. warps)
|
|
|
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
|
|
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
|
|
|
|
|
int64_t nsg = 1;
|
|
|
while (nsg <= nsgt) {
|
|
|
@@ -5585,13 +5619,74 @@ static int ggml_metal_encode_node(
|
|
|
}
|
|
|
nsg /= 2;
|
|
|
|
|
|
- const size_t smem = FATTN_SMEM(nsg);
|
|
|
+ // workgroups
|
|
|
+ // each workgroup handles nsg*nkpsg cache values
|
|
|
+ uint16_t nwg = 1;
|
|
|
+ if (4*nsg*nkpsg >= ne11) {
|
|
|
+ const size_t smem = FATTN_SMEM(nsg);
|
|
|
|
|
|
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
|
|
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
|
|
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
|
|
|
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
|
|
+
|
|
|
+ // using 1 workgroup -> write the result directly into dst
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
|
|
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
|
|
|
+
|
|
|
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
+ } else {
|
|
|
+ nwg = 32;
|
|
|
+ nsg = MIN(4, nsg);
|
|
|
+
|
|
|
+ const size_t smem = FATTN_SMEM(nsg);
|
|
|
+
|
|
|
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
|
|
|
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
|
|
+
|
|
|
+ // sanity checks
|
|
|
+ GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
|
+ GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
|
|
|
+
|
|
|
+ const int32_t nrows = ne1*ne2*ne3;
|
|
|
+
|
|
|
+ // temp buffer for writing the results from each workgroup
|
|
|
+ // - ne20: the size of the head vector
|
|
|
+ // - + 2: the S and M values for each intermediate result
|
|
|
+ const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
|
|
|
+ id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
|
|
|
+ if (!h_tmp) {
|
|
|
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
|
|
|
+ //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
|
|
|
+
|
|
|
+ [encoder setBuffer:h_tmp offset:0 atIndex:6];
|
|
|
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
|
|
|
+
|
|
|
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
+
|
|
|
+ // reduce the results from the workgroups
|
|
|
+ {
|
|
|
+ ggml_metal_kargs_flash_attn_ext_reduce args0 = {
|
|
|
+ nrows,
|
|
|
+ ne20,
|
|
|
+ };
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline0];
|
|
|
+ [encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
|
|
|
+ [encoder setBuffer:h_tmp offset:0 atIndex:1];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
+
|
|
|
+ //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
|
|
|
+ }
|
|
|
+ }
|
|
|
#undef FATTN_SMEM
|
|
|
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
}
|
|
|
} break;
|
|
|
case GGML_OP_DUP:
|