|
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
llama_pos pos = v_cells[s0].pos_get(i);
|
|
llama_pos pos = v_cells[s0].pos_get(i);
|
|
|
llama_pos shift = v_cells[s0].get_shift(i);
|
|
llama_pos shift = v_cells[s0].get_shift(i);
|
|
|
|
|
|
|
|
|
|
+ llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
|
|
|
|
|
+
|
|
|
if (shift != 0) {
|
|
if (shift != 0) {
|
|
|
pos -= shift;
|
|
pos -= shift;
|
|
|
assert(pos >= 0);
|
|
assert(pos >= 0);
|
|
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|
|
if (shift != 0) {
|
|
if (shift != 0) {
|
|
|
v_cells[s1].pos_add(i, shift);
|
|
v_cells[s1].pos_add(i, shift);
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ v_cells[s1].ext_set(i, ext);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
|
|
|
|
|
|
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
|
|
|
+ GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
|
|
|
|
|
|
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
|
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
|
|
|
|
|
|
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
|
|
|
|
+ GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
|
|
|
|
|
|
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
|
|
|
|
|
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|
|
|
|
|
|
|
cells.pos_set(idx, ubatch.pos[i]);
|
|
cells.pos_set(idx, ubatch.pos[i]);
|
|
|
|
|
|
|
|
|
|
+ if (ubatch.is_pos_2d()) {
|
|
|
|
|
+ llama_kv_cell_ext ext {
|
|
|
|
|
+ /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
|
|
|
|
|
+ /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
|
|
|
|
|
+ };
|
|
|
|
|
+ cells.ext_set(idx, ext);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
|
|
}
|
|
}
|
|
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
|
|
|
|
|
const llama_pos p1 = ubatch->pos[i];
|
|
const llama_pos p1 = ubatch->pos[i];
|
|
|
|
|
|
|
|
|
|
+ // for M-RoPE
|
|
|
|
|
+ const bool is_2d = ubatch->is_pos_2d();
|
|
|
|
|
+ const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
|
|
|
|
+ const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
|
|
|
|
+
|
|
|
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
|
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
|
|
|
|
|
|
|
for (uint32_t j = 0; j < n_kv; ++j) {
|
|
for (uint32_t j = 0; j < n_kv; ++j) {
|
|
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|
|
continue;
|
|
continue;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // M-RoPE causal mask
|
|
|
|
|
+ if (causal_attn && is_2d && p0 == p1) {
|
|
|
|
|
+ const auto & p0_ext = cells.ext_get(j);
|
|
|
|
|
+ if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// apply SWA if any
|
|
// apply SWA if any
|
|
|
if (is_masked_swa(p0, p1)) {
|
|
if (is_masked_swa(p0, p1)) {
|
|
|
continue;
|
|
continue;
|
|
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
|
|
io.write(&pos, sizeof(pos));
|
|
io.write(&pos, sizeof(pos));
|
|
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
|
|
|
|
|
|
|
|
|
+ // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
|
|
|
|
|
+ // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
|
|
|
+
|
|
|
for (const auto & seq_id : seq_ids) {
|
|
for (const auto & seq_id : seq_ids) {
|
|
|
io.write(&seq_id, sizeof(seq_id));
|
|
io.write(&seq_id, sizeof(seq_id));
|
|
|
}
|
|
}
|
|
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
|
|
|
|
|
+ // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
|
apply_ubatch(sinfo, ubatch);
|
|
apply_ubatch(sinfo, ubatch);
|
|
|
|
|
|
|
|
const auto head_cur = sinfo.head();
|
|
const auto head_cur = sinfo.head();
|