|
|
@@ -20,6 +20,69 @@
|
|
|
|
|
|
#define UNUSED(x) (void)(x)
|
|
|
|
|
|
+// globals
|
|
|
+
|
|
|
+// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
|
|
+static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
|
|
+
|
|
|
+// initialized in ggml_backend_metal_reg
|
|
|
+static struct ggml_backend_reg g_ggml_backend_metal_reg;
|
|
|
+static struct ggml_backend_device g_ggml_backend_metal_device;
|
|
|
+
|
|
|
+// information about a Metal device
|
|
|
+// note: assumes single GPU device - the default one
|
|
|
+// TODO: support multiple GPU devices
|
|
|
+static struct ggml_backend_metal_device_context {
|
|
|
+ id<MTLDevice> mtl_device;
|
|
|
+ int mtl_device_ref_count;
|
|
|
+
|
|
|
+ bool support_simdgroup_reduction;
|
|
|
+ bool support_simdgroup_mm;
|
|
|
+
|
|
|
+ char name[128];
|
|
|
+} g_ggml_ctx_dev_main = {
|
|
|
+ /*.mtl_device =*/ nil,
|
|
|
+ /*.mtl_device_ref_count =*/ 0,
|
|
|
+ /*.support_simdgroup_reduction =*/ false,
|
|
|
+ /*.support_simdgroup_mm =*/ false,
|
|
|
+ /*.name =*/ "",
|
|
|
+};
|
|
|
+
|
|
|
+// acquire
|
|
|
+static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
|
|
+ assert(ctx != NULL);
|
|
|
+
|
|
|
+ if (ctx->mtl_device == nil) {
|
|
|
+ ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
|
+
|
|
|
+ ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
|
+ ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
|
+
|
|
|
+ ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
|
+
|
|
|
+ strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx->mtl_device_ref_count++;
|
|
|
+
|
|
|
+ return ctx->mtl_device;
|
|
|
+}
|
|
|
+
|
|
|
+// release
|
|
|
+static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
|
|
|
+ assert(ctx != NULL);
|
|
|
+ assert(ctx->mtl_device_ref_count > 0);
|
|
|
+
|
|
|
+ ctx->mtl_device_ref_count--;
|
|
|
+
|
|
|
+ if (ctx->mtl_device_ref_count == 0) {
|
|
|
+ [ctx->mtl_device release];
|
|
|
+ ctx->mtl_device = nil;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// kernels
|
|
|
+
|
|
|
struct ggml_metal_kernel {
|
|
|
id<MTLComputePipelineState> pipeline;
|
|
|
};
|
|
|
@@ -214,16 +277,12 @@ enum ggml_metal_kernel_type {
|
|
|
};
|
|
|
|
|
|
struct ggml_backend_metal_context {
|
|
|
- id<MTLDevice> device;
|
|
|
id<MTLCommandQueue> queue;
|
|
|
|
|
|
dispatch_queue_t d_queue;
|
|
|
|
|
|
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
|
|
|
|
|
- bool support_simdgroup_reduction;
|
|
|
- bool support_simdgroup_mm;
|
|
|
-
|
|
|
// capture state
|
|
|
bool capture_next_compute;
|
|
|
bool capture_started;
|
|
|
@@ -280,7 +339,7 @@ static void * ggml_metal_host_malloc(size_t n) {
|
|
|
return data;
|
|
|
}
|
|
|
|
|
|
-static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
+static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
|
|
|
GGML_LOG_INFO("%s: allocating\n", __func__);
|
|
|
|
|
|
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
|
|
@@ -292,14 +351,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
[devices release]; // since it was created by a *Copy* C method
|
|
|
#endif
|
|
|
|
|
|
- // Pick and show default Metal device
|
|
|
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
|
|
+ // init context
|
|
|
+ struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
|
+
|
|
|
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
|
|
|
|
- // Configure context
|
|
|
- struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
|
|
- ctx->device = device;
|
|
|
- ctx->queue = [ctx->device newCommandQueue];
|
|
|
+ ctx->queue = [device newCommandQueue];
|
|
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
|
|
|
|
id<MTLLibrary> metal_library;
|
|
|
@@ -332,7 +391,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
|
|
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
|
|
|
|
|
- metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
|
+ metal_library = [device newLibraryWithURL:libURL error:&error];
|
|
|
if (error) {
|
|
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
|
return NULL;
|
|
|
@@ -382,7 +441,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
|
|
|
//[options setFastMathEnabled:false];
|
|
|
|
|
|
- metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
|
+ metal_library = [device newLibraryWithSource:src options:options error:&error];
|
|
|
if (error) {
|
|
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
|
return NULL;
|
|
|
@@ -392,44 +451,37 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
}
|
|
|
|
|
|
// print MTL GPU family:
|
|
|
- GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
|
-
|
|
|
- const NSInteger MTLGPUFamilyMetal3 = 5001;
|
|
|
+ GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]);
|
|
|
|
|
|
// determine max supported GPU family
|
|
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
|
{
|
|
|
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
|
- if ([ctx->device supportsFamily:i]) {
|
|
|
+ if ([device supportsFamily:i]) {
|
|
|
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
|
|
- if ([ctx->device supportsFamily:i]) {
|
|
|
+ if ([device supportsFamily:i]) {
|
|
|
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
|
|
|
- if ([ctx->device supportsFamily:i]) {
|
|
|
- GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
|
|
|
+ for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
|
|
|
+ if ([device supportsFamily:i]) {
|
|
|
+ GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
|
|
- ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
|
|
|
-
|
|
|
- ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
|
|
-
|
|
|
- GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
|
|
|
- GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
|
|
- GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
|
|
|
|
|
ctx->capture_next_compute = false;
|
|
|
ctx->capture_started = false;
|
|
|
@@ -443,13 +495,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
|
|
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
|
- GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
|
|
- }
|
|
|
-#elif TARGET_OS_OSX
|
|
|
- if (ctx->device.maxTransferRate != 0) {
|
|
|
- GGML_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
|
|
- } else {
|
|
|
- GGML_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
|
+ GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
@@ -470,7 +516,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
if (supported) { \
|
|
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
|
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
|
|
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
|
|
+ kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
|
|
[metal_function release]; \
|
|
|
if (error) { \
|
|
|
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
|
@@ -481,6 +527,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
|
|
}
|
|
|
|
|
|
+ const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
|
|
+ const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
|
|
+
|
|
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
|
@@ -507,10 +556,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
|
@@ -535,101 +584,101 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
|
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
|
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
|
|
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
|
|
|
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
|
|
|
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
|
|
|
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
|
@@ -643,14 +692,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
|
|
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
|
|
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
|
|
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
|
|
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
|
@@ -684,7 +733,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
|
Block_release(ctx->encode_async);
|
|
|
|
|
|
[ctx->queue release];
|
|
|
- [ctx->device release];
|
|
|
|
|
|
dispatch_release(ctx->d_queue);
|
|
|
|
|
|
@@ -742,13 +790,16 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
|
|
return nil;
|
|
|
}
|
|
|
|
|
|
-static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
|
|
|
+static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
|
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
|
|
+ const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
|
|
+
|
|
|
switch (op->op) {
|
|
|
case GGML_OP_UNARY:
|
|
|
switch (ggml_get_unary_op(op)) {
|
|
|
@@ -786,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
|
case GGML_OP_SOFT_MAX:
|
|
|
case GGML_OP_RMS_NORM:
|
|
|
case GGML_OP_GROUP_NORM:
|
|
|
- return ctx->support_simdgroup_reduction;
|
|
|
+ return support_simdgroup_reduction;
|
|
|
case GGML_OP_NORM:
|
|
|
case GGML_OP_ROPE:
|
|
|
return true;
|
|
|
@@ -812,13 +863,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
|
if (op->src[0]->ne[0] == 256) {
|
|
|
return false;
|
|
|
}
|
|
|
- return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
|
+ return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
|
case GGML_OP_SSM_CONV:
|
|
|
case GGML_OP_SSM_SCAN:
|
|
|
return true;
|
|
|
case GGML_OP_MUL_MAT:
|
|
|
case GGML_OP_MUL_MAT_ID:
|
|
|
- return ctx->support_simdgroup_reduction &&
|
|
|
+ return support_simdgroup_reduction &&
|
|
|
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
|
|
case GGML_OP_CPY:
|
|
|
case GGML_OP_DUP:
|
|
|
@@ -862,9 +913,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
|
}
|
|
|
|
|
|
static void ggml_metal_encode_node(
|
|
|
- struct ggml_backend_metal_context * ctx,
|
|
|
+ ggml_backend_t backend,
|
|
|
int idx,
|
|
|
id<MTLComputeCommandEncoder> encoder) {
|
|
|
+ struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
+
|
|
|
struct ggml_cgraph * gf = ctx->gf;
|
|
|
|
|
|
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
|
|
@@ -894,7 +948,7 @@ static void ggml_metal_encode_node(
|
|
|
} break;
|
|
|
}
|
|
|
|
|
|
- if (!ggml_metal_supports_op(ctx, dst)) {
|
|
|
+ if (!ggml_metal_supports_op(ctx_dev, dst)) {
|
|
|
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
|
|
GGML_ABORT("unsupported op");
|
|
|
}
|
|
|
@@ -967,6 +1021,8 @@ static void ggml_metal_encode_node(
|
|
|
// dst->name);
|
|
|
//}
|
|
|
|
|
|
+ id<MTLDevice> device = ctx_dev->mtl_device;
|
|
|
+
|
|
|
switch (dst->op) {
|
|
|
case GGML_OP_CONCAT:
|
|
|
{
|
|
|
@@ -1675,7 +1731,7 @@ static void ggml_metal_encode_node(
|
|
|
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
|
// these numbers do not translate to other devices or model sizes
|
|
|
// TODO: need to find a better approach
|
|
|
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
|
+ if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
|
switch (src0t) {
|
|
|
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
|
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
|
@@ -1695,7 +1751,7 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
|
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
|
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
|
|
!ggml_is_transposed(src0) &&
|
|
|
!ggml_is_transposed(src1) &&
|
|
|
src1t == GGML_TYPE_F32 &&
|
|
|
@@ -1990,7 +2046,7 @@ static void ggml_metal_encode_node(
|
|
|
// ne21 = n_rows
|
|
|
const int dst_rows = ne20*ne21;
|
|
|
const int dst_rows_min = n_as;
|
|
|
- const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
|
|
+ const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
|
|
|
|
|
// max size of the rowids array in the kernel shared buffer
|
|
|
GGML_ASSERT(dst_rows <= dst_rows_max);
|
|
|
@@ -2001,7 +2057,7 @@ static void ggml_metal_encode_node(
|
|
|
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
|
// indirect matrix multiplication
|
|
|
// !!!
|
|
|
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
|
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
|
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
|
dst_rows > dst_rows_min) {
|
|
|
|
|
|
@@ -2840,7 +2896,7 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
while (true) {
|
|
|
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
|
|
- if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
|
|
+ if (smem > device.maxThreadgroupMemoryLength) {
|
|
|
break;
|
|
|
}
|
|
|
nsgmax *= 2;
|
|
|
@@ -2852,8 +2908,8 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
|
|
|
|
|
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
|
|
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
|
|
+ //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
|
|
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
|
|
|
|
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
|
|
|
|
|
@@ -2878,8 +2934,8 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
|
|
|
|
|
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
|
|
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
|
|
+ //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
|
|
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
|
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
@@ -2954,8 +3010,11 @@ static void ggml_metal_encode_node(
|
|
|
}
|
|
|
|
|
|
static enum ggml_status ggml_metal_graph_compute(
|
|
|
- struct ggml_backend_metal_context * ctx,
|
|
|
- struct ggml_cgraph * gf) {
|
|
|
+ ggml_backend_t backend,
|
|
|
+ struct ggml_cgraph * gf) {
|
|
|
+ struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
+
|
|
|
// number of nodes encoded by the main thread (empirically determined)
|
|
|
const int n_main = 128;
|
|
|
|
|
|
@@ -2983,7 +3042,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
|
|
|
if (!ctx->capture_started) {
|
|
|
// create capture scope
|
|
|
- ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
|
|
|
+ ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
|
|
|
|
|
|
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
|
|
descriptor.captureObject = ctx->capture_scope;
|
|
|
@@ -3087,31 +3146,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
|
|
|
// backend interface
|
|
|
|
|
|
-// default buffer
|
|
|
-static id<MTLDevice> g_backend_device = nil;
|
|
|
-static int g_backend_device_ref_count = 0;
|
|
|
-
|
|
|
-static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
|
|
- if (g_backend_device == nil) {
|
|
|
- g_backend_device = MTLCreateSystemDefaultDevice();
|
|
|
- }
|
|
|
-
|
|
|
- g_backend_device_ref_count++;
|
|
|
-
|
|
|
- return g_backend_device;
|
|
|
-}
|
|
|
-
|
|
|
-static void ggml_backend_metal_free_device(void) {
|
|
|
- assert(g_backend_device_ref_count > 0);
|
|
|
-
|
|
|
- g_backend_device_ref_count--;
|
|
|
-
|
|
|
- if (g_backend_device_ref_count == 0) {
|
|
|
- [g_backend_device release];
|
|
|
- g_backend_device = nil;
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
|
return "Metal";
|
|
|
|
|
|
@@ -3124,7 +3158,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
|
for (int i = 0; i < ctx->n_buffers; i++) {
|
|
|
[ctx->buffers[i].metal release];
|
|
|
}
|
|
|
- ggml_backend_metal_free_device();
|
|
|
+ ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
|
|
|
|
if (ctx->owned) {
|
|
|
#if TARGET_OS_OSX
|
|
|
@@ -3227,7 +3261,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
|
}
|
|
|
|
|
|
- id<MTLDevice> device = ggml_backend_metal_get_device();
|
|
|
+ id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
|
|
|
|
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
|
|
ctx->all_size = size_aligned;
|
|
|
@@ -3241,16 +3275,16 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
|
|
|
if (size_aligned > 0) {
|
|
|
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
|
|
- length:size_aligned
|
|
|
- options:MTLResourceStorageModeShared
|
|
|
- deallocator:nil];
|
|
|
+ length:size_aligned
|
|
|
+ options:MTLResourceStorageModeShared
|
|
|
+ deallocator:nil];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
|
free(ctx);
|
|
|
- ggml_backend_metal_free_device();
|
|
|
+ ggml_backend_metal_device_rel(buft->device->context);
|
|
|
return NULL;
|
|
|
}
|
|
|
|
|
|
@@ -3265,9 +3299,9 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
|
|
|
}
|
|
|
|
|
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
|
- id<MTLDevice> device = ggml_backend_metal_get_device();
|
|
|
- size_t max_size = device.maxBufferLength;
|
|
|
- ggml_backend_metal_free_device();
|
|
|
+ id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
|
|
+ const size_t max_size = device.maxBufferLength;
|
|
|
+ ggml_backend_metal_device_rel(buft->device->context);
|
|
|
|
|
|
return max_size;
|
|
|
|
|
|
@@ -3290,15 +3324,14 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
|
|
},
|
|
|
- /* .device = */ NULL,
|
|
|
+ /* .device = */ &g_ggml_backend_metal_device,
|
|
|
/* .context = */ NULL,
|
|
|
};
|
|
|
|
|
|
return &ggml_backend_buffer_type_metal;
|
|
|
}
|
|
|
|
|
|
-// buffer from ptr
|
|
|
-
|
|
|
+// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
|
|
|
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
|
|
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
|
|
|
|
|
@@ -3321,7 +3354,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
|
}
|
|
|
|
|
|
- id<MTLDevice> device = ggml_backend_metal_get_device();
|
|
|
+ id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
|
|
|
|
|
// the buffer fits into the max buffer size allowed by the device
|
|
|
if (size_aligned <= device.maxBufferLength) {
|
|
|
@@ -3386,8 +3419,12 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
|
}
|
|
|
|
|
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
|
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
|
|
+ struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
+
|
|
|
+ ggml_backend_metal_device_rel(ctx_dev);
|
|
|
ggml_metal_free(ctx);
|
|
|
+
|
|
|
free(backend);
|
|
|
}
|
|
|
|
|
|
@@ -3398,21 +3435,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
|
|
|
}
|
|
|
|
|
|
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
|
- struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
|
|
|
-
|
|
|
- return ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
|
-}
|
|
|
-
|
|
|
-static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
|
|
- struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
|
|
|
-
|
|
|
- return ggml_metal_supports_op(metal_ctx, op);
|
|
|
-}
|
|
|
-
|
|
|
-static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
|
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
|
|
-
|
|
|
- UNUSED(backend);
|
|
|
+ return ggml_metal_graph_compute(backend, cgraph);
|
|
|
}
|
|
|
|
|
|
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
@@ -3459,7 +3482,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
|
}
|
|
|
|
|
|
- ggml_metal_encode_node(ctx, idx, encoder);
|
|
|
+ ggml_metal_encode_node(backend, idx, encoder);
|
|
|
|
|
|
if (should_capture) {
|
|
|
[encoder popDebugGroup];
|
|
|
@@ -3487,8 +3510,8 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|
|
/* .graph_plan_update = */ NULL,
|
|
|
/* .graph_plan_compute = */ NULL,
|
|
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
|
|
- /* .supports_op = */ ggml_backend_metal_supports_op,
|
|
|
- /* .supports_buft = */ ggml_backend_metal_supports_buft,
|
|
|
+ /* .supports_op = */ NULL,
|
|
|
+ /* .supports_buft = */ NULL,
|
|
|
/* .offload_op = */ NULL,
|
|
|
/* .event_record = */ NULL,
|
|
|
/* .event_wait = */ NULL,
|
|
|
@@ -3499,8 +3522,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
|
|
|
return &guid;
|
|
|
}
|
|
|
|
|
|
+// TODO: remove in the future
|
|
|
ggml_backend_t ggml_backend_metal_init(void) {
|
|
|
- struct ggml_backend_metal_context * ctx = ggml_metal_init();
|
|
|
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
|
|
|
+
|
|
|
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
|
|
if (ctx == NULL) {
|
|
|
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
|
|
return NULL;
|
|
|
@@ -3511,7 +3537,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
|
*backend = (struct ggml_backend) {
|
|
|
/* .guid = */ ggml_backend_metal_guid(),
|
|
|
/* .interface = */ ggml_backend_metal_i,
|
|
|
- /* .device = */ NULL,
|
|
|
+ /* .device = */ dev,
|
|
|
/* .context = */ ctx,
|
|
|
};
|
|
|
|
|
|
@@ -3536,9 +3562,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
|
|
|
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
|
|
|
|
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
|
- return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
|
+ return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
|
}
|
|
|
|
|
|
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|
|
@@ -3548,11 +3574,246 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|
|
ctx->capture_next_compute = true;
|
|
|
}
|
|
|
|
|
|
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
|
+// backend device
|
|
|
+
|
|
|
+static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
|
|
+ return "Metal";
|
|
|
|
|
|
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
|
- return ggml_backend_metal_init();
|
|
|
+ GGML_UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
|
|
+ // acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
|
+ ggml_backend_metal_device_acq(ctx_dev);
|
|
|
+ ggml_backend_metal_device_rel(ctx_dev);
|
|
|
+
|
|
|
+ return ctx_dev->name;
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
|
+ if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
|
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
|
+
|
|
|
+ *total = device.recommendedMaxWorkingSetSize;
|
|
|
+ *free = *total - device.currentAllocatedSize;
|
|
|
+
|
|
|
+ ggml_backend_metal_device_rel(ctx_dev);
|
|
|
+ } else {
|
|
|
+ *free = 1;
|
|
|
+ *total = 1;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
|
|
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
|
|
+
|
|
|
+ GGML_UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
|
+ props->name = ggml_backend_metal_device_get_name(dev);
|
|
|
+ props->description = ggml_backend_metal_device_get_description(dev);
|
|
|
+ props->type = ggml_backend_metal_device_get_type(dev);
|
|
|
+ ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
|
+ props->caps = (struct ggml_backend_dev_caps) {
|
|
|
+ /* .async = */ false,
|
|
|
+ /* .host_buffer = */ false,
|
|
|
+ /* .buffer_from_host_ptr = */ true,
|
|
|
+ /* .events = */ false,
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
|
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
|
|
+ if (ctx == NULL) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
|
|
|
+
|
|
|
+ *backend = (struct ggml_backend) {
|
|
|
+ /* .guid = */ ggml_backend_metal_guid(),
|
|
|
+ /* .interface = */ ggml_backend_metal_i,
|
|
|
+ /* .device = */ dev,
|
|
|
+ /* .context = */ ctx,
|
|
|
+ };
|
|
|
+
|
|
|
+ ggml_backend_metal_set_n_cb(backend, 1);
|
|
|
+
|
|
|
+ return backend;
|
|
|
|
|
|
GGML_UNUSED(params);
|
|
|
- GGML_UNUSED(user_data);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
|
+ return ggml_backend_metal_buffer_type();
|
|
|
+
|
|
|
+ GGML_UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
|
|
+ struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
|
|
+
|
|
|
+ ctx->all_data = ptr;
|
|
|
+ ctx->all_size = size;
|
|
|
+ ctx->owned = false;
|
|
|
+ ctx->n_buffers = 0;
|
|
|
+
|
|
|
+ const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
|
+
|
|
|
+ // page-align the data ptr
|
|
|
+ {
|
|
|
+ const uintptr_t offs = (uintptr_t) ptr % size_page;
|
|
|
+ ptr = (void *) ((char *) ptr - offs);
|
|
|
+ size += offs;
|
|
|
+ }
|
|
|
+
|
|
|
+ size_t size_aligned = size;
|
|
|
+ if ((size_aligned % size_page) != 0) {
|
|
|
+ size_aligned += (size_page - (size_aligned % size_page));
|
|
|
+ }
|
|
|
+
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
|
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
|
+
|
|
|
+ // the buffer fits into the max buffer size allowed by the device
|
|
|
+ if (size_aligned <= device.maxBufferLength) {
|
|
|
+ ctx->buffers[ctx->n_buffers].data = ptr;
|
|
|
+ ctx->buffers[ctx->n_buffers].size = size;
|
|
|
+ ctx->buffers[ctx->n_buffers].metal = nil;
|
|
|
+
|
|
|
+ if (size_aligned > 0) {
|
|
|
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
|
+
|
|
|
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_backend_metal_log_allocated_size(device, size_aligned);
|
|
|
+
|
|
|
+ ++ctx->n_buffers;
|
|
|
+ } else {
|
|
|
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
|
|
+ // one of the views
|
|
|
+ const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
|
|
+ const size_t size_step = device.maxBufferLength - size_ovlp;
|
|
|
+ const size_t size_view = device.maxBufferLength;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < size; i += size_step) {
|
|
|
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
|
|
+
|
|
|
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i);
|
|
|
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
|
|
+ ctx->buffers[ctx->n_buffers].metal = nil;
|
|
|
+
|
|
|
+ if (size_step_aligned > 0) {
|
|
|
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
|
+
|
|
|
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
|
|
+
|
|
|
+ if (i + size_step < size) {
|
|
|
+ GGML_LOG_INFO("\n");
|
|
|
+ }
|
|
|
+
|
|
|
+ ++ctx->n_buffers;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
|
+
|
|
|
+ return ggml_metal_supports_op(ctx_dev, op);
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
|
+ return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
|
|
+
|
|
|
+ UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
|
+ return false;
|
|
|
+
|
|
|
+ GGML_UNUSED(dev);
|
|
|
+ GGML_UNUSED(op);
|
|
|
+}
|
|
|
+
|
|
|
+static struct ggml_backend_device_i ggml_backend_metal_device_i = {
|
|
|
+ /* .get_name = */ ggml_backend_metal_device_get_name,
|
|
|
+ /* .get_description = */ ggml_backend_metal_device_get_description,
|
|
|
+ /* .get_memory = */ ggml_backend_metal_device_get_memory,
|
|
|
+ /* .get_type = */ ggml_backend_metal_device_get_type,
|
|
|
+ /* .get_props = */ ggml_backend_metal_device_get_props,
|
|
|
+ /* .init_backend = */ ggml_backend_metal_device_init,
|
|
|
+ /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
|
|
|
+ /* .get_host_buffer_type = */ NULL,
|
|
|
+ /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
|
|
|
+ /* .supports_op = */ ggml_backend_metal_device_supports_op,
|
|
|
+ /* .supports_buft = */ ggml_backend_metal_device_supports_buft,
|
|
|
+ /* .offload_op = */ ggml_backend_metal_device_offload_op,
|
|
|
+ /* .event_new = */ NULL,
|
|
|
+ /* .event_free = */ NULL,
|
|
|
+ /* .event_synchronize = */ NULL,
|
|
|
+};
|
|
|
+
|
|
|
+// backend registry
|
|
|
+
|
|
|
+static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
|
|
|
+ return "Metal";
|
|
|
+
|
|
|
+ GGML_UNUSED(reg);
|
|
|
+}
|
|
|
+
|
|
|
+static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
|
|
|
+ return 1;
|
|
|
+
|
|
|
+ GGML_UNUSED(reg);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
|
|
|
+ GGML_ASSERT(index == 0);
|
|
|
+
|
|
|
+ return &g_ggml_backend_metal_device;
|
|
|
+
|
|
|
+ GGML_UNUSED(reg);
|
|
|
+ GGML_UNUSED(index);
|
|
|
+}
|
|
|
+
|
|
|
+static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
|
|
+ /* .get_name = */ ggml_backend_metal_reg_get_name,
|
|
|
+ /* .device_count = */ ggml_backend_metal_reg_device_count,
|
|
|
+ /* .device_get = */ ggml_backend_metal_reg_device_get,
|
|
|
+ /* .get_proc_address = */ NULL,
|
|
|
+};
|
|
|
+
|
|
|
+ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
|
|
+ // TODO: make this thread-safe somehow?
|
|
|
+ {
|
|
|
+ g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
|
|
+ /* .iface = */ ggml_backend_metal_reg_i,
|
|
|
+ /* .context = */ NULL,
|
|
|
+ };
|
|
|
+
|
|
|
+ g_ggml_backend_metal_device = (struct ggml_backend_device) {
|
|
|
+ /* .iface = */ ggml_backend_metal_device_i,
|
|
|
+ /* .reg = */ &g_ggml_backend_metal_reg,
|
|
|
+ /* .context = */ &g_ggml_ctx_dev_main,
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ return &g_ggml_backend_metal_reg;
|
|
|
}
|