|
@@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
|
|
// number of rows/cols for flash attention shader
|
|
// number of rows/cols for flash attention shader
|
|
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
|
-static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
|
|
|
|
|
|
+
|
|
|
|
|
+static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
|
|
|
|
+ if (hsv >= 512) {
|
|
|
|
|
+ return 2;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ return 8;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
|
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
@@ -1760,7 +1767,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
if (small_rows) {
|
|
if (small_rows) {
|
|
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
|
} else {
|
|
} else {
|
|
|
- return {scalar_flash_attention_num_large_rows, 32};
|
|
|
|
|
|
|
+ return {get_fa_scalar_num_large_rows(hsv), 32};
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1779,7 +1786,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
|
|
|
|
|
// small cols to reduce register count
|
|
// small cols to reduce register count
|
|
|
if (ggml_is_quantized(type) || hsk >= 256) {
|
|
if (ggml_is_quantized(type) || hsk >= 256) {
|
|
|
- return {64, 32};
|
|
|
|
|
|
|
+ if (hsk >= 512) {
|
|
|
|
|
+ return {32, 32};
|
|
|
|
|
+ } else {
|
|
|
|
|
+ return {64, 32};
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
return {64, 64};
|
|
return {64, 64};
|
|
|
}
|
|
}
|
|
@@ -1821,7 +1832,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
|
|
|
|
|
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
|
- const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
|
|
|
|
|
|
+ const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
|
|
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
|
|
|
|
|
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
@@ -1946,10 +1957,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
|
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
|
|
|
|
|
|
|
// spec constants and tile sizes for quant matmul_id
|
|
// spec constants and tile sizes for quant matmul_id
|
|
|
- l_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
|
|
|
|
|
+ l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
|
|
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
|
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
|
- l_mmqid_wg_denoms = { 128, 64, 1 };
|
|
|
|
|
|
|
+ l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
|
|
|
|
|
@@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
|
|
// Needs to be kept up to date on shader changes
|
|
// Needs to be kept up to date on shader changes
|
|
|
GGML_UNUSED(hsv);
|
|
GGML_UNUSED(hsv);
|
|
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
|
- const uint32_t Br = scalar_flash_attention_num_large_rows;
|
|
|
|
|
|
|
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
|
|
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
|
|
|
|
|
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
@@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
case FA_SCALAR:
|
|
case FA_SCALAR:
|
|
|
case FA_COOPMAT1:
|
|
case FA_COOPMAT1:
|
|
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
|
- max_gqa = scalar_flash_attention_num_large_rows;
|
|
|
|
|
|
|
+ max_gqa = get_fa_scalar_num_large_rows(HSV);
|
|
|
break;
|
|
break;
|
|
|
case FA_COOPMAT2:
|
|
case FA_COOPMAT2:
|
|
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|