|
|
@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
|
|
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
|
|
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
|
|
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
|
|
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
|
|
|
|
|
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
|
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
|
|
}
|
|
|
+ if (node->src[2]) {
|
|
|
+ GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
|
|
+ ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
|
|
+ }
|
|
|
+ if (node->src[3]) {
|
|
|
+ GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
|
|
+ ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
|
|
+ }
|
|
|
if (node) {
|
|
|
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
|
node->name);
|
|
|
@@ -1889,20 +1901,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
|
|
return (ne01 < 20) && (ne00 % 32 == 0);
|
|
|
}
|
|
|
|
|
|
+size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|
|
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
|
+
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
|
+
|
|
|
+ size_t res = 0;
|
|
|
+
|
|
|
+ const bool has_mask = op->src[3] != nullptr;
|
|
|
+
|
|
|
+ if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
|
+ const bool has_kvpad = ne11 % 32 != 0;
|
|
|
+
|
|
|
+ if (has_kvpad) {
|
|
|
+ res += 32*(
|
|
|
+ nb11*ne12*ne13 +
|
|
|
+ nb21*ne22*ne23 +
|
|
|
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ const bool has_kvpad = ne11 % 64 != 0;
|
|
|
+
|
|
|
+ if (has_kvpad) {
|
|
|
+ res += 64*(
|
|
|
+ nb11*ne12*ne13 +
|
|
|
+ nb21*ne22*ne23 +
|
|
|
+ (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
|
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
|
|
|
|
|
- const int64_t nwg = 32;
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
|
+ //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
|
|
+ //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
|
|
+ //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
|
|
+ //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
|
|
+
|
|
|
+ size_t res = 0;
|
|
|
|
|
|
- const int64_t ne01 = op->src[0]->ne[1];
|
|
|
- const int64_t ne02 = op->src[0]->ne[2];
|
|
|
- const int64_t ne03 = op->src[0]->ne[3];
|
|
|
- const int64_t ne20 = op->src[2]->ne[0];
|
|
|
+ if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
|
+ const int64_t nwg = 32;
|
|
|
|
|
|
- // temp buffer for writing the results from each workgroup
|
|
|
- // - ne20: the size of the Value head
|
|
|
- // - + 2: the S and M values for each intermediate result
|
|
|
- return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
|
|
+ // temp buffer for writing the results from each workgroup
|
|
|
+ // - ne20: the size of the Value head
|
|
|
+ // - + 2: the S and M values for each intermediate result
|
|
|
+ res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
|
|
+ }
|
|
|
+
|
|
|
+ return res;
|
|
|
}
|
|
|
|
|
|
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
@@ -1924,8 +1985,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
|
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
|
|
|
|
|
- GGML_ASSERT(ne00 % 4 == 0);
|
|
|
- GGML_ASSERT(ne11 % 32 == 0);
|
|
|
+ GGML_ASSERT(ne00 % 4 == 0);
|
|
|
|
|
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
|
|
@@ -1935,8 +1995,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_ASSERT(ne12 == ne22);
|
|
|
|
|
|
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
|
|
- GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
|
|
|
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
|
|
+ GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
|
|
+ "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
|
|
|
|
|
float scale;
|
|
|
float max_bias;
|
|
|
@@ -1963,6 +2023,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
|
|
|
GGML_ASSERT(ne01 < 65536);
|
|
|
|
|
|
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
|
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
|
|
+ ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
|
|
+ ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
|
|
+ ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
|
|
+
|
|
|
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
|
+
|
|
|
+ ggml_metal_buffer_id bid_pad = bid_dst;
|
|
|
+ bid_pad.offs += ggml_nbytes(op);
|
|
|
+
|
|
|
+ ggml_metal_buffer_id bid_tmp = bid_pad;
|
|
|
+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
|
|
+
|
|
|
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
|
// half8x8 kernel
|
|
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
|
|
@@ -1972,6 +2046,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
|
|
|
|
+ const bool has_kvpad = ne11 % ncpsg != 0;
|
|
|
+
|
|
|
+ if (has_kvpad) {
|
|
|
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
|
+
|
|
|
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
|
+ /*.ne11 =*/ne11,
|
|
|
+ /*.ne_12_2 =*/ne12,
|
|
|
+ /*.ne_12_3 =*/ne13,
|
|
|
+ /*.nb11 =*/nb11,
|
|
|
+ /*.nb12 =*/nb12,
|
|
|
+ /*.nb13 =*/nb13,
|
|
|
+ /*.nb21 =*/nb21,
|
|
|
+ /*.nb22 =*/nb22,
|
|
|
+ /*.nb23 =*/nb23,
|
|
|
+ /*.ne31 =*/ne31,
|
|
|
+ /*.ne32 =*/ne32,
|
|
|
+ /*.ne33 =*/ne33,
|
|
|
+ /*.nb31 =*/nb31,
|
|
|
+ /*.nb32 =*/nb32,
|
|
|
+ /*.nb33 =*/nb33,
|
|
|
+ };
|
|
|
+
|
|
|
+ ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
|
+
|
|
|
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
|
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
|
+
|
|
|
+ assert(ne12 == ne22);
|
|
|
+ assert(ne13 == ne23);
|
|
|
+
|
|
|
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
|
+
|
|
|
+ ggml_metal_op_concurrency_reset(ctx);
|
|
|
+ } else {
|
|
|
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
|
|
+ }
|
|
|
+
|
|
|
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
|
|
|
|
|
// 2*(2*ncpsg)
|
|
|
@@ -2021,6 +2137,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
/*.nb21 =*/ nb21,
|
|
|
/*.nb22 =*/ nb22,
|
|
|
/*.nb23 =*/ nb23,
|
|
|
+ /*.ne31 =*/ ne31,
|
|
|
/*.ne32 =*/ ne32,
|
|
|
/*.ne33 =*/ ne33,
|
|
|
/*.nb31 =*/ nb31,
|
|
|
@@ -2037,24 +2154,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
/*.logit_softcap =*/ logit_softcap,
|
|
|
};
|
|
|
|
|
|
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
|
|
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
|
|
|
|
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
|
|
- if (op->src[3]) {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
|
|
- } else {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
|
- }
|
|
|
- if (op->src[4]) {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
|
- } else {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
|
- }
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
|
|
|
|
|
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
|
|
|
|
@@ -2070,6 +2180,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_ASSERT(nqptg % 1 == 0);
|
|
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
|
|
|
|
+ const bool has_kvpad = ne11 % ncpsg != 0;
|
|
|
+
|
|
|
+ if (has_kvpad) {
|
|
|
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
|
|
+
|
|
|
+ ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
|
|
+ /*.ne11 =*/ne11,
|
|
|
+ /*.ne_12_2 =*/ne12,
|
|
|
+ /*.ne_12_3 =*/ne13,
|
|
|
+ /*.nb11 =*/nb11,
|
|
|
+ /*.nb12 =*/nb12,
|
|
|
+ /*.nb13 =*/nb13,
|
|
|
+ /*.nb21 =*/nb21,
|
|
|
+ /*.nb22 =*/nb22,
|
|
|
+ /*.nb23 =*/nb23,
|
|
|
+ /*.ne31 =*/ne31,
|
|
|
+ /*.ne32 =*/ne32,
|
|
|
+ /*.ne33 =*/ne33,
|
|
|
+ /*.nb31 =*/nb31,
|
|
|
+ /*.nb32 =*/nb32,
|
|
|
+ /*.nb33 =*/nb33,
|
|
|
+ };
|
|
|
+
|
|
|
+ ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
|
|
+
|
|
|
+ ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
|
|
+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
|
+
|
|
|
+ assert(ne12 == ne22);
|
|
|
+ assert(ne13 == ne23);
|
|
|
+
|
|
|
+ ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
|
|
+
|
|
|
+ ggml_metal_op_concurrency_reset(ctx);
|
|
|
+ } else {
|
|
|
+ assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
|
|
+ }
|
|
|
+
|
|
|
// ne00 + 2*ncpsg*(nsg)
|
|
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
|
// and store the soft_max values and the mask
|
|
|
@@ -2134,6 +2286,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
/*.nb21 =*/ nb21,
|
|
|
/*.nb22 =*/ nb22,
|
|
|
/*.nb23 =*/ nb23,
|
|
|
+ /*.ne31 =*/ ne31,
|
|
|
/*.ne32 =*/ ne32,
|
|
|
/*.ne33 =*/ ne33,
|
|
|
/*.nb31 =*/ nb31,
|
|
|
@@ -2150,25 +2303,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
/*.logit_softcap =*/ logit_softcap,
|
|
|
};
|
|
|
|
|
|
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
|
|
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
|
|
|
|
|
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
|
|
|
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
|
|
- if (op->src[3]) {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
|
|
- } else {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
|
- }
|
|
|
- if (op->src[4]) {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
|
- } else {
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
|
- }
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
|
|
|
|
|
const size_t smem = FATTN_SMEM(nsg);
|
|
|
|
|
|
@@ -2176,23 +2321,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|
|
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
|
|
|
|
|
if (nwg == 1) {
|
|
|
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
|
|
+
|
|
|
// using 1 workgroup -> write the result directly into dst
|
|
|
- ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
|
|
|
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
|
+ ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
|
|
|
|
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
|
|
|
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
|
|
} else {
|
|
|
// sanity checks
|
|
|
+ assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
|
|
+
|
|
|
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
|
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
|
|
|
|
|
- ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
|
-
|
|
|
// write the results from each workgroup into a temp buffer
|
|
|
- ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
|
- bid_tmp.offs += ggml_nbytes(op);
|
|
|
- ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
|
|
+ ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
|
|
+ ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
|
|
|
|
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|