|
|
@@ -523,13 +523,6 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
|
|
|
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
|
|
@@ -1562,13 +1555,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
|
|
@@ -1909,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_OP_ARANGE:
|
|
|
return true;
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
- if (op->src[0]->ne[0] == 32) {
|
|
|
- // head size == 32 (e.g. bert-bge-small)
|
|
|
- // TODO: not sure if it is worth adding kernels for this size
|
|
|
+ // for new head sizes, add checks here
|
|
|
+ if (op->src[0]->ne[0] != 40 &&
|
|
|
+ op->src[0]->ne[0] != 64 &&
|
|
|
+ op->src[0]->ne[0] != 80 &&
|
|
|
+ op->src[0]->ne[0] != 96 &&
|
|
|
+ op->src[0]->ne[0] != 112 &&
|
|
|
+ op->src[0]->ne[0] != 128 &&
|
|
|
+ op->src[0]->ne[0] != 192 &&
|
|
|
+ op->src[0]->ne[0] != 256) {
|
|
|
return false;
|
|
|
}
|
|
|
if (op->src[0]->ne[0] == 576) {
|
|
|
@@ -5138,10 +5130,8 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
bool use_vec_kernel = false;
|
|
|
|
|
|
- // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
|
|
- // for now avoiding mainly to keep the number of templates/kernels a bit lower
|
|
|
- // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
|
|
- if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
|
|
+ // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
|
|
|
+ if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
|
|
|
switch (src1->type) {
|
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
|
@@ -5329,24 +5319,6 @@ static int ggml_metal_encode_node(
|
|
|
use_vec_kernel = true;
|
|
|
|
|
|
switch (ne00) {
|
|
|
- case 40:
|
|
|
- {
|
|
|
- switch (src1->type) {
|
|
|
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
|
|
|
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
|
|
|
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
|
|
|
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
|
|
|
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
|
|
|
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
|
|
|
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
|
|
|
- default:
|
|
|
- {
|
|
|
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
|
|
- GGML_LOG_ERROR("add template specialization for this type\n");
|
|
|
- GGML_ABORT("add template specialization for this type");
|
|
|
- }
|
|
|
- }
|
|
|
- } break;
|
|
|
case 64:
|
|
|
{
|
|
|
switch (src1->type) {
|