|
|
@@ -33,12 +33,15 @@ struct ggml_metal_buffer {
|
|
|
struct ggml_metal_context {
|
|
|
int n_cb;
|
|
|
|
|
|
- float * logits;
|
|
|
-
|
|
|
id<MTLDevice> device;
|
|
|
id<MTLCommandQueue> queue;
|
|
|
id<MTLLibrary> library;
|
|
|
|
|
|
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
|
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
|
|
+
|
|
|
+ dispatch_queue_t d_queue;
|
|
|
+
|
|
|
int n_buffers;
|
|
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
|
|
|
|
|
@@ -114,12 +117,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
|
|
|
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
|
|
|
|
|
- ctx->n_cb = n_cb;
|
|
|
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
|
ctx->device = MTLCreateSystemDefaultDevice();
|
|
|
ctx->queue = [ctx->device newCommandQueue];
|
|
|
ctx->n_buffers = 0;
|
|
|
ctx->concur_list_len = 0;
|
|
|
|
|
|
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
|
|
|
|
|
#if 0
|
|
|
// compile from source string and show compile log
|
|
|
@@ -239,9 +243,67 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
|
|
|
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
|
fprintf(stderr, "%s: deallocating\n", __func__);
|
|
|
+#define GGML_METAL_DEL_KERNEL(name) \
|
|
|
+ [ctx->function_##name release]; \
|
|
|
+ [ctx->pipeline_##name release];
|
|
|
+
|
|
|
+ GGML_METAL_DEL_KERNEL(add);
|
|
|
+ GGML_METAL_DEL_KERNEL(add_row);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_row);
|
|
|
+ GGML_METAL_DEL_KERNEL(scale);
|
|
|
+ GGML_METAL_DEL_KERNEL(silu);
|
|
|
+ GGML_METAL_DEL_KERNEL(relu);
|
|
|
+ GGML_METAL_DEL_KERNEL(gelu);
|
|
|
+ GGML_METAL_DEL_KERNEL(soft_max);
|
|
|
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
|
|
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
|
+ GGML_METAL_DEL_KERNEL(rms_norm);
|
|
|
+ GGML_METAL_DEL_KERNEL(norm);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(rope);
|
|
|
+ GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
|
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
|
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
|
+
|
|
|
+#undef GGML_METAL_DEL_KERNEL
|
|
|
+
|
|
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
|
[ctx->buffers[i].metal release];
|
|
|
}
|
|
|
+
|
|
|
+ [ctx->library release];
|
|
|
+ [ctx->queue release];
|
|
|
+ [ctx->device release];
|
|
|
+
|
|
|
+ dispatch_release(ctx->d_queue);
|
|
|
+
|
|
|
free(ctx);
|
|
|
}
|
|
|
|
|
|
@@ -261,7 +323,7 @@ void ggml_metal_host_free(void * data) {
|
|
|
}
|
|
|
|
|
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
|
|
- ctx->n_cb = n_cb;
|
|
|
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
|
}
|
|
|
|
|
|
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
|
|
@@ -507,6 +569,8 @@ void ggml_metal_graph_compute(
|
|
|
struct ggml_cgraph * gf) {
|
|
|
metal_printf("%s: evaluating graph\n", __func__);
|
|
|
|
|
|
+ @autoreleasepool {
|
|
|
+
|
|
|
// if there is ctx->concur_list, dispatch concurrently
|
|
|
// else fallback to serial dispatch
|
|
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
|
|
@@ -521,29 +585,25 @@ void ggml_metal_graph_compute(
|
|
|
|
|
|
const int n_cb = ctx->n_cb;
|
|
|
|
|
|
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
|
|
-
|
|
|
for (int i = 0; i < n_cb; ++i) {
|
|
|
- command_buffers[i] = [ctx->queue commandBuffer];
|
|
|
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
|
|
|
|
|
// enqueue the command buffers in order to specify their execution order
|
|
|
- [command_buffers[i] enqueue];
|
|
|
- }
|
|
|
+ [ctx->command_buffers[i] enqueue];
|
|
|
|
|
|
- // TODO: is this the best way to start threads?
|
|
|
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
|
|
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
|
|
+ }
|
|
|
|
|
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
|
|
|
|
|
- dispatch_async(queue, ^{
|
|
|
+ dispatch_async(ctx->d_queue, ^{
|
|
|
size_t offs_src0 = 0;
|
|
|
size_t offs_src1 = 0;
|
|
|
size_t offs_dst = 0;
|
|
|
|
|
|
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
|
-
|
|
|
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
|
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
|
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
|
|
|
|
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
|
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
|
@@ -1117,17 +1177,19 @@ void ggml_metal_graph_compute(
|
|
|
}
|
|
|
|
|
|
// wait for all threads to finish
|
|
|
- dispatch_barrier_sync(queue, ^{});
|
|
|
-
|
|
|
- [command_buffers[n_cb - 1] waitUntilCompleted];
|
|
|
+ dispatch_barrier_sync(ctx->d_queue, ^{});
|
|
|
|
|
|
// check status of command buffers
|
|
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
|
for (int i = 0; i < n_cb; i++) {
|
|
|
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
|
|
|
+ [ctx->command_buffers[i] waitUntilCompleted];
|
|
|
+
|
|
|
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
|
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
|
GGML_ASSERT(false);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ }
|
|
|
}
|