Browse Source

ggml : add asserts for type conversion in fattn kernels (#9971)

ggml-ci
Georgi Gerganov 1 year ago
parent
commit
f594bc80ba
3 changed files with 8 additions and 4 deletions
  1. 2 2
      common/common.cpp
  2. 5 1
      ggml/src/ggml.c
  3. 1 1
      src/llama.cpp

+ 2 - 2
common/common.cpp

@@ -1035,7 +1035,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
         return GGML_TYPE_Q5_1;
     }
 
-    throw std::runtime_error("Invalid cache type: " + s);
+    throw std::runtime_error("Unsupported cache type: " + s);
 }
 
 struct llama_context_params common_context_params_to_llama(const common_params & params) {
@@ -1047,7 +1047,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.n_ubatch          = params.n_ubatch;
     cparams.n_threads         = params.cpuparams.n_threads;
     cparams.n_threads_batch   = params.cpuparams_batch.n_threads == -1 ?
-                                    params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
+                                params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
     cparams.logits_all        = params.logits_all;
     cparams.embeddings        = params.embedding;
     cparams.rope_scaling_type = params.rope_scaling_type;

+ 5 - 1
ggml/src/ggml.c

@@ -324,8 +324,9 @@ struct ggml_logger_state {
 static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
 
 static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
-    if (format == NULL)
+    if (format == NULL) {
         return;
+    }
     va_list args_copy;
     va_copy(args_copy, args);
     char buffer[128];
@@ -15723,6 +15724,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     ggml_vec_dot_t    const kq_vec_dot     = type_traits[k->type].vec_dot;
     ggml_to_float_t   const v_to_float     = type_traits[v->type].to_float;
 
+    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
+    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
+
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices

+ 1 - 1
src/llama.cpp

@@ -19243,7 +19243,7 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
-    if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
+    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
     }