|
|
@@ -1833,6 +1833,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
// can't use 256 for D==80.
|
|
|
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
|
|
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
|
|
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
|
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
|
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
|
|
};
|
|
|
|
|
|
@@ -5511,6 +5513,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
// the "aligned" shader variant will forcibly align strides, for performance
|
|
|
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
|
|
|
|
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
|
+ GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
|
|
+
|
|
|
vk_pipeline pipeline = pipelines[aligned];
|
|
|
assert(pipeline);
|
|
|
|