|
|
@@ -275,6 +275,7 @@ struct vk_device_struct {
|
|
|
bool prefer_host_memory;
|
|
|
bool float_controls_rte_fp16;
|
|
|
bool subgroup_add;
|
|
|
+ bool subgroup_shuffle;
|
|
|
|
|
|
bool integer_dot_product;
|
|
|
|
|
|
@@ -402,12 +403,20 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
|
|
|
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
|
+ vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
|
+ vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
|
+ vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
|
+ vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
|
+ vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
|
+ vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
|
+
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
|
+
|
|
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
|
|
|
|
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
|
|
@@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
|
|
|
// number of rows/cols for flash attention shader
|
|
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
|
-static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
|
+static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
|
+static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
|
|
+
|
|
|
+static uint32_t get_fa_num_small_rows(bool scalar) {
|
|
|
+ return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
|
|
|
+}
|
|
|
+
|
|
|
+static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
|
GGML_UNUSED(clamp);
|
|
|
|
|
|
+ if (scalar) {
|
|
|
+ if (small_rows) {
|
|
|
+ return {scalar_flash_attention_num_small_rows, 64};
|
|
|
+ } else {
|
|
|
+ return {scalar_flash_attention_num_large_rows, 32};
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// small rows, large cols
|
|
|
if (small_rows) {
|
|
|
- return {flash_attention_num_small_rows, 64};
|
|
|
+ return {get_fa_num_small_rows(scalar), 32};
|
|
|
}
|
|
|
+
|
|
|
// small cols to reduce register count
|
|
|
if (ggml_is_quantized(type) || D == 256) {
|
|
|
return {64, 32};
|
|
|
@@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
|
};
|
|
|
|
|
|
-#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
|
- if (device->coopmat2) {
|
|
|
-
|
|
|
- auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
|
- return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
|
|
- };
|
|
|
+ auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
|
+ return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
|
|
|
+ };
|
|
|
|
|
|
- auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
|
- // For large number of rows, 128 invocations seems to work best.
|
|
|
- // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
|
- // can't use 256 for D==80.
|
|
|
- uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
|
|
- auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
|
|
- // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
|
- GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
|
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
|
|
- };
|
|
|
+ auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
|
+ // For large number of rows, 128 invocations seems to work best.
|
|
|
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
|
+ // can't use 256 for D==80.
|
|
|
+ // For scalar, use 128 (arbitrary)
|
|
|
+ uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
|
+ auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
|
|
+
|
|
|
+ // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
|
+ // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
|
+ const uint32_t D_lsb = D ^ (D & (D-1));
|
|
|
+ uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
|
|
+
|
|
|
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
|
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
|
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
|
|
+ };
|
|
|
|
|
|
-#define CREATE_FA2(TYPE, NAMELC, D) \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
|
|
-
|
|
|
-#define CREATE_FA(TYPE, NAMELC) \
|
|
|
- CREATE_FA2(TYPE, NAMELC, 64) \
|
|
|
- CREATE_FA2(TYPE, NAMELC, 80) \
|
|
|
- CREATE_FA2(TYPE, NAMELC, 96) \
|
|
|
- CREATE_FA2(TYPE, NAMELC, 112) \
|
|
|
- CREATE_FA2(TYPE, NAMELC, 128) \
|
|
|
- CREATE_FA2(TYPE, NAMELC, 256)
|
|
|
-
|
|
|
- CREATE_FA(GGML_TYPE_F16, f16)
|
|
|
- CREATE_FA(GGML_TYPE_Q4_0, q4_0)
|
|
|
- CREATE_FA(GGML_TYPE_Q4_1, q4_1)
|
|
|
- CREATE_FA(GGML_TYPE_Q5_0, q5_0)
|
|
|
- CREATE_FA(GGML_TYPE_Q5_1, q5_1)
|
|
|
- CREATE_FA(GGML_TYPE_Q8_0, q8_0)
|
|
|
- // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
|
- //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
|
|
|
- //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
|
|
|
- //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
|
|
- //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
|
|
- //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
|
|
|
- //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
|
|
|
- CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
|
|
|
+#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
|
|
+
|
|
|
+#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
|
|
+ CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
|
|
+ CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
|
|
|
+ CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
|
|
|
+ CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
|
|
|
+ CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
|
|
|
+ CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
|
|
|
+
|
|
|
+ CREATE_FA(GGML_TYPE_F16, f16, true, )
|
|
|
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
|
|
|
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
|
|
|
+#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
|
+ if (device->coopmat2) {
|
|
|
+ CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
|
|
|
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
|
|
|
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
|
|
|
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
|
|
|
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
|
|
|
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
|
|
|
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
|
|
|
+ }
|
|
|
+#endif
|
|
|
+#undef CREATE_FA2
|
|
|
#undef CREATE_FA
|
|
|
|
|
|
+#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
|
+ if (device->coopmat2) {
|
|
|
+
|
|
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
|
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
|
@@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
|
|
|
|
|
+ device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
|
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
|
|
+
|
|
|
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
|
|
|
|
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
|
@@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
assert(q->type == GGML_TYPE_F32);
|
|
|
assert(k->type == v->type);
|
|
|
|
|
|
+ bool scalar = !ctx->device->coopmat2;
|
|
|
+
|
|
|
+ uint32_t gqa_ratio = 1;
|
|
|
+ uint32_t qk_ratio = neq2 / nek2;
|
|
|
+ uint32_t workgroups_x = (uint32_t)neq1;
|
|
|
+ uint32_t workgroups_y = (uint32_t)neq2;
|
|
|
+ uint32_t workgroups_z = (uint32_t)neq3;
|
|
|
+
|
|
|
+ // For scalar FA, we can use the "large" size to accommodate qga.
|
|
|
+ // For coopmat FA, we always use the small size (which is still pretty large for gqa).
|
|
|
+ const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
|
|
|
+
|
|
|
+ if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
|
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
|
|
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
|
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
|
+ // and change addressing calculations to index Q's dimension 2.
|
|
|
+ gqa_ratio = qk_ratio;
|
|
|
+ N = gqa_ratio;
|
|
|
+ workgroups_y /= N;
|
|
|
+ }
|
|
|
+
|
|
|
vk_pipeline *pipelines;
|
|
|
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
|
|
- bool f32acc = dst->op_params[3] == GGML_PREC_F32;
|
|
|
- bool small_rows = N <= flash_attention_num_small_rows;
|
|
|
- switch (D) {
|
|
|
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
|
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
|
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
|
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
|
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
|
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
|
- default:
|
|
|
- assert(!"unsupported D value");
|
|
|
- return;
|
|
|
+ bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
|
|
|
+ bool small_rows = N <= get_fa_num_small_rows(scalar);
|
|
|
+
|
|
|
+ if (scalar) {
|
|
|
+ switch (D) {
|
|
|
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
|
+ default:
|
|
|
+ GGML_ASSERT(!"unsupported D value");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ switch (D) {
|
|
|
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
|
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
|
+ default:
|
|
|
+ GGML_ASSERT(!"unsupported D value");
|
|
|
+ return;
|
|
|
+ }
|
|
|
}
|
|
|
assert(pipelines);
|
|
|
|
|
|
@@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
vk_pipeline pipeline = pipelines[aligned];
|
|
|
assert(pipeline);
|
|
|
|
|
|
- uint32_t gqa_ratio = 1;
|
|
|
- uint32_t qk_ratio = neq2 / nek2;
|
|
|
- uint32_t workgroups_x = (uint32_t)neq1;
|
|
|
- uint32_t workgroups_y = (uint32_t)neq2;
|
|
|
- uint32_t workgroups_z = (uint32_t)neq3;
|
|
|
-
|
|
|
- if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
|
|
|
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
|
|
- // grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
|
- // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
|
- // and change addressing calculations to index Q's dimension 2.
|
|
|
- gqa_ratio = qk_ratio;
|
|
|
- N = gqa_ratio;
|
|
|
- workgroups_y /= N;
|
|
|
- }
|
|
|
-
|
|
|
uint32_t split_kv = KV;
|
|
|
uint32_t split_k = 1;
|
|
|
|
|
|
+ // Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
|
|
+ const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
|
+
|
|
|
// Try to use split_k when KV is large enough to be worth the overhead
|
|
|
- if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
|
|
|
+ if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
|
|
// Try to run two workgroups per SM.
|
|
|
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
|
|
if (split_k > 1) {
|
|
|
@@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
{
|
|
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
|
- if (!ggml_vk_get_device(ctx->device)->coopmat2) {
|
|
|
- return false;
|
|
|
- }
|
|
|
+ auto device = ggml_vk_get_device(ctx->device);
|
|
|
+ bool coopmat2 = device->coopmat2;
|
|
|
switch (op->src[0]->ne[0]) {
|
|
|
case 64:
|
|
|
case 80:
|
|
|
@@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
case 112:
|
|
|
case 128:
|
|
|
case 256:
|
|
|
- case 575: // DeepSeek MLA
|
|
|
break;
|
|
|
default:
|
|
|
return false;
|
|
|
@@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
switch (op->src[1]->type) {
|
|
|
case GGML_TYPE_F16:
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
+ case GGML_TYPE_Q8_0:
|
|
|
+ // supported in scalar and coopmat2 paths
|
|
|
+ break;
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
case GGML_TYPE_Q5_0:
|
|
|
case GGML_TYPE_Q5_1:
|
|
|
- case GGML_TYPE_Q8_0:
|
|
|
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
|
//case GGML_TYPE_Q2_K:
|
|
|
//case GGML_TYPE_Q3_K:
|
|
|
@@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
//case GGML_TYPE_IQ3_S:
|
|
|
//case GGML_TYPE_IQ4_XS:
|
|
|
case GGML_TYPE_IQ4_NL:
|
|
|
+ // currently supported only in coopmat2 path
|
|
|
+ if (!coopmat2) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
break;
|
|
|
default:
|
|
|
return false;
|
|
|
}
|
|
|
+ if (!coopmat2 && !device->subgroup_shuffle) {
|
|
|
+ // scalar FA uses subgroupShuffle
|
|
|
+ return false;
|
|
|
+ }
|
|
|
return true;
|
|
|
}
|
|
|
case GGML_OP_GET_ROWS:
|