|
|
@@ -2949,9 +2949,6 @@ struct llama_sbatch_seq {
|
|
|
llama_seq_id * seq_id;
|
|
|
size_t offset;
|
|
|
size_t length;
|
|
|
-
|
|
|
- // helper for smoother batch API transition -- can be deprecated in the future
|
|
|
- llama_seq_id all_seq_id; // used if seq_id == NULL
|
|
|
};
|
|
|
|
|
|
// sequence-length-aware batch splitting
|
|
|
@@ -3046,30 +3043,18 @@ struct llama_sbatch {
|
|
|
} else {
|
|
|
ubatch.embd = nullptr;
|
|
|
}
|
|
|
- // from here on, the else branches are deprecated;
|
|
|
- // they are helpers for smoother batch API transition
|
|
|
- if (batch->pos) {
|
|
|
- if (ubatch.equal_seqs) {
|
|
|
- for (size_t i = 0; i < length; ++i) {
|
|
|
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
|
|
- }
|
|
|
- } else {
|
|
|
- // simple split
|
|
|
- ubatch.pos = batch->pos + seq.offset;
|
|
|
- }
|
|
|
- } else {
|
|
|
+ if (ubatch.equal_seqs) {
|
|
|
for (size_t i = 0; i < length; ++i) {
|
|
|
- llama_pos bi = ids[seq.offset + i];
|
|
|
- ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
|
|
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
|
|
}
|
|
|
+ } else {
|
|
|
+ // simple split
|
|
|
+ ubatch.pos = batch->pos + seq.offset;
|
|
|
}
|
|
|
if (ubatch.equal_seqs) {
|
|
|
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
|
|
if (seq.seq_id) {
|
|
|
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
|
|
- } else {
|
|
|
- GGML_ASSERT(seq.n_seq_id == 1);
|
|
|
- ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
|
|
}
|
|
|
} else {
|
|
|
// simple split
|
|
|
@@ -3082,10 +3067,6 @@ struct llama_sbatch {
|
|
|
}
|
|
|
if (batch->seq_id) {
|
|
|
ubatch.seq_id = batch->seq_id + seq.offset;
|
|
|
- } else {
|
|
|
- for (size_t i = 0; i < length; ++i) {
|
|
|
- ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
|
|
- }
|
|
|
}
|
|
|
}
|
|
|
if (logits_all) {
|
|
|
@@ -3204,7 +3185,6 @@ struct llama_sbatch {
|
|
|
s.seq_id = nullptr;
|
|
|
s.offset = 0;
|
|
|
s.length = n_tokens;
|
|
|
- s.all_seq_id = batch.all_seq_id;
|
|
|
return;
|
|
|
}
|
|
|
std::sort(ids.begin(), ids.end(),
|
|
|
@@ -3227,7 +3207,7 @@ struct llama_sbatch {
|
|
|
if (batch.pos) {
|
|
|
return batch.pos[a] < batch.pos[b];
|
|
|
}
|
|
|
- // no pos, sort by id (assuming batch.all_pos_1 is positive)
|
|
|
+ // no pos, sort by id
|
|
|
return a < b;
|
|
|
}
|
|
|
// shared prompts go first
|
|
|
@@ -3237,30 +3217,25 @@ struct llama_sbatch {
|
|
|
// init seq
|
|
|
llama_sbatch_seq * last_seq = nullptr;
|
|
|
|
|
|
- if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
|
|
- for (size_t i = 0; i < n_tokens; ++i) {
|
|
|
- const size_t bi = ids[i];
|
|
|
- const int32_t n_seqs = batch.n_seq_id[bi];
|
|
|
- llama_seq_id * seq_ids = batch.seq_id[bi];
|
|
|
- if (last_seq != nullptr) {
|
|
|
- bool same = n_seqs == last_seq->n_seq_id;
|
|
|
- for (int32_t j = 0; same && j < n_seqs; ++j) {
|
|
|
- if (seq_ids[j] != last_seq->seq_id[j]) {
|
|
|
- same = false;
|
|
|
- }
|
|
|
- }
|
|
|
- if (same) {
|
|
|
- last_seq->length += 1;
|
|
|
- continue;
|
|
|
+ for (size_t i = 0; i < n_tokens; ++i) {
|
|
|
+ const size_t bi = ids[i];
|
|
|
+ const int32_t n_seqs = batch.n_seq_id[bi];
|
|
|
+ llama_seq_id * seq_ids = batch.seq_id[bi];
|
|
|
+ if (last_seq != nullptr) {
|
|
|
+ bool same = n_seqs == last_seq->n_seq_id;
|
|
|
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
|
|
|
+ if (seq_ids[j] != last_seq->seq_id[j]) {
|
|
|
+ same = false;
|
|
|
}
|
|
|
}
|
|
|
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
|
|
- seq.push_back(new_seq);
|
|
|
- last_seq = &seq.back();
|
|
|
+ if (same) {
|
|
|
+ last_seq->length += 1;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
}
|
|
|
- } else {
|
|
|
- llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
|
|
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
|
|
seq.push_back(new_seq);
|
|
|
+ last_seq = &seq.back();
|
|
|
}
|
|
|
// keep shared prompts first at the end, then sort by length descending.
|
|
|
std::sort(seq.begin(), seq.end(),
|
|
|
@@ -21096,9 +21071,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
|
|
|
|
|
struct llama_batch llama_batch_get_one(
|
|
|
llama_token * tokens,
|
|
|
- int32_t n_tokens,
|
|
|
- llama_pos pos_0,
|
|
|
- llama_seq_id seq_id) {
|
|
|
+ int32_t n_tokens) {
|
|
|
return {
|
|
|
/*n_tokens =*/ n_tokens,
|
|
|
/*tokens =*/ tokens,
|
|
|
@@ -21107,9 +21080,6 @@ struct llama_batch llama_batch_get_one(
|
|
|
/*n_seq_id =*/ nullptr,
|
|
|
/*seq_id =*/ nullptr,
|
|
|
/*logits =*/ nullptr,
|
|
|
- /*all_pos_0 =*/ pos_0,
|
|
|
- /*all_pos_1 =*/ 1,
|
|
|
- /*all_seq_id =*/ seq_id,
|
|
|
};
|
|
|
}
|
|
|
|
|
|
@@ -21122,9 +21092,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|
|
/*n_seq_id =*/ nullptr,
|
|
|
/*seq_id =*/ nullptr,
|
|
|
/*logits =*/ nullptr,
|
|
|
- /*all_pos_0 =*/ 0,
|
|
|
- /*all_pos_1 =*/ 0,
|
|
|
- /*all_seq_id =*/ 0,
|
|
|
};
|
|
|
|
|
|
if (embd) {
|
|
|
@@ -21160,11 +21127,62 @@ void llama_batch_free(struct llama_batch batch) {
|
|
|
if (batch.logits) free(batch.logits);
|
|
|
}
|
|
|
|
|
|
+// temporary allocate memory for the input batch if needed
|
|
|
+static const llama_seq_id batch_default_seq_id = 0;
|
|
|
+struct llama_batch_allocr {
|
|
|
+ std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
|
|
+ std::vector<llama_pos> pos;
|
|
|
+ std::vector<int32_t> n_seq_id;
|
|
|
+ std::vector<llama_seq_id *> seq_id;
|
|
|
+ std::vector<int8_t> logits;
|
|
|
+ struct llama_batch batch;
|
|
|
+ // optionally fulfill the batch returned by llama_batch_get_one
|
|
|
+ llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
|
|
|
+ batch = in_batch;
|
|
|
+ if (!batch.pos) {
|
|
|
+ // determine the last position in KV cache
|
|
|
+ llama_pos last_pos = -1;
|
|
|
+ for (const auto & cell : ctx->kv_self.cells) {
|
|
|
+ if (cell.has_seq_id(batch_default_seq_id)) {
|
|
|
+ last_pos = std::max(last_pos, cell.pos);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ last_pos++; // next position
|
|
|
+ pos.resize(batch.n_tokens);
|
|
|
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
|
+ pos[i] = i+last_pos;
|
|
|
+ }
|
|
|
+ batch.pos = pos.data();
|
|
|
+ }
|
|
|
+ if (!batch.n_seq_id) {
|
|
|
+ n_seq_id.resize(batch.n_tokens);
|
|
|
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
|
+ n_seq_id[i] = seq_id_0.size();
|
|
|
+ }
|
|
|
+ batch.n_seq_id = n_seq_id.data();
|
|
|
+ }
|
|
|
+ if (!batch.seq_id) {
|
|
|
+ seq_id.resize(batch.n_tokens + 1);
|
|
|
+ seq_id[batch.n_tokens] = NULL;
|
|
|
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
|
+ seq_id[i] = seq_id_0.data();
|
|
|
+ }
|
|
|
+ batch.seq_id = seq_id.data();
|
|
|
+ }
|
|
|
+ if (!batch.logits) {
|
|
|
+ logits.resize(batch.n_tokens);
|
|
|
+ logits[logits.size() - 1] = true;
|
|
|
+ batch.logits = logits.data();
|
|
|
+ }
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
int32_t llama_encode(
|
|
|
struct llama_context * ctx,
|
|
|
struct llama_batch batch) {
|
|
|
- const int ret = llama_encode_internal(*ctx, batch);
|
|
|
- if (ret < 0) {
|
|
|
+ llama_batch_allocr batch_allocr(ctx, batch);
|
|
|
+ const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
|
|
|
+ if (ret != 0) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
|
|
}
|
|
|
|
|
|
@@ -21174,8 +21192,9 @@ int32_t llama_encode(
|
|
|
int32_t llama_decode(
|
|
|
struct llama_context * ctx,
|
|
|
struct llama_batch batch) {
|
|
|
- const int ret = llama_decode_internal(*ctx, batch);
|
|
|
- if (ret < 0) {
|
|
|
+ llama_batch_allocr batch_allocr(ctx, batch);
|
|
|
+ const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
|
|
|
+ if (ret != 0) {
|
|
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
|
}
|
|
|
|