|
|
@@ -754,13 +754,13 @@ struct server_context {
|
|
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
|
default_generation_settings_for_props["seed"] = -1;
|
|
|
|
|
|
- // the update_slots() logic will always submit a maximum of n_batch tokens
|
|
|
+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
|
{
|
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
|
|
|
|
// only a single seq_id per token is needed
|
|
|
- batch = llama_batch_init(n_batch, 0, 1);
|
|
|
+ batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
|
|
}
|
|
|
|
|
|
metrics.init();
|
|
|
@@ -1137,28 +1137,19 @@ struct server_context {
|
|
|
if (!system_prompt.empty()) {
|
|
|
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
|
|
|
|
|
- llama_batch_clear(batch);
|
|
|
+ const int32_t n_batch = llama_n_batch(ctx);
|
|
|
+ const int32_t n_tokens_prompt = system_tokens.size();
|
|
|
|
|
|
- for (int i = 0; i < (int)system_tokens.size(); ++i) {
|
|
|
- llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
|
|
- }
|
|
|
+ for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
|
|
+ const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
|
|
|
|
|
- const int32_t n_batch = llama_n_batch(ctx);
|
|
|
+ llama_batch_clear(batch);
|
|
|
|
|
|
- for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
|
- const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
|
|
|
- llama_batch batch_view = {
|
|
|
- n_tokens,
|
|
|
- batch.token + i,
|
|
|
- nullptr,
|
|
|
- batch.pos + i,
|
|
|
- batch.n_seq_id + i,
|
|
|
- batch.seq_id + i,
|
|
|
- batch.logits + i,
|
|
|
- 0, 0, 0, // unused
|
|
|
- };
|
|
|
+ for (int32_t j = 0; j < n_tokens; ++j) {
|
|
|
+ llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
|
|
+ }
|
|
|
|
|
|
- if (llama_decode(ctx, batch_view) != 0) {
|
|
|
+ if (llama_decode(ctx, batch) != 0) {
|
|
|
LOG_ERROR("llama_decode() failed", {});
|
|
|
return;
|
|
|
}
|