|
@@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
|
|
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
|
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
|
|
|
|
|
|
|
// SWA mask
|
|
// SWA mask
|
|
|
- if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
|
|
|
|
|
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
|
can_use = true;
|
|
can_use = true;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -1237,90 +1237,236 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
|
|
|
- const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
|
+struct args_set_input_kq_mask {
|
|
|
|
|
+ const llama_hparams & hparams;
|
|
|
|
|
+ const llama_ubatch * ubatch;
|
|
|
|
|
|
|
|
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
|
|
|
- float * data = (float *) dst->data;
|
|
|
|
|
|
|
+ const std::vector<llama_kv_cells> & v_cells;
|
|
|
|
|
+ const std::vector<uint32_t> & seq_to_stream;
|
|
|
|
|
|
|
|
- const int64_t n_kv = dst->ne[0];
|
|
|
|
|
- const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
|
|
|
|
|
|
|
+ uint32_t n_swa;
|
|
|
|
|
+ llama_swa_type swa_type;
|
|
|
|
|
|
|
|
- GGML_ASSERT(n_tokens%n_stream == 0);
|
|
|
|
|
|
|
+ int64_t n_kv;
|
|
|
|
|
+ int64_t n_stream;
|
|
|
|
|
+ int64_t n_tps;
|
|
|
|
|
+};
|
|
|
|
|
|
|
|
- // n_tps == n_tokens_per_stream
|
|
|
|
|
- const int64_t n_tps = n_tokens/n_stream;
|
|
|
|
|
|
|
+template<bool causal, bool swa, bool is_2d, bool alibi>
|
|
|
|
|
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
|
|
|
+ //const auto & hparams = args.hparams;
|
|
|
|
|
+ const auto & ubatch = args.ubatch;
|
|
|
|
|
|
|
|
- std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
|
|
|
|
-
|
|
|
|
|
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
|
|
|
|
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
|
|
|
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
|
|
|
|
- // Causal mask:
|
|
|
|
|
- // xxx-------
|
|
|
|
|
- // xxxx------
|
|
|
|
|
- // xxxxx-----
|
|
|
|
|
- // Non-causal mask:
|
|
|
|
|
- // xxxxx-----
|
|
|
|
|
- // xxxxx-----
|
|
|
|
|
- // xxxxx-----
|
|
|
|
|
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
|
|
|
|
- // TODO: optimize this section
|
|
|
|
|
- for (uint32_t h = 0; h < 1; ++h) {
|
|
|
|
|
- for (uint32_t s = 0; s < n_stream; ++s) {
|
|
|
|
|
- for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
|
|
|
|
- const uint32_t i = s*n_tps + ii;
|
|
|
|
|
|
|
+ const auto & v_cells = args.v_cells;
|
|
|
|
|
+ const auto & seq_to_stream = args.seq_to_stream;
|
|
|
|
|
|
|
|
- const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
|
|
|
|
|
+ const uint32_t n_swa = args.n_swa;
|
|
|
|
|
+ const llama_swa_type swa_type = args.swa_type;
|
|
|
|
|
|
|
|
- const auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
|
|
|
|
|
+ const int64_t n_kv = args.n_kv;
|
|
|
|
|
+ const int64_t n_stream = args.n_stream;
|
|
|
|
|
+ const int64_t n_tps = args.n_tps;
|
|
|
|
|
|
|
|
- const llama_pos p1 = ubatch->pos[i];
|
|
|
|
|
|
|
+ // the min position in the batch for each sequence
|
|
|
|
|
+ llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
|
|
|
|
+ std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
|
|
|
|
|
|
|
- // for M-RoPE
|
|
|
|
|
- const bool is_2d = ubatch->is_pos_2d();
|
|
|
|
|
- const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
|
|
|
|
- const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
|
|
|
|
|
|
+ for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
|
|
|
|
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
|
|
|
|
|
|
- const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
|
|
|
|
|
|
+ seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- for (uint32_t j = 0; j < n_kv; ++j) {
|
|
|
|
|
- if (cells.is_empty(j)) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for (uint32_t s = 0; s < n_stream; ++s) {
|
|
|
|
|
+ // bookeeping of the KQ mask cells that could change for other tokens of the same sequence
|
|
|
|
|
+ std::unordered_map<llama_seq_id, uint32_t> seq_srct;
|
|
|
|
|
+ std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
|
|
|
|
|
+
|
|
|
|
|
+ for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
|
|
|
|
+ const uint32_t i = s*n_tps + ii;
|
|
|
|
|
+
|
|
|
|
|
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
|
|
|
|
+
|
|
|
|
|
+ const auto & cells = v_cells.at(seq_to_stream[seq_id]);
|
|
|
|
|
+
|
|
|
|
|
+ llama_pos p0 = -1;
|
|
|
|
|
+ const llama_pos p1 = ubatch->pos[i];
|
|
|
|
|
+
|
|
|
|
|
+ // for M-RoPE
|
|
|
|
|
+ const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
|
|
|
|
+ const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
|
|
|
|
+
|
|
|
|
|
+ const uint64_t idst = n_kv*i;
|
|
|
|
|
+
|
|
|
|
|
+ // for tokens of the same sequence, the mask is mostly the same, so we can reuse it
|
|
|
|
|
+ // the only cells that could change are the ones that are with similar positions as the
|
|
|
|
|
+ // ones in the batch (i.e. due to causal masking, SWA, etc.)
|
|
|
|
|
+ // keep track of those cells and shortcut the loop to save time
|
|
|
|
|
+ // note: this optimization is not compatible with Alibi position encoding
|
|
|
|
|
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18842
|
|
|
|
|
+ bool prev = false;
|
|
|
|
|
|
|
|
- // mask the token if not the same sequence
|
|
|
|
|
- if (!cells.seq_has(j, seq_id)) {
|
|
|
|
|
- continue;
|
|
|
|
|
|
|
+ auto & idxs = seq_idxs[seq_id];
|
|
|
|
|
+
|
|
|
|
|
+ if (!alibi) {
|
|
|
|
|
+ if (seq_srct.find(seq_id) != seq_srct.end()) {
|
|
|
|
|
+ const uint32_t srct = seq_srct[seq_id];
|
|
|
|
|
+
|
|
|
|
|
+ const uint64_t idst_prev = n_kv*srct;
|
|
|
|
|
+
|
|
|
|
|
+ std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
|
|
|
|
|
+
|
|
|
|
|
+ prev = true;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ idxs.clear();
|
|
|
|
|
+ idxs.reserve(ubatch->n_tokens + n_swa + 32);
|
|
|
|
|
+
|
|
|
|
|
+ seq_srct[seq_id] = i;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (uint32_t jj = 0; jj < n_kv; ++jj) {
|
|
|
|
|
+ uint32_t j = jj;
|
|
|
|
|
+
|
|
|
|
|
+ // we have an exiting mask for this sequence -> update just seq_idxs
|
|
|
|
|
+ if (!alibi) {
|
|
|
|
|
+ if (prev) {
|
|
|
|
|
+ if (jj >= idxs.size()) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ j = idxs[jj];
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (cells.is_empty(j)) {
|
|
|
|
|
+ goto skip;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // mask the token if not the same sequence
|
|
|
|
|
+ if (!cells.seq_has(j, seq_id)) {
|
|
|
|
|
+ goto skip;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ p0 = cells.pos_get(j);
|
|
|
|
|
|
|
|
- const llama_pos p0 = cells.pos_get(j);
|
|
|
|
|
|
|
+ if (!alibi) {
|
|
|
|
|
+ if (!prev) {
|
|
|
|
|
+ // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
|
|
|
|
|
+ if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
|
|
|
|
|
+ idxs.push_back(j);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
|
|
+ if (causal) {
|
|
|
// mask future tokens
|
|
// mask future tokens
|
|
|
- if (causal_attn && p0 > p1) {
|
|
|
|
|
- continue;
|
|
|
|
|
|
|
+ if (p0 > p1) {
|
|
|
|
|
+ goto skip;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// M-RoPE causal mask
|
|
// M-RoPE causal mask
|
|
|
- if (causal_attn && is_2d && p0 == p1) {
|
|
|
|
|
- const auto & p0_ext = cells.ext_get(j);
|
|
|
|
|
- if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
|
|
|
|
- continue;
|
|
|
|
|
|
|
+ if (is_2d) {
|
|
|
|
|
+ if (p0 == p1) {
|
|
|
|
|
+ const auto & p0_ext = cells.ext_get(j);
|
|
|
|
|
+
|
|
|
|
|
+ if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
|
|
|
|
+ goto skip;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // apply SWA if any
|
|
|
|
|
- if (is_masked_swa(p0, p1)) {
|
|
|
|
|
- continue;
|
|
|
|
|
|
|
+ // apply SWA if any
|
|
|
|
|
+ if (swa) {
|
|
|
|
|
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
|
|
|
+ goto skip;
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
|
|
|
|
|
+ if (alibi) {
|
|
|
|
|
+ data[idst + j] = -std::abs(p0 - p1);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ data[idst + j] = 0.0f;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ continue;
|
|
|
|
|
+skip:
|
|
|
|
|
+ data[idst + j] = -INFINITY;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+template<bool causal, bool swa, bool is_2d>
|
|
|
|
|
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
|
|
|
+ const bool alibi = args.hparams.use_alibi;
|
|
|
|
|
+ if (alibi) {
|
|
|
|
|
+ set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+template<bool causal, bool swa>
|
|
|
|
|
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
|
|
|
+ const bool is_2d = args.ubatch->is_pos_2d();
|
|
|
|
|
+ if (is_2d) {
|
|
|
|
|
+ set_input_kq_mask_impl<causal, swa, true> (args, data);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ set_input_kq_mask_impl<causal, swa, false>(args, data);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+template<bool causal>
|
|
|
|
|
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
|
|
|
|
+ const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
|
|
|
|
+ if (swa) {
|
|
|
|
|
+ set_input_kq_mask_impl<causal, true> (args, data);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ set_input_kq_mask_impl<causal, false>(args, data);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
|
|
|
|
+ const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
|
|
|
+ float * data = (float *) dst->data;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t n_kv = dst->ne[0];
|
|
|
|
|
+ const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(n_tokens%n_stream == 0);
|
|
|
|
|
+
|
|
|
|
|
+ // n_tps == n_tokens_per_stream
|
|
|
|
|
+ const int64_t n_tps = n_tokens/n_stream;
|
|
|
|
|
+
|
|
|
|
|
+ //const int64_t t_start = ggml_time_us();
|
|
|
|
|
+
|
|
|
|
|
+ const args_set_input_kq_mask args = {
|
|
|
|
|
+ /*.hparams =*/ hparams,
|
|
|
|
|
+ /*.ubatch =*/ ubatch,
|
|
|
|
|
+ /*.v_cells =*/ v_cells,
|
|
|
|
|
+ /*.seq_to_stream =*/ seq_to_stream,
|
|
|
|
|
+ /*.n_swa =*/ n_swa,
|
|
|
|
|
+ /*.swa_type =*/ swa_type,
|
|
|
|
|
+ /*.n_kv =*/ n_kv,
|
|
|
|
|
+ /*.n_stream =*/ n_stream,
|
|
|
|
|
+ /*.n_tps =*/ n_tps,
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ if (causal_attn) {
|
|
|
|
|
+ set_input_kq_mask_impl<true> (args, data);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ set_input_kq_mask_impl<false>(args, data);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //const int64_t t_end = ggml_time_us();
|
|
|
|
|
+
|
|
|
|
|
+ //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
|
@@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|
|
return gf;
|
|
return gf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|
|
|
|
- return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
|
GGML_UNUSED(flags);
|
|
GGML_UNUSED(flags);
|
|
|
|
|
|