Bläddra i källkod

vulkan: Support FA with K/V in F32 (#16543)

Jeff Bolz 3 månader sedan
förälder
incheckning
4258e0cfe7

+ 14 - 2
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
             } \
             } \
         }
         }
 
 
+    CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
     CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
 #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     if (device->coopmat1_fa_support) {
     if (device->coopmat1_fa_support) {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
         CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
         CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
         CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
         CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
         CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
         CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
@@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #endif
 #endif
 #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
     if (device->coopmat2) {
     if (device->coopmat2) {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
         CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
         CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
         CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
         CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
         CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
         CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
@@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     }
     }
 
 
     const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
     const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
-    const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
-    const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
+    uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
+    uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
+
+    // For F32, the shader treats it as a block of size 4 (for vec4 loads)
+    if (k->type == GGML_TYPE_F32) {
+        k_stride /= 4;
+    }
+    if (v->type == GGML_TYPE_F32) {
+        v_stride /= 4;
+    }
 
 
     uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
     uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
     bool aligned = (KV % alignment) == 0 &&
     bool aligned = (KV % alignment) == 0 &&
@@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 }
                 }
                 switch (op->src[1]->type) {
                 switch (op->src[1]->type) {
                 case GGML_TYPE_F16:
                 case GGML_TYPE_F16:
+                case GGML_TYPE_F32:
                 case GGML_TYPE_Q4_0:
                 case GGML_TYPE_Q4_0:
                 case GGML_TYPE_Q8_0:
                 case GGML_TYPE_Q8_0:
                     // supported in scalar and coopmat2 paths
                     // supported in scalar and coopmat2 paths

+ 14 - 0
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

@@ -1,6 +1,18 @@
 
 
 #include "types.glsl"
 #include "types.glsl"
 
 
+layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
+   vec4 block;
+};
+
+float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+    const vec4 v = bl.block;
+    const uint idx = coordInBlock[1];
+    const f16vec4 vf16 = f16vec4(v);
+    return vf16[idx];
+}
+
 layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
 layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
    block_q4_0_packed16 block;
    block_q4_0_packed16 block;
 };
 };
@@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
 #define dequantFuncA dequantFuncIQ4_NL
 #define dequantFuncA dequantFuncIQ4_NL
 #elif defined(DATA_A_MXFP4)
 #elif defined(DATA_A_MXFP4)
 #define dequantFuncA dequantFuncMXFP4
 #define dequantFuncA dequantFuncMXFP4
+#elif defined(DATA_A_F32)
+#define dequantFuncA dequantFuncF32
 #endif
 #endif

+ 19 - 1
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl

@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
 
 
 layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
 layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
 
 
-#if defined(A_TYPE_PACKED16)
 #define BINDING_IDX_K 0
 #define BINDING_IDX_K 0
 #define BINDING_IDX_V 1
 #define BINDING_IDX_V 1
+#if defined(DATA_A_F32)
+layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
+layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
+#elif defined(A_TYPE_PACKED16)
 layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
 layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 #endif
 #endif
 
 
+#if defined(DATA_A_F32)
+#undef BLOCK_SIZE
+#define BLOCK_SIZE 4
+#define BLOCK_BYTE_SIZE 16
+
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+    // iqs is currently always zero in the flash attention shaders
+    if (binding_idx == BINDING_IDX_K) {
+        return k_packed.k_data_packed[a_offset + ib];
+    } else {
+        return v_packed.v_data_packed[a_offset + ib];
+    }
+}
+#endif
+
 #if defined(DATA_A_Q4_0)
 #if defined(DATA_A_Q4_0)
 #define BLOCK_BYTE_SIZE 18
 #define BLOCK_BYTE_SIZE 18
 
 

+ 2 - 5
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -611,9 +611,6 @@ void process_shaders() {
         }
         }
 
 
         for (const auto& tname : type_names) {
         for (const auto& tname : type_names) {
-            if (tname == "f32") {
-                continue;
-            }
             if (tname == "bf16") continue;
             if (tname == "bf16") continue;
 
 
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -630,7 +627,7 @@ void process_shaders() {
             if (tname == "f16") {
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0") {
+            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
@@ -639,7 +636,7 @@ void process_shaders() {
             if (tname == "f16") {
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0") {
+            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);