Просмотр исходного кода

Fix FlashAttention debug test, FP32 assert (#7684)

Johannes Gäßler 1 год назад
Родитель
Сommit
e141ce624a
2 измененных файлов с 5 добавлено и 7 удалено
  1. 0 4
      ggml-cuda/fattn-vec-f32.cuh
  2. 5 3
      tests/test-backend-ops.cpp

+ 0 - 4
ggml-cuda/fattn-vec-f32.cuh

@@ -278,14 +278,10 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
 
 template <int D, ggml_type type_K, ggml_type type_V>
 void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    ggml_tensor * KQV = dst;
     ggml_tensor * Q   = dst->src[0];
     ggml_tensor * K   = dst->src[1];
     ggml_tensor * V   = dst->src[2];
 
-    const int32_t precision = KQV->op_params[2];
-    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
-
     GGML_ASSERT(K->type == type_K);
     GGML_ASSERT(V->type == type_V);
 

+ 5 - 3
tests/test-backend-ops.cpp

@@ -1584,9 +1584,11 @@ struct test_flash_attn_ext : public test_case {
         : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
-        ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV,       hs, kv, nh, 1);
-        ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV,       hs, kv, nh, 1);
+        const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
+
+        ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
+        ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
         ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
         ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
         return out;