|
|
@@ -2181,7 +2181,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|
|
|
|
|
const bool has_mask = op->src[3] != nullptr;
|
|
|
|
|
|
- if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
|
+ // note: the non-vec kernel requires more extra memory, so always reserve for it
|
|
|
+ GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
|
|
|
+
|
|
|
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
|
|
+ if (false) {
|
|
|
// note: always reserve the padding space to avoid graph reallocations
|
|
|
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
|
|
const bool has_kvpad = true;
|