Просмотр исходного кода

llama : fix K-shift with quantized K and BLAS backend (#13113)

Diego Devesa 8 месяцев назад
Родитель
Сommit
295354ea68
2 измененных файлов с 4 добавлено и 16 удалено
  1. 3 14
      src/llama-context.cpp
  2. 1 2
      src/llama-context.h

+ 3 - 14
src/llama-context.cpp

@@ -469,8 +469,7 @@ ggml_tensor * llama_context::build_rope_shift(
         ggml_tensor * shift,
         ggml_tensor * shift,
         ggml_tensor * factors,
         ggml_tensor * factors,
               float   freq_base,
               float   freq_base,
-              float   freq_scale,
-        ggml_backend_buffer * bbuf) const {
+              float   freq_scale) const {
     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
 
 
     const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
     const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
@@ -492,17 +491,7 @@ ggml_tensor * llama_context::build_rope_shift(
         // dequantize to f32 -> RoPE -> quantize back
         // dequantize to f32 -> RoPE -> quantize back
         tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
         tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
 
 
-        if (bbuf) {
-            for (const auto & backend : backends) {
-                // Figure out which backend KV cache belongs to
-                if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
-                    ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
-                    break;
-                }
-            }
-        }
-
-        tmp = ggml_rope_ext_inplace(ctx0, tmp,
+        tmp = ggml_rope_ext(ctx0, tmp,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                 yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
                 yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
 
@@ -582,7 +571,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
                 ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
                 ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
                 0);
                 0);
 
 
-        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
+        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
 
 
         ggml_build_forward_expand(gf, cur);
         ggml_build_forward_expand(gf, cur);
     }
     }

+ 1 - 2
src/llama-context.h

@@ -170,8 +170,7 @@ private:
         ggml_tensor * shift,
         ggml_tensor * shift,
         ggml_tensor * factors,
         ggml_tensor * factors,
               float   freq_base,
               float   freq_base,
-              float   freq_scale,
-        ggml_backend_buffer * bbuf) const;
+              float   freq_scale) const;
 
 
     llm_graph_result_ptr build_kv_self_shift(
     llm_graph_result_ptr build_kv_self_shift(
             ggml_context * ctx0,
             ggml_context * ctx0,