|
@@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
|
|
|
|
|
|
|
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
|
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
|
|
|
|
|
|
|
|
|
+ slot_info sinfo;
|
|
|
|
|
+
|
|
|
bool res = true;
|
|
bool res = true;
|
|
|
- res = res && state_read_meta(io, strm, cell_count, seq_id);
|
|
|
|
|
- res = res && state_read_data(io, strm, cell_count);
|
|
|
|
|
|
|
+ res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
|
|
|
|
|
+ res = res && state_read_data(io, strm, cell_count, sinfo);
|
|
|
|
|
|
|
|
if (!res) {
|
|
if (!res) {
|
|
|
if (seq_id == -1) {
|
|
if (seq_id == -1) {
|
|
@@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
|
|
|
|
|
|
+bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
|
|
|
auto & cells = v_cells[strm];
|
|
auto & cells = v_cells[strm];
|
|
|
auto & head = v_heads[strm];
|
|
auto & head = v_heads[strm];
|
|
|
|
|
|
|
@@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
ubatch.seq_id[i] = &dest_seq_id;
|
|
ubatch.seq_id[i] = &dest_seq_id;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const auto sinfo = find_slot(ubatch, true);
|
|
|
|
|
|
|
+ sinfo = find_slot(ubatch, false);
|
|
|
if (sinfo.empty()) {
|
|
if (sinfo.empty()) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
|
return false;
|
|
return false;
|
|
@@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
|
|
apply_ubatch(sinfo, ubatch);
|
|
apply_ubatch(sinfo, ubatch);
|
|
|
|
|
|
|
|
- const auto head_cur = sinfo.head();
|
|
|
|
|
-
|
|
|
|
|
- // keep the head at the old position because we will read the KV data into it in state_read_data()
|
|
|
|
|
- head = head_cur;
|
|
|
|
|
-
|
|
|
|
|
- LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
|
|
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
|
|
|
|
|
|
|
|
- // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
|
|
|
|
- // Assume that this is one contiguous block of cells
|
|
|
|
|
- GGML_ASSERT(head_cur + cell_count <= cells.size());
|
|
|
|
|
- GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
|
|
|
|
- GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
|
|
|
|
- GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
|
|
|
|
- GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
|
|
|
|
|
|
+ // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
|
|
|
|
|
+ GGML_ASSERT(sinfo.n_stream() == 1);
|
|
|
|
|
+ GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
|
|
|
|
|
+ for (uint32_t i = 0; i < cell_count; ++i) {
|
|
|
|
|
+ const uint32_t idx = sinfo.idxs[0][i];
|
|
|
|
|
+ GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
|
|
|
|
|
+ GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
|
|
|
|
|
+ }
|
|
|
} else {
|
|
} else {
|
|
|
// whole KV cache restore
|
|
// whole KV cache restore
|
|
|
|
|
|
|
@@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // Create contiguous slot_info for whole cache restore
|
|
|
|
|
+ sinfo.s0 = strm;
|
|
|
|
|
+ sinfo.s1 = strm;
|
|
|
|
|
+ sinfo.resize(1);
|
|
|
|
|
+ sinfo.strm[0] = strm;
|
|
|
|
|
+ sinfo.idxs[0].resize(cell_count);
|
|
|
|
|
+ for (uint32_t i = 0; i < cell_count; ++i) {
|
|
|
|
|
+ sinfo.idxs[0][i] = i;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
head = 0;
|
|
head = 0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
|
|
|
|
|
|
+bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
|
|
|
auto & cells = v_cells[strm];
|
|
auto & cells = v_cells[strm];
|
|
|
- auto & head = v_heads[strm];
|
|
|
|
|
|
|
|
|
|
uint32_t v_trans;
|
|
uint32_t v_trans;
|
|
|
uint32_t n_layer;
|
|
uint32_t n_layer;
|
|
@@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (cell_count) {
|
|
if (cell_count) {
|
|
|
- // Read and set the keys for the whole cell range
|
|
|
|
|
- ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
|
|
|
|
|
|
+ if (sinfo.is_contiguous()) {
|
|
|
|
|
+ // Fast path: contiguous cells, single memcpy
|
|
|
|
|
+ ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // Slow path: scatter to non-contiguous positions
|
|
|
|
|
+ const void * src = io.read(cell_count * k_size_row);
|
|
|
|
|
+ for (uint32_t i = 0; i < cell_count; ++i) {
|
|
|
|
|
+ const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
|
|
|
|
+ ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (cell_count) {
|
|
if (cell_count) {
|
|
|
- // Read and set the values for the whole cell range
|
|
|
|
|
- ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
|
|
|
|
|
|
+ if (sinfo.is_contiguous()) {
|
|
|
|
|
+ // Fast path: contiguous cells, single memcpy
|
|
|
|
|
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // Slow path: scatter to non-contiguous positions
|
|
|
|
|
+ const void * src = io.read(cell_count * v_size_row);
|
|
|
|
|
+ for (uint32_t i = 0; i < cell_count; ++i) {
|
|
|
|
|
+ const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
|
|
|
|
+ ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
@@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (cell_count) {
|
|
if (cell_count) {
|
|
|
- // For each row in the transposed matrix, read the values for the whole cell range
|
|
|
|
|
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
|
|
|
- const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
|
|
|
|
- ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
|
|
|
|
|
|
+ if (sinfo.is_contiguous()) {
|
|
|
|
|
+ // Fast path: contiguous cells
|
|
|
|
|
+ const uint32_t h = sinfo.head();
|
|
|
|
|
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
|
|
|
+ const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
|
|
|
|
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // Slow path: scatter to non-contiguous positions
|
|
|
|
|
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
|
|
|
+ const void * src = io.read(cell_count * v_size_el);
|
|
|
|
|
+ for (uint32_t i = 0; i < cell_count; ++i) {
|
|
|
|
|
+ const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
|
|
|
|
+ ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|