|
|
@@ -32,7 +32,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|
|
bool res = true;
|
|
|
|
|
|
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
|
|
- res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
|
|
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
|
|
|
|
|
return res;
|
|
|
}
|
|
|
@@ -62,7 +62,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
|
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
|
|
bool res = true;
|
|
|
|
|
|
- res &= pos->ne[0] == params.ubatch.n_tokens;
|
|
|
+ res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
|
|
|
|
|
|
return res;
|
|
|
}
|