|
|
@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
|
|
|
// note: assumes single GPU device - the default one
|
|
|
// TODO: support multiple GPU devices
|
|
|
static struct ggml_backend_metal_device_context {
|
|
|
- id<MTLDevice> mtl_device;
|
|
|
- int mtl_device_ref_count;
|
|
|
+ id<MTLDevice> mtl_device;
|
|
|
+ int mtl_device_ref_count;
|
|
|
id<MTLLibrary> mtl_library;
|
|
|
|
|
|
bool has_simdgroup_reduction;
|
|
|
@@ -490,7 +490,259 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_COUNT
|
|
|
};
|
|
|
|
|
|
+//
|
|
|
+// ggml_metal_heap
|
|
|
+//
|
|
|
+
|
|
|
+struct ggml_metal_heap {
|
|
|
+ // number of times the heap was unused
|
|
|
+ int n_unused;
|
|
|
+
|
|
|
+ // total number of buffer allocations in this heap across all computes
|
|
|
+ int64_t n_alloc;
|
|
|
+
|
|
|
+ // current offset in the heap - we reset this after each node in order to reuse the memory
|
|
|
+ size_t offs;
|
|
|
+
|
|
|
+ // the currently allocated MTLBuffer objects in this heap
|
|
|
+ id<MTLHeap> obj;
|
|
|
+
|
|
|
+ NSMutableArray * bufs;
|
|
|
+};
|
|
|
+
|
|
|
+static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
|
|
+ struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
|
|
|
+
|
|
|
+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
|
|
+ desc.storageMode = MTLStorageModePrivate;
|
|
|
+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
|
|
+ desc.type = MTLHeapTypePlacement;
|
|
|
+ desc.size = size;
|
|
|
+
|
|
|
+ heap->n_unused = 0;
|
|
|
+ heap->n_alloc = 0;
|
|
|
+
|
|
|
+ heap->obj = [device newHeapWithDescriptor:desc];
|
|
|
+ if (!heap->obj) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
|
|
+
|
|
|
+ free(heap);
|
|
|
+
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ [desc release];
|
|
|
+
|
|
|
+ heap->bufs = [[NSMutableArray alloc] init];
|
|
|
+
|
|
|
+ return heap;
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
|
|
|
+ heap->offs = 0;
|
|
|
+
|
|
|
+ // count how many graph computes the heap ended up being unused
|
|
|
+ if ([heap->bufs count] > 0) {
|
|
|
+ heap->n_unused = 0;
|
|
|
+ } else {
|
|
|
+ heap->n_unused++;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (id<MTLBuffer> buf in heap->bufs) {
|
|
|
+ [buf release];
|
|
|
+ }
|
|
|
+ [heap->bufs removeAllObjects];
|
|
|
+
|
|
|
+ // tell the OS that it can reuse this memory if needed
|
|
|
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
|
|
+ [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
|
|
+ if (heap == nil) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_metal_heap_reset(heap);
|
|
|
+
|
|
|
+ [heap->obj release];
|
|
|
+ [heap->bufs release];
|
|
|
+
|
|
|
+ free(heap);
|
|
|
+}
|
|
|
+
|
|
|
+@interface ggml_metal_heap_ptr : NSObject
|
|
|
+
|
|
|
+@property (nonatomic, assign) struct ggml_metal_heap * data;
|
|
|
+
|
|
|
+@end
|
|
|
+
|
|
|
+@implementation ggml_metal_heap_ptr
|
|
|
+@end
|
|
|
+
|
|
|
+//
|
|
|
+// ggml_metal_mem_pool
|
|
|
+//
|
|
|
+
|
|
|
+struct ggml_metal_mem_pool {
|
|
|
+ id<MTLDevice> device;
|
|
|
+
|
|
|
+ int n_heaps; // total number of heaps ever created (including those that were removed)
|
|
|
+
|
|
|
+ NSMutableArray * heaps;
|
|
|
+ NSMutableArray * heaps_to_remove;
|
|
|
+};
|
|
|
+
|
|
|
+static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
|
|
|
+ struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
|
|
|
+
|
|
|
+ mem_pool->n_heaps = 0;
|
|
|
+
|
|
|
+ mem_pool->heaps = [[NSMutableArray alloc] init];
|
|
|
+ mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
|
|
+
|
|
|
+ return mem_pool;
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
|
|
|
+ GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
|
|
+
|
|
|
+ size_t size_all = 0;
|
|
|
+ size_t size_cur = 0;
|
|
|
+
|
|
|
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
|
|
+ GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
|
|
+ GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
|
|
+ GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
|
|
+ GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
|
|
+ GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
|
|
+
|
|
|
+ if ([ptr.data->bufs count] > 0) {
|
|
|
+ size_cur += [ptr.data->obj size];
|
|
|
+ }
|
|
|
+ size_all += [ptr.data->obj size];
|
|
|
+
|
|
|
+ ggml_metal_heap_free(ptr.data);
|
|
|
+ [ptr release];
|
|
|
+ }
|
|
|
+ [mem_pool->heaps release];
|
|
|
+ [mem_pool->heaps_to_remove release];
|
|
|
+
|
|
|
+ if (size_all > 0) {
|
|
|
+ GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
|
|
+ GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
|
|
+ }
|
|
|
+
|
|
|
+ free(mem_pool);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
|
|
+ for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
|
|
+ ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
|
|
+
|
|
|
+ struct ggml_metal_heap * heap = ptr.data;
|
|
|
+ ggml_metal_heap_reset(heap);
|
|
|
+
|
|
|
+ // if the heap hasn't been used for a while, remove it
|
|
|
+ if (heap->n_unused >= 128) {
|
|
|
+ [mem_pool->heaps_to_remove addObject:@(i)];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (mem_pool->heaps_to_remove.count > 0) {
|
|
|
+ for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
|
|
|
+ NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
|
|
+ ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
|
|
+
|
|
|
+ struct ggml_metal_heap * heap = ptr.data;
|
|
|
+ ggml_metal_heap_free(heap);
|
|
|
+
|
|
|
+ [mem_pool->heaps removeObjectAtIndex:index];
|
|
|
+ [ptr release];
|
|
|
+ }
|
|
|
+
|
|
|
+ [mem_pool->heaps_to_remove removeAllObjects];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
|
|
|
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
|
|
+ ptr.data->offs = 0;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
|
|
|
+ const size_t alignment = 32;
|
|
|
+
|
|
|
+ const size_t size_aligned = GGML_PAD(size, alignment);
|
|
|
+
|
|
|
+ // try one of the existing heaps
|
|
|
+ for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
|
|
+ struct ggml_metal_heap * heap = ptr.data;
|
|
|
+ if (heap->offs + size_aligned <= [heap->obj size]) {
|
|
|
+ // if this is the first buffer in the heap for the current command buffer, tell the OS that
|
|
|
+ // it cannot free the memory used by the heap
|
|
|
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
|
|
+ if ([heap->bufs count] == 0) {
|
|
|
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
|
|
+ }
|
|
|
+
|
|
|
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
|
|
+ if (buf == nil) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
|
|
+ return nil;
|
|
|
+ }
|
|
|
+
|
|
|
+ heap->n_alloc++;
|
|
|
+ heap->offs += size_aligned;
|
|
|
+
|
|
|
+ [heap->bufs addObject:buf];
|
|
|
+
|
|
|
+ return buf;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // create a new heap that can fit this buffer
|
|
|
+ ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
|
|
|
+
|
|
|
+ struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
|
|
|
+ if (heap == NULL) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
|
|
+
|
|
|
+ heap_ptr.data = heap;
|
|
|
+ ggml_metal_heap_reset(heap);
|
|
|
+
|
|
|
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
|
|
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
|
|
+ if (buf == nil) {
|
|
|
+ GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ heap->n_alloc++;
|
|
|
+ heap->offs += size_aligned;
|
|
|
+
|
|
|
+ [heap->bufs addObject:buf];
|
|
|
+
|
|
|
+ [mem_pool->heaps addObject:heap_ptr];
|
|
|
+ mem_pool->n_heaps++;
|
|
|
+
|
|
|
+ return buf;
|
|
|
+}
|
|
|
+
|
|
|
+struct ggml_metal_command_buffer {
|
|
|
+ id<MTLCommandBuffer> obj;
|
|
|
+
|
|
|
+ // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
|
|
+ struct ggml_metal_mem_pool * mem_pool;
|
|
|
+};
|
|
|
+
|
|
|
struct ggml_backend_metal_context {
|
|
|
+ id<MTLDevice> device;
|
|
|
id<MTLCommandQueue> queue;
|
|
|
|
|
|
dispatch_queue_t d_queue;
|
|
|
@@ -515,7 +767,7 @@ struct ggml_backend_metal_context {
|
|
|
void (^encode_async)(size_t ith);
|
|
|
|
|
|
// n_cb command buffers + 1 used by the main thread
|
|
|
- id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
|
|
+ struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
|
|
|
|
|
// abort ggml_metal_graph_compute if callback returns true
|
|
|
ggml_abort_callback abort_callback;
|
|
|
@@ -705,9 +957,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
|
|
|
|
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
|
+
|
|
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
|
|
|
|
- ctx->queue = [device newCommandQueue];
|
|
|
+ ctx->device = device;
|
|
|
+ ctx->queue = [device newCommandQueue];
|
|
|
if (ctx->queue == nil) {
|
|
|
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
|
|
return NULL;
|
|
|
@@ -768,7 +1022,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
ctx->gf = nil;
|
|
|
ctx->encode_async = nil;
|
|
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
|
|
- ctx->command_buffers[i] = nil;
|
|
|
+ ctx->cmd_bufs[i].obj = nil;
|
|
|
+
|
|
|
+ ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
|
|
+ ctx->cmd_bufs[i].mem_pool->device = device;
|
|
|
}
|
|
|
|
|
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
|
|
@@ -1181,6 +1438,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
|
|
|
|
[ctx->queue release];
|
|
|
|
|
|
+ for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
|
|
+ // ctx->cmd_bufs[i].obj is auto released
|
|
|
+
|
|
|
+ ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
|
|
+ }
|
|
|
+
|
|
|
dispatch_release(ctx->d_queue);
|
|
|
|
|
|
free(ctx);
|
|
|
@@ -1486,10 +1749,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static void ggml_metal_encode_node(
|
|
|
+static bool ggml_metal_encode_node(
|
|
|
ggml_backend_t backend,
|
|
|
int idx,
|
|
|
- id<MTLComputeCommandEncoder> encoder) {
|
|
|
+ id<MTLComputeCommandEncoder> encoder,
|
|
|
+ struct ggml_metal_mem_pool * mem_pool) {
|
|
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
|
@@ -1505,7 +1769,7 @@ static void ggml_metal_encode_node(
|
|
|
struct ggml_tensor * dst = node;
|
|
|
|
|
|
if (ggml_is_empty(dst)) {
|
|
|
- return;
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
switch (dst->op) {
|
|
|
@@ -1516,7 +1780,7 @@ static void ggml_metal_encode_node(
|
|
|
case GGML_OP_PERMUTE:
|
|
|
{
|
|
|
// noop -> next node
|
|
|
- } return;
|
|
|
+ } return true;
|
|
|
default:
|
|
|
{
|
|
|
} break;
|
|
|
@@ -1527,6 +1791,8 @@ static void ggml_metal_encode_node(
|
|
|
GGML_ABORT("unsupported op");
|
|
|
}
|
|
|
|
|
|
+ ggml_metal_mem_pool_clear(mem_pool);
|
|
|
+
|
|
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
|
@@ -2173,26 +2439,76 @@ static void ggml_metal_encode_node(
|
|
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
|
|
- ggml_metal_kargs_soft_max args = {
|
|
|
+// use this branch to test the ggml_metal_mem_pool functionality
|
|
|
+#if 0
|
|
|
+ // cpy to tmp buffer in MTLHeap
|
|
|
+
|
|
|
+ id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
|
|
+ if (!h_src0) {
|
|
|
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ offs_src0 = 0;
|
|
|
+
|
|
|
+ ggml_metal_kargs_cpy args_cpy = {
|
|
|
/*.ne00 =*/ ne00,
|
|
|
/*.ne01 =*/ ne01,
|
|
|
/*.ne02 =*/ ne02,
|
|
|
- /*.scale =*/ scale,
|
|
|
- /*.max_bias =*/ max_bias,
|
|
|
- /*.m0 =*/ m0,
|
|
|
- /*.m1 =*/ m1,
|
|
|
+ /*.ne03 =*/ ne03,
|
|
|
+ /*.nb00 =*/ nb00,
|
|
|
+ /*.nb01 =*/ nb01,
|
|
|
+ /*.nb02 =*/ nb02,
|
|
|
+ /*.nb03 =*/ nb03,
|
|
|
+ /*.ne0 =*/ ne00,
|
|
|
+ /*.ne1 =*/ ne01,
|
|
|
+ /*.ne2 =*/ ne02,
|
|
|
+ /*.ne3 =*/ ne03,
|
|
|
+ /*.nb0 =*/ nb00,
|
|
|
+ /*.nb1 =*/ nb01,
|
|
|
+ /*.nb2 =*/ nb02,
|
|
|
+ /*.nb3 =*/ nb03,
|
|
|
+ };
|
|
|
+
|
|
|
+ if (src0->type == GGML_TYPE_F16) {
|
|
|
+ [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
|
|
+ } else {
|
|
|
+ [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
|
|
+ }
|
|
|
+ [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
|
|
|
+
|
|
|
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
|
+ int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
|
|
+
|
|
|
+#else
|
|
|
+ id<MTLBuffer> h_src0 = id_src0;
|
|
|
+#endif
|
|
|
+ // softmax
|
|
|
+
|
|
|
+ ggml_metal_kargs_soft_max args = {
|
|
|
+ /*.ne00 =*/ ne00,
|
|
|
+ /*.ne01 =*/ ne01,
|
|
|
+ /*.ne02 =*/ ne02,
|
|
|
+ /*.scale =*/ scale,
|
|
|
+ /*.max_bias =*/ max_bias,
|
|
|
+ /*.m0 =*/ m0,
|
|
|
+ /*.m1 =*/ m1,
|
|
|
/*.n_head_log2 =*/ n_head_log2,
|
|
|
};
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
|
|
if (id_src1) {
|
|
|
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
} else {
|
|
|
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
|
|
}
|
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
|
|
|
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
|
|
|
|
@@ -4601,6 +4917,8 @@ static void ggml_metal_encode_node(
|
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
static enum ggml_status ggml_metal_graph_compute(
|
|
|
@@ -4654,25 +4972,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
}
|
|
|
|
|
|
// the main thread commits the first few commands immediately
|
|
|
- // command_buffer[n_cb]
|
|
|
+ // cmd_buf[n_cb]
|
|
|
{
|
|
|
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
- ctx->command_buffers[n_cb] = command_buffer;
|
|
|
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
|
|
|
|
|
- [command_buffer enqueue];
|
|
|
+ [cmd_buf enqueue];
|
|
|
ctx->encode_async(n_cb);
|
|
|
}
|
|
|
|
|
|
// prepare the rest of the command buffers asynchronously
|
|
|
- // command_buffer[0.. n_cb)
|
|
|
+ // cmd_buf[0.. n_cb)
|
|
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
|
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
- ctx->command_buffers[cb_idx] = command_buffer;
|
|
|
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
|
|
|
|
|
// always enqueue the first two command buffers
|
|
|
// enqueue all of the command buffers if we don't need to abort
|
|
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
|
- [command_buffer enqueue];
|
|
|
+ [cmd_buf enqueue];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -4681,14 +4999,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
// wait for completion and check status of each command buffer
|
|
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
|
{
|
|
|
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
|
|
|
- [command_buffer waitUntilCompleted];
|
|
|
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
|
|
+ [cmd_buf waitUntilCompleted];
|
|
|
|
|
|
- MTLCommandBufferStatus status = [command_buffer status];
|
|
|
+ MTLCommandBufferStatus status = [cmd_buf status];
|
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
|
|
if (status == MTLCommandBufferStatusError) {
|
|
|
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
|
|
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
|
|
}
|
|
|
|
|
|
return GGML_STATUS_FAILED;
|
|
|
@@ -4696,20 +5014,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
}
|
|
|
|
|
|
for (int i = 0; i < n_cb; ++i) {
|
|
|
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
|
|
|
- [command_buffer waitUntilCompleted];
|
|
|
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
|
|
+ [cmd_buf waitUntilCompleted];
|
|
|
|
|
|
- MTLCommandBufferStatus status = [command_buffer status];
|
|
|
+ MTLCommandBufferStatus status = [cmd_buf status];
|
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
|
if (status == MTLCommandBufferStatusError) {
|
|
|
- GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
|
|
+ GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
|
|
}
|
|
|
|
|
|
return GGML_STATUS_FAILED;
|
|
|
}
|
|
|
|
|
|
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
|
|
|
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
|
|
if (!next_buffer) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -5092,8 +5410,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
|
|
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
|
|
|
|
|
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
|
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
|
|
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
|
|
+
|
|
|
+ id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
|
|
|
|
|
int node_start = 0;
|
|
|
int node_end = n_nodes_0;
|
|
|
@@ -5105,22 +5424,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
|
|
|
const bool should_capture = ctx->capture_next_compute;
|
|
|
|
|
|
+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
|
|
+ ggml_metal_mem_pool_reset(mem_pool);
|
|
|
+
|
|
|
for (int idx = node_start; idx < node_end; ++idx) {
|
|
|
if (should_capture) {
|
|
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
|
}
|
|
|
|
|
|
- ggml_metal_encode_node(backend, idx, encoder);
|
|
|
+ const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
|
|
|
|
|
if (should_capture) {
|
|
|
[encoder popDebugGroup];
|
|
|
}
|
|
|
+
|
|
|
+ if (!res) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
[encoder endEncoding];
|
|
|
|
|
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
|
- [command_buffer commit];
|
|
|
+ [cmd_buf commit];
|
|
|
}
|
|
|
});
|
|
|
}
|