|
@@ -96,11 +96,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
|
|
|
|
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
|
|
|
|
|
|
- for (int h = 0; h < 1; ++h) {
|
|
|
|
|
- for (int j = 0; j < n_tokens; ++j) {
|
|
|
|
|
- for (int i = 0; i < n_tokens; ++i) {
|
|
|
|
|
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for (int j = 0; j < n_tokens; ++j) {
|
|
|
|
|
+ for (int i = 0; i < n_tokens; ++i) {
|
|
|
|
|
+ data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -323,34 +321,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
|
|
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
|
- for (int h = 0; h < 1; ++h) {
|
|
|
|
|
- for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
|
|
|
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
|
|
|
- const llama_pos p1 = ubatch->pos[i1];
|
|
|
|
|
|
|
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
|
|
|
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
|
|
|
+ const llama_pos p1 = ubatch->pos[i1];
|
|
|
|
|
|
|
|
- const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
|
|
|
|
|
|
+ const uint64_t idst = i1*n_kv;
|
|
|
|
|
|
|
|
- for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
|
|
|
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
|
|
|
- const llama_pos p0 = ubatch->pos[i0];
|
|
|
|
|
-
|
|
|
|
|
- // mask different sequences
|
|
|
|
|
- if (s0 != s1) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
|
|
|
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
|
|
|
+ const llama_pos p0 = ubatch->pos[i0];
|
|
|
|
|
|
|
|
- // mask future tokens
|
|
|
|
|
- if (cparams.causal_attn && p0 > p1) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // mask different sequences
|
|
|
|
|
+ if (s0 != s1) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // apply SWA if any
|
|
|
|
|
- if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // mask future tokens
|
|
|
|
|
+ if (cparams.causal_attn && p0 > p1) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
|
|
|
|
|
+ // apply SWA if any
|
|
|
|
|
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
|
|
|
+ continue;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
@@ -454,27 +450,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
|
|
|
|
|
float * data = (float *) cross_kq_mask->data;
|
|
float * data = (float *) cross_kq_mask->data;
|
|
|
|
|
|
|
|
- for (int h = 0; h < 1; ++h) {
|
|
|
|
|
- for (int i = 0; i < n_tokens; ++i) {
|
|
|
|
|
- for (int j = 0; j < n_enc; ++j) {
|
|
|
|
|
- float f = -INFINITY;
|
|
|
|
|
|
|
+ for (int i = 0; i < n_tokens; ++i) {
|
|
|
|
|
+ for (int j = 0; j < n_enc; ++j) {
|
|
|
|
|
+ float f = -INFINITY;
|
|
|
|
|
|
|
|
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
|
|
|
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
|
|
|
|
|
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
|
|
|
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
|
|
|
|
|
|
- if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
|
|
|
- f = 0.0f;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
|
|
|
+ f = 0.0f;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
|
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- for (int i = n_tokens; i < n_tokens; ++i) {
|
|
|
|
|
- for (int j = 0; j < n_enc; ++j) {
|
|
|
|
|
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ data[i*n_enc + j] = f;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|