|
|
@@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
|
|
|
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
|
|
|
|
|
GGML_ASSERT(ne00 == ne10);
|
|
|
- GGML_ASSERT(ne02 == ne12);
|
|
|
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
|
|
+ GGML_ASSERT(ne03 == ne13);
|
|
|
|
|
|
if (ggml_is_contiguous(src0) &&
|
|
|
ggml_is_contiguous(src1) &&
|
|
|
@@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
|
|
|
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
|
|
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
|
|
|
|
|
- // we need to do ne02 multiplications
|
|
|
+ // we need to do ne12 multiplications
|
|
|
// TODO: is there a way to do this in parallel - currently very slow ..
|
|
|
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
|
|
|
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
|
|
- size_t offs_src0_cur = offs_src0 + i02*nb02;
|
|
|
+ for (int64_t i02 = 0; i02 < ne12; ++i02) {
|
|
|
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
|
|
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
|
|
size_t offs_dst_cur = offs_dst + i02*nb2;
|
|
|
|
|
|
@@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
|
|
|
switch (src0t) {
|
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
|
- GGML_ASSERT(ne02 == ne12);
|
|
|
-
|
|
|
nth0 = 64;
|
|
|
nth1 = 1;
|
|
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
|
@@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
|
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
|
|
|
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
|
|
|
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
|
|
|
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
|
|
|
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
|
|
|
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
|
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
|
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
|
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
|
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
|
|
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
|
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
|
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
|
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
|
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
|
|
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
|
|
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
|
|
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
|
|
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
|
|
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
|
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
|
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
|
|
|
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
|
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|