|
|
@@ -254,6 +254,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
|
|
+ const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
|
|
|
+
|
|
|
+ this->mctx = mctx;
|
|
|
+
|
|
|
+ bool res = true;
|
|
|
+
|
|
|
+ res &= s_copy->ne[0] == mctx->get_n_rs();
|
|
|
+
|
|
|
+ res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
|
+ res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
|
|
+
|
|
|
+ res &= head == mctx->get_head();
|
|
|
+ res &= rs_z == mctx->get_rs_z();
|
|
|
+
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
GGML_UNUSED(ubatch);
|
|
|
|
|
|
@@ -461,8 +479,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
}
|
|
|
|
|
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
|
- inp_attn->set_input(ubatch);
|
|
|
- inp_rs->set_input(ubatch);
|
|
|
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
|
|
+ mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
|
|
+
|
|
|
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
|
|
+
|
|
|
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
|
+
|
|
|
+ if (inp_rs->s_copy) {
|
|
|
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
|
|
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
|
|
+
|
|
|
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
|
+ for (uint32_t i = 0; i < n_rs; ++i) {
|
|
|
+ data[i] = mctx->get_recr()->s_copy(i);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|
|
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
|
|
+
|
|
|
+ this->mctx = mctx;
|
|
|
+
|
|
|
+ bool res = true;
|
|
|
+
|
|
|
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
|
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
|
+
|
|
|
+ res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
|
|
+ res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
|
|
+
|
|
|
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
|
|
+
|
|
|
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
|
|
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
|
|
+
|
|
|
+ res &= inp_rs->head == mctx->get_recr()->get_head();
|
|
|
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
|
|
+
|
|
|
+ return res;
|
|
|
}
|
|
|
|
|
|
//
|
|
|
@@ -1850,6 +1906,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|
|
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
|
|
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
|
|
|
|
|
+ inp->head = mctx_cur->get_head();
|
|
|
+ inp->rs_z = mctx_cur->get_rs_z();
|
|
|
+
|
|
|
return inp;
|
|
|
}
|
|
|
|
|
|
@@ -1918,10 +1977,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
|
|
|
|
- auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
|
|
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
|
|
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
|
|
|
|
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
|
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
|
|
|
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
|
}
|