|
|
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
|
|
+static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
|
|
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
|
|
- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
|
|
- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
|
|
- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
|
|
- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
|
|
+ const char * swa_type_str = "unknown";
|
|
|
+
|
|
|
+ switch (swa_type) {
|
|
|
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
|
|
|
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
|
|
|
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
|
|
|
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
|
|
+ };
|
|
|
+
|
|
|
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
|
|
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
|
|
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
|
|
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
const int64_t n_kv = ubatch->n_tokens;
|
|
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
- GGML_ASSERT(kq_mask);
|
|
|
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
|
-
|
|
|
- float * data = (float *) kq_mask->data;
|
|
|
-
|
|
|
- // [TAG_NO_CACHE_ISWA]
|
|
|
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
|
|
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
|
+ for (int h = 0; h < 1; ++h) {
|
|
|
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
|
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
|
+ const llama_pos p1 = ubatch->pos[i1];
|
|
|
|
|
|
- for (int h = 0; h < 1; ++h) {
|
|
|
- for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
|
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
|
+ const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
|
|
|
|
|
- for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
|
- float f = -INFINITY;
|
|
|
-
|
|
|
- for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
|
|
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
|
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
|
+ const llama_pos p0 = ubatch->pos[i0];
|
|
|
|
|
|
+ // mask different sequences
|
|
|
if (s0 != s1) {
|
|
|
- continue; // skip different sequences
|
|
|
+ continue;
|
|
|
}
|
|
|
|
|
|
- if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
|
|
- continue; // skip future tokens for causal attention
|
|
|
+ // mask future tokens
|
|
|
+ if (cparams.causal_attn && p0 > p1) {
|
|
|
+ continue;
|
|
|
}
|
|
|
|
|
|
- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
|
|
- //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
|
|
- // continue; // skip masked tokens for SWA
|
|
|
- //}
|
|
|
-
|
|
|
- // TODO: reimplement this like in llama_kv_cache_unified
|
|
|
- if (hparams.use_alibi) {
|
|
|
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
|
|
- } else {
|
|
|
- f = 0.0f;
|
|
|
+ // apply SWA if any
|
|
|
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
|
+ continue;
|
|
|
}
|
|
|
+
|
|
|
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
|
}
|
|
|
- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
|
|
}
|
|
|
}
|
|
|
+ };
|
|
|
+
|
|
|
+ {
|
|
|
+ GGML_ASSERT(self_kq_mask);
|
|
|
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
|
|
+
|
|
|
+ float * data = (float *) self_kq_mask->data;
|
|
|
+
|
|
|
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
|
|
+
|
|
|
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
|
|
+
|
|
|
+ if (debug) {
|
|
|
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
|
|
+ }
|
|
|
}
|
|
|
- if (debug) {
|
|
|
- print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
|
|
+
|
|
|
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
|
+ GGML_ASSERT(self_kq_mask_swa);
|
|
|
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
|
|
+
|
|
|
+ float * data = (float *) self_kq_mask_swa->data;
|
|
|
+
|
|
|
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
|
|
+
|
|
|
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
|
|
|
+
|
|
|
+ if (debug) {
|
|
|
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1299,12 +1321,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
|
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
|
|
|
|
|
- const auto n_kv = k->ne[1];
|
|
|
-
|
|
|
ggml_tensor * cur;
|
|
|
|
|
|
// TODO: replace hardcoded padding with ggml-provided padding
|
|
|
- if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
|
|
+ if (cparams.flash_attn && kq_b == nullptr) {
|
|
|
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
|
|
|
|
|
if (v_trans) {
|
|
|
@@ -1419,10 +1439,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
|
|
|
|
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
|
- inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
|
- ggml_set_input(inp->kq_mask);
|
|
|
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
|
+ ggml_set_input(inp->self_kq_mask);
|
|
|
+
|
|
|
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
|
|
|
|
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
|
|
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
|
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
|
+ ggml_set_input(inp->self_kq_mask_swa);
|
|
|
+
|
|
|
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
|
+ } else {
|
|
|
+ inp->self_kq_mask_swa = nullptr;
|
|
|
+ inp->self_kq_mask_swa_cnv = nullptr;
|
|
|
+ }
|
|
|
|
|
|
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
|
|
}
|
|
|
@@ -1447,7 +1477,9 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
ggml_build_forward_expand(gf, k_cur);
|
|
|
ggml_build_forward_expand(gf, v_cur);
|
|
|
|
|
|
- const auto & kq_mask = inp->get_kq_mask();
|
|
|
+ const bool is_swa = hparams.is_swa(il);
|
|
|
+
|
|
|
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
|
|
|
|
// [TAG_NO_CACHE_PAD]
|
|
|
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|