|
|
@@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
} \
|
|
|
}
|
|
|
|
|
|
+ CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
|
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
|
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
|
if (device->coopmat1_fa_support) {
|
|
|
+ CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
|
|
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
|
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
|
|
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
|
|
@@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
#endif
|
|
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
|
if (device->coopmat2) {
|
|
|
+ CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
|
|
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
|
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
|
|
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
|
|
@@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
}
|
|
|
|
|
|
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
|
- const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
|
- const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
|
|
+ uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
|
+ uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
|
|
+
|
|
|
+ // For F32, the shader treats it as a block of size 4 (for vec4 loads)
|
|
|
+ if (k->type == GGML_TYPE_F32) {
|
|
|
+ k_stride /= 4;
|
|
|
+ }
|
|
|
+ if (v->type == GGML_TYPE_F32) {
|
|
|
+ v_stride /= 4;
|
|
|
+ }
|
|
|
|
|
|
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
|
|
bool aligned = (KV % alignment) == 0 &&
|
|
|
@@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
}
|
|
|
switch (op->src[1]->type) {
|
|
|
case GGML_TYPE_F16:
|
|
|
+ case GGML_TYPE_F32:
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
case GGML_TYPE_Q8_0:
|
|
|
// supported in scalar and coopmat2 paths
|