|
|
@@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
|
|
|
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
|
|
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
|
|
|
|
|
- seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
- seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
+ seq_pos.resize(LLAMA_MAX_SEQ);
|
|
|
+ seq_cpl.resize(LLAMA_MAX_SEQ);
|
|
|
for (auto & cur : seq_cpl) {
|
|
|
- cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
+ cur.resize(LLAMA_MAX_SEQ);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
|
|
|
if (batch.seq_id) {
|
|
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
|
- if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
|
|
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
|
|
|
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
|
|
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
@@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
|
|
|
pos.resize(batch.n_tokens);
|
|
|
|
|
|
// initialize the starting position for each sequence based on the positions in the memory
|
|
|
- llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
|
|
|
- for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
+ llama_pos p0[LLAMA_MAX_SEQ];
|
|
|
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (!memory) {
|
|
|
p0[s] = 0;
|
|
|
} else {
|
|
|
@@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
|
|
|
// consistency checks
|
|
|
//
|
|
|
|
|
|
- for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
|
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
|
if (seq_pos[s].empty()) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
|
|
|
}
|
|
|
|
|
|
if (memory) {
|
|
|
- for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
|
|
|
- for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
|
|
|
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
|
|
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
|
|
if (seq_cpl[s0][s1]) {
|
|
|
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
|
|
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|