|
|
@@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|
|
}
|
|
|
|
|
|
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
|
|
+ if (f16acc) {
|
|
|
+ base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
|
|
+ }
|
|
|
|
|
|
if (coopmat) {
|
|
|
base_dict["COOPMAT"] = "1";
|
|
|
@@ -437,8 +440,12 @@ void process_shaders() {
|
|
|
|
|
|
// flash attention
|
|
|
for (const auto& f16acc : {false, true}) {
|
|
|
- std::string acctype = f16acc ? "float16_t" : "float";
|
|
|
- std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
|
|
+ std::map<std::string, std::string> fa_base_dict = base_dict;
|
|
|
+ fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
|
|
+ fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
|
|
|
+ if (f16acc) {
|
|
|
+ fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
|
|
+ }
|
|
|
|
|
|
for (const auto& tname : type_names) {
|
|
|
if (tname == "f32") {
|
|
|
@@ -449,30 +456,30 @@ void process_shaders() {
|
|
|
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
|
if (tname == "f16") {
|
|
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
|
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
|
|
+ merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
|
|
|
} else {
|
|
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
|
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
|
|
+ merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
|
|
}
|
|
|
#endif
|
|
|
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
|
if (tname == "f16") {
|
|
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
|
|
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
|
+ merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
|
} else if (tname == "q4_0" || tname == "q8_0") {
|
|
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
|
|
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
|
+ merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
|
}
|
|
|
#endif
|
|
|
if (tname == "f16") {
|
|
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
|
|
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
|
|
|
+ merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
|
|
} else if (tname == "q4_0" || tname == "q8_0") {
|
|
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
|
|
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
|
|
+ merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
|
|
}
|
|
|
}
|
|
|
}
|