|
|
@@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
|
if (op->src[0]->ne[0] == 256) {
|
|
|
return false;
|
|
|
}
|
|
|
- {
|
|
|
- float logit_softcap;
|
|
|
-
|
|
|
- memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
|
|
|
-
|
|
|
- if (logit_softcap != 0.0f) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- }
|
|
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
|
case GGML_OP_MUL_MAT:
|
|
|
case GGML_OP_MUL_MAT_ID:
|
|
|
@@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
|
|
|
float scale;
|
|
|
float max_bias;
|
|
|
+ float logit_softcap;
|
|
|
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
|
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
|
+ memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
|
|
|
|
|
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
|
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
|
+ if (logit_softcap != 0.0f) {
|
|
|
+ scale /= logit_softcap;
|
|
|
+ }
|
|
|
|
|
|
const uint32_t n_head = src0->ne[2];
|
|
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
|
@@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
} else {
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
|
|
}
|
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
|
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
|
|
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
|
|
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
|
|
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
|
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
|
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
|
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
|
|
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
|
|
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
|
|
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
|
|
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
|
|
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
|
|
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
|
|
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
|
|
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
|
|
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
|
|
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
|
|
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
|
|
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
|
|
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
|
|
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
|
|
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
|
|
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
|
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
|
|
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
|
|
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
|
|
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
|
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
|
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
|
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
|
|
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
|
|
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
|
|
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
|
|
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
|
|
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
|
|
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
|
|
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
|
|
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
|
|
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
|
|
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
|
|
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
|
|
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
|
|
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
|
|
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
|
|
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
|
|
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
|
|
+ [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
|
|
|
|
|
if (!use_vec_kernel) {
|
|
|
// half8x8 kernel
|