|
|
@@ -1518,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants {
|
|
|
uint32_t num_blocks;
|
|
|
};
|
|
|
|
|
|
+struct vk_op_flash_attn_split_k_reduce_push_constants {
|
|
|
+ uint32_t D;
|
|
|
+ uint32_t ne1;
|
|
|
+ uint32_t ne2;
|
|
|
+ uint32_t ne3;
|
|
|
+ uint32_t k_num;
|
|
|
+ uint32_t sinks;
|
|
|
+};
|
|
|
+
|
|
|
// Allow pre-recording command buffers
|
|
|
struct vk_staging_memcpy {
|
|
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
|
@@ -3982,7 +3991,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
|
|
|
|
if (device->subgroup_clustered && device->subgroup_require_full_support) {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
|
|
|
@@ -8457,14 +8466,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
GGML_ASSERT(0);
|
|
|
}
|
|
|
|
|
|
- if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
|
+ if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
|
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
|
|
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
|
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
|
// and change addressing calculations to index Q's dimension 2.
|
|
|
gqa_ratio = qk_ratio;
|
|
|
N = gqa_ratio;
|
|
|
- workgroups_y /= N;
|
|
|
+ workgroups_y /= gqa_ratio;
|
|
|
}
|
|
|
|
|
|
bool small_rows = N <= get_fa_num_small_rows(path);
|
|
|
@@ -8526,6 +8535,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
}
|
|
|
|
|
|
assert(pipeline);
|
|
|
+ // Compile early to initialize wg_denoms.
|
|
|
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
|
|
|
|
uint32_t split_kv = KV;
|
|
|
uint32_t split_k = 1;
|
|
|
@@ -8533,22 +8544,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
|
|
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
|
|
|
|
- // Try to use split_k when KV is large enough to be worth the overhead
|
|
|
- if (workgroups_x == 1 && shader_core_count > 0) {
|
|
|
+ // Try to use split_k when KV is large enough to be worth the overhead.
|
|
|
+ // Must either be a single batch or be using gqa, we can't mix the two.
|
|
|
+ if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
|
|
|
// Try to run two workgroups per SM.
|
|
|
- split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
|
|
+ split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
|
|
|
if (split_k > 1) {
|
|
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
|
// of "align", so recompute split_k based on that.
|
|
|
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
|
|
split_k = CEIL_DIV(KV, split_kv);
|
|
|
- workgroups_x = split_k;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
|
|
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
|
|
- const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
|
|
+ // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
|
|
|
+ // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
|
|
|
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
|
|
|
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
|
|
|
GGML_ABORT("Requested preallocation size is too large");
|
|
|
}
|
|
|
@@ -8559,7 +8572,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
|
|
|
{
|
|
|
// Request descriptor sets
|
|
|
- ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
|
if (split_k > 1) {
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
|
|
}
|
|
|
@@ -8608,7 +8620,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
if (ctx->prealloc_split_k_need_sync) {
|
|
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
|
}
|
|
|
-
|
|
|
+ workgroups_x *= pipeline->wg_denoms[0];
|
|
|
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
|
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
|
|
|
@@ -8616,15 +8628,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
|
|
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
|
|
// cancel out the divide by wg_denoms[0].
|
|
|
- pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
|
+ pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
|
|
|
|
|
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
|
- const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
|
|
+ const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
|
{split_k_buf, sinks_buf, dst_buf},
|
|
|
- pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
|
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
|
|
|
ctx->prealloc_split_k_need_sync = true;
|
|
|
} else {
|
|
|
+ if (gqa_ratio > 1) {
|
|
|
+ // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
|
|
|
+ workgroups_x *= pipeline->wg_denoms[0];
|
|
|
+ }
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
|
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
|
|
|
pc, { workgroups_x, workgroups_y, workgroups_z });
|