|
|
@@ -39,6 +39,7 @@ static struct ggml_backend_metal_device_context {
|
|
|
bool has_simdgroup_reduction;
|
|
|
bool has_simdgroup_mm;
|
|
|
bool has_bfloat;
|
|
|
+ bool use_bfloat;
|
|
|
|
|
|
char name[128];
|
|
|
} g_ggml_ctx_dev_main = {
|
|
|
@@ -47,6 +48,7 @@ static struct ggml_backend_metal_device_context {
|
|
|
/*.has_simdgroup_reduction =*/ false,
|
|
|
/*.has_simdgroup_mm =*/ false,
|
|
|
/*.has_bfloat =*/ false,
|
|
|
+ /*.use_bfloat =*/ false,
|
|
|
/*.name =*/ "",
|
|
|
};
|
|
|
|
|
|
@@ -65,6 +67,12 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
|
|
|
|
|
+#if defined(GGML_METAL_USE_BF16)
|
|
|
+ ctx->use_bfloat = ctx->has_bfloat;
|
|
|
+#else
|
|
|
+ ctx->use_bfloat = false;
|
|
|
+#endif
|
|
|
+
|
|
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
|
}
|
|
|
|
|
|
@@ -504,6 +512,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
// dictionary of preprocessor macros
|
|
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
|
|
|
|
|
+ if (ctx_dev->use_bfloat) {
|
|
|
+ [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
|
|
|
+ }
|
|
|
+
|
|
|
MTLCompileOptions * options = [MTLCompileOptions new];
|
|
|
options.preprocessorMacros = prep;
|
|
|
|
|
|
@@ -556,7 +568,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
|
|
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
|
|
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
|
|
- GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
|
|
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
|
|
|
|
|
ctx->capture_next_compute = false;
|
|
|
@@ -608,7 +621,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
|
|
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
|
|
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
|
|
- const bool has_bfloat = ctx_dev->has_bfloat;
|
|
|
+ const bool use_bfloat = ctx_dev->use_bfloat;
|
|
|
|
|
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
|
|
|
|
@@ -644,7 +657,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
|
@@ -671,10 +684,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
|
|
@@ -703,7 +716,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
|
|
@@ -725,7 +738,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
|
|
@@ -747,7 +760,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
|
|
@@ -788,12 +801,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
|
@@ -825,14 +838,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
|
@@ -840,11 +853,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
|
|
@@ -936,9 +949,9 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
|
|
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
|
|
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
|
|
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
|
|
- const bool has_bfloat = ctx_dev->has_bfloat;
|
|
|
+ const bool use_bfloat = ctx_dev->use_bfloat;
|
|
|
|
|
|
- if (!has_bfloat) {
|
|
|
+ if (!use_bfloat) {
|
|
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
|
|
return false;
|