|
|
@@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
|
|
|
float max_bias;
|
|
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
|
|
|
|
- if (__builtin_popcount(n_head) != 1) {
|
|
|
- GGML_ASSERT(false && "only power-of-two n_head implemented");
|
|
|
- }
|
|
|
-
|
|
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
|
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
|
|
|
|
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
@@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
|
|
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
|
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
|
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
|
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
|
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
} break;
|