|
@@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
|
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
|
|
|
|
|
|
|
- GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
|
|
|
|
- GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
|
|
|
|
|
|
|
+ GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
|
|
|
+ GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
|
|
|
|
|
|
// loop over n_batch and n_head
|
|
// loop over n_batch and n_head
|
|
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
@@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
vs = expf(s - M);
|
|
vs = expf(s - M);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- v_to_float(v_data, V32, DV);
|
|
|
|
|
-
|
|
|
|
|
// V += v*expf(s - M)
|
|
// V += v*expf(s - M)
|
|
|
- ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
|
|
|
|
|
|
+ if (v_to_float) {
|
|
|
|
|
+ v_to_float(v_data, V32, DV);
|
|
|
|
|
+ ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // V is F32
|
|
|
|
|
+ ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
S = S*ms + vs; // scale and increment sum with partial sum
|
|
S = S*ms + vs; // scale and increment sum with partial sum
|