|
|
@@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
|
const int64_t n_tps = n_tokens/n_stream;
|
|
|
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
|
|
|
|
|
|
+ std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
|
|
+
|
|
|
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
|
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
|
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
|
|
@@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
|
|
|
|
const llama_pos p1 = ubatch->pos[i];
|
|
|
|
|
|
- for (uint32_t j = 0; j < n_kv; ++j) {
|
|
|
- float f = 0.0f;
|
|
|
-
|
|
|
- bool masked = false;
|
|
|
+ const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
|
|
|
|
|
+ for (uint32_t j = 0; j < n_kv; ++j) {
|
|
|
if (cells.is_empty(j)) {
|
|
|
- masked = true;
|
|
|
- } else {
|
|
|
- const llama_pos p0 = cells.pos_get(j);
|
|
|
-
|
|
|
- // mask the token if not the same sequence
|
|
|
- masked = masked || (!cells.seq_has(j, seq_id));
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- // mask future tokens
|
|
|
- masked = masked || (causal_attn && p0 > p1);
|
|
|
+ // mask the token if not the same sequence
|
|
|
+ if (!cells.seq_has(j, seq_id)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
|
|
|
- // apply SWA if any
|
|
|
- masked = masked || (is_masked_swa(p0, p1));
|
|
|
+ const llama_pos p0 = cells.pos_get(j);
|
|
|
|
|
|
- if (!masked && hparams.use_alibi) {
|
|
|
- f = -std::abs(p0 - p1);
|
|
|
- }
|
|
|
+ // mask future tokens
|
|
|
+ if (causal_attn && p0 > p1) {
|
|
|
+ continue;
|
|
|
}
|
|
|
|
|
|
- if (masked) {
|
|
|
- f = -INFINITY;
|
|
|
+ // apply SWA if any
|
|
|
+ if (is_masked_swa(p0, p1)) {
|
|
|
+ continue;
|
|
|
}
|
|
|
|
|
|
- data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
|
|
|
- }
|
|
|
-
|
|
|
- // mask padded tokens
|
|
|
- if (data) {
|
|
|
- for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
|
|
|
- for (uint32_t j = 0; j < n_kv; ++j) {
|
|
|
- data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
|
|
|
- }
|
|
|
- }
|
|
|
+ data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
|
}
|
|
|
}
|
|
|
}
|