Browse Source

graph : fix graph reuse logic when `n_pos_per_embd > 1` (#18566)

Georgi Gerganov 3 weeks ago
parent
commit
c69c7ebc90
1 changed files with 2 additions and 2 deletions
  1. 2 2
      src/llama-graph.cpp

+ 2 - 2
src/llama-graph.cpp

@@ -32,7 +32,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
     bool res = true;
 
     res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
-    res &= (!embd   && !params.ubatch.embd)  || (embd   &&   embd->ne[0] == params.ubatch.n_tokens);
+    res &= (!embd   && !params.ubatch.embd)  || (embd   &&   embd->ne[1] == params.ubatch.n_tokens);
 
     return res;
 }
@@ -62,7 +62,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
 bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
     bool res = true;
 
-    res &= pos->ne[0] == params.ubatch.n_tokens;
+    res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
 
     return res;
 }