|
@@ -3,6 +3,7 @@
|
|
|
#include "llama-impl.h"
|
|
#include "llama-impl.h"
|
|
|
#include "llama-cparams.h"
|
|
#include "llama-cparams.h"
|
|
|
#include "llama-vocab.h"
|
|
#include "llama-vocab.h"
|
|
|
|
|
+#include "llama-memory.h"
|
|
|
|
|
|
|
|
#include <cassert>
|
|
#include <cassert>
|
|
|
#include <cstring>
|
|
#include <cstring>
|
|
@@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
|
|
llama_batch_allocr::llama_batch_allocr() {
|
|
llama_batch_allocr::llama_batch_allocr() {
|
|
|
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
|
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
|
|
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
|
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
|
|
|
|
+
|
|
|
|
|
+ seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
|
|
+ seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
|
|
+ for (auto & cur : seq_cpl) {
|
|
|
|
|
+ cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
|
|
|
|
|
|
|
+bool llama_batch_allocr::init(
|
|
|
|
|
+ const llama_batch & batch_inp,
|
|
|
|
|
+ const llama_vocab & vocab,
|
|
|
|
|
+ const llama_memory_i * memory) {
|
|
|
clear();
|
|
clear();
|
|
|
|
|
|
|
|
batch = batch_inp;
|
|
batch = batch_inp;
|
|
|
|
|
|
|
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
|
|
|
|
|
|
- if (!batch.pos) {
|
|
|
|
|
- if (batch.seq_id) {
|
|
|
|
|
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
|
|
|
|
- return false;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ //
|
|
|
|
|
+ // validate input batch
|
|
|
|
|
+ //
|
|
|
|
|
|
|
|
if (batch.token) {
|
|
if (batch.token) {
|
|
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
@@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (!batch.pos) {
|
|
|
|
|
- assert(p0 >= 0);
|
|
|
|
|
- pos.resize(batch.n_tokens);
|
|
|
|
|
- for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
|
|
|
- pos[i] = p0 + i;
|
|
|
|
|
- }
|
|
|
|
|
- batch.pos = pos.data();
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ //
|
|
|
|
|
+ // auto-generate missing fields
|
|
|
|
|
+ //
|
|
|
|
|
|
|
|
if (!batch.n_seq_id) {
|
|
if (!batch.n_seq_id) {
|
|
|
n_seq_id.resize(batch.n_tokens);
|
|
n_seq_id.resize(batch.n_tokens);
|
|
@@ -349,6 +351,32 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
|
|
|
batch.seq_id = seq_id.data();
|
|
batch.seq_id = seq_id.data();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (!batch.pos) {
|
|
|
|
|
+ pos.resize(batch.n_tokens);
|
|
|
|
|
+
|
|
|
|
|
+ // initialize the starting position for each sequence based on the positions in the memory
|
|
|
|
|
+ llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
|
|
|
|
|
+ for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
|
|
+ if (!memory) {
|
|
|
|
|
+ p0[s] = 0;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ p0[s] = memory->seq_pos_max(s) + 1;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
|
|
|
+ const llama_seq_id seq_id = batch.seq_id[i][0];
|
|
|
|
|
+
|
|
|
|
|
+ pos[i] = p0[seq_id];
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
|
|
|
+ p0[batch.seq_id[i][s]] = pos[i] + 1;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ batch.pos = pos.data();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (!batch.logits) {
|
|
if (!batch.logits) {
|
|
|
// by default return the output only for the last token
|
|
// by default return the output only for the last token
|
|
|
output.resize(batch.n_tokens);
|
|
output.resize(batch.n_tokens);
|
|
@@ -356,13 +384,36 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
|
|
|
batch.logits = output.data();
|
|
batch.logits = output.data();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ //
|
|
|
|
|
+ // compute stats
|
|
|
|
|
+ //
|
|
|
|
|
+
|
|
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
n_outputs += batch.logits[i] != 0;
|
|
n_outputs += batch.logits[i] != 0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // determine coupled sequences
|
|
|
|
|
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
|
|
|
|
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
|
|
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
|
|
|
+ seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
|
|
|
|
|
+
|
|
|
|
|
+ if (s > 0) {
|
|
|
|
|
+ const llama_seq_id s0 = batch.seq_id[i][0];
|
|
|
|
|
+ const llama_seq_id s1 = batch.seq_id[i][s];
|
|
|
|
|
+
|
|
|
|
|
+ // mark that sequence s1 is coupled to s0
|
|
|
|
|
+ seq_cpl[s1][s0] = true;
|
|
|
|
|
+
|
|
|
|
|
+ // note: the other way around is not necessary for now
|
|
|
|
|
+ //seq_cpl[s0][s1] = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (debug > 0) {
|
|
if (debug > 0) {
|
|
|
- LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
|
|
|
|
|
- LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
|
|
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
|
|
|
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
|
|
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
|
|
|
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
|
|
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
|
|
|
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
|
|
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
|
|
@@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
|
|
|
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
|
|
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
|
|
|
}
|
|
}
|
|
|
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
|
|
|
+
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
|
|
|
|
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
|
|
|
|
+ if (seq_pos[s0].empty()) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::stringstream ss;
|
|
|
|
|
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
|
|
|
|
+ if (seq_cpl[s0][s1]) {
|
|
|
|
|
+ ss << s1 << " ";
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
|
|
|
|
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
|
|
|
|
+ }
|
|
|
|
|
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //
|
|
|
|
|
+ // consistency checks
|
|
|
|
|
+ //
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
|
|
+ if (seq_pos[s].empty()) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
|
|
|
|
+ LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
|
|
|
+ LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (memory) {
|
|
|
|
|
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
|
|
|
|
|
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
|
|
|
|
|
+ if (seq_cpl[s0][s1]) {
|
|
|
|
|
+ if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
|
|
|
|
+ memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
|
|
|
|
+ LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|
|
return n_outputs;
|
|
return n_outputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
|
|
|
|
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
|
|
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void llama_batch_allocr::clear() {
|
|
void llama_batch_allocr::clear() {
|
|
|
n_outputs = 0;
|
|
n_outputs = 0;
|
|
|
|
|
|
|
@@ -426,6 +537,14 @@ void llama_batch_allocr::clear() {
|
|
|
n_seq_id.clear();
|
|
n_seq_id.clear();
|
|
|
seq_id.clear();
|
|
seq_id.clear();
|
|
|
output.clear();
|
|
output.clear();
|
|
|
|
|
+
|
|
|
|
|
+ for (auto & cur : seq_pos) {
|
|
|
|
|
+ cur.clear();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (auto & cur : seq_cpl) {
|
|
|
|
|
+ std::fill(cur.begin(), cur.end(), false);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
//
|
|
//
|