|
@@ -117,7 +117,7 @@ struct slot_params {
|
|
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
|
|
int32_t n_predict = -1; // new tokens to predict
|
|
int32_t n_predict = -1; // new tokens to predict
|
|
|
- int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
|
|
|
|
|
|
+ int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
|
|
|
|
|
|
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
|
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
|
@@ -1382,6 +1382,7 @@ struct server_slot {
|
|
|
common_speculative * spec = nullptr;
|
|
common_speculative * spec = nullptr;
|
|
|
|
|
|
|
|
std::vector<common_adapter_lora_info> lora;
|
|
std::vector<common_adapter_lora_info> lora;
|
|
|
|
|
+ int32_t alora_invocation_start = -1;
|
|
|
|
|
|
|
|
// the index relative to completion multi-task request
|
|
// the index relative to completion multi-task request
|
|
|
size_t index = 0;
|
|
size_t index = 0;
|
|
@@ -1476,6 +1477,9 @@ struct server_slot {
|
|
|
// clear speculative decoding stats
|
|
// clear speculative decoding stats
|
|
|
n_draft_total = 0;
|
|
n_draft_total = 0;
|
|
|
n_draft_accepted = 0;
|
|
n_draft_accepted = 0;
|
|
|
|
|
+
|
|
|
|
|
+ // clear alora start
|
|
|
|
|
+ alora_invocation_start = -1;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
bool need_embd() const {
|
|
bool need_embd() const {
|
|
@@ -2367,11 +2371,65 @@ struct server_context {
|
|
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
|
|
|
|
|
|
|
if (!are_lora_equal(slot.params.lora, slot.lora)) {
|
|
if (!are_lora_equal(slot.params.lora, slot.lora)) {
|
|
|
- // if lora is changed, we cannot reuse cached tokens
|
|
|
|
|
- slot.cache_tokens.clear();
|
|
|
|
|
|
|
+ // if lora has changed, check to see if the cache should be cleared
|
|
|
|
|
+ if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
|
|
|
|
|
+ SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
|
|
|
|
|
+ slot.cache_tokens.clear();
|
|
|
|
|
+ } else {
|
|
|
|
|
+ SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
|
|
|
|
|
+ }
|
|
|
slot.lora = slot.params.lora;
|
|
slot.lora = slot.params.lora;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // if using alora, make sure it's only a single one requested and active
|
|
|
|
|
+ size_t alora_invocation_start = slot.prompt_tokens.size();
|
|
|
|
|
+ if (lora_all_alora(slot.lora)) {
|
|
|
|
|
+
|
|
|
|
|
+ const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
|
|
|
|
|
+ // TODO: This will error out if a user requests two aloras, but only
|
|
|
|
|
+ // provides the activation string for one. We could, instead search
|
|
|
|
|
+ // for all requested alora activation strings and then either keep
|
|
|
|
|
+ // only the last one, or reject if multiple are found.
|
|
|
|
|
+ if (enabled_ids.size() != 1) {
|
|
|
|
|
+ send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ const auto & lora = slot.lora[enabled_ids[0]].ptr;
|
|
|
|
|
+
|
|
|
|
|
+ // get the pointer and count for the invocation tokens
|
|
|
|
|
+ const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
|
|
|
|
|
+ const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
|
|
|
|
|
+
|
|
|
|
|
+ // scan backwards through the prompt tokens to find the last
|
|
|
|
|
+ // occurrence of the invocation sequence
|
|
|
|
|
+ int match_idx = static_cast<int>(n_invocation_tokens) - 1;
|
|
|
|
|
+ for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
|
|
|
|
|
+ // the token in this position matches the next token to find in
|
|
|
|
|
+ // the invocation sequence
|
|
|
|
|
+ if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
|
|
|
|
|
+ // if it's a full match, we've found the start
|
|
|
|
|
+ if (match_idx == 0) {
|
|
|
|
|
+ alora_invocation_start = i;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ // otherwise, check the next token in the sequence
|
|
|
|
|
+ --match_idx;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // no match in this position, so start looking over again
|
|
|
|
|
+ match_idx = static_cast<int>(n_invocation_tokens) - 1;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // if the activation string is not found, disable the alora
|
|
|
|
|
+ if (alora_invocation_start == slot.prompt_tokens.size()) {
|
|
|
|
|
+ SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
|
|
|
|
|
+ slot.lora[enabled_ids[0]].scale = 0.0f;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
|
|
|
|
|
+ slot.alora_invocation_start = alora_invocation_start;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (!slot.prompt_tokens.validate(ctx)) {
|
|
if (!slot.prompt_tokens.validate(ctx)) {
|
|
|
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
|
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
|
|
return false;
|
|
return false;
|
|
@@ -3247,6 +3305,8 @@ struct server_context {
|
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
|
|
|
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
|
|
|
+ float alora_scale = -1.0f;
|
|
|
|
|
+ size_t alora_disabled_id = 0;
|
|
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
|
for (auto & slot : slots) {
|
|
for (auto & slot : slots) {
|
|
|
// check if we can batch this slot with the previous one
|
|
// check if we can batch this slot with the previous one
|
|
@@ -3367,6 +3427,12 @@ struct server_context {
|
|
|
// reuse any previously computed tokens that are common with the new prompt
|
|
// reuse any previously computed tokens that are common with the new prompt
|
|
|
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
|
|
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
|
|
|
|
|
|
|
|
|
|
+ // if there is an alora invoked, don't cache after the invocation start
|
|
|
|
|
+ if (slot.alora_invocation_start >= 0) {
|
|
|
|
|
+ SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
|
|
|
|
|
+ slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
|
|
if (params_base.n_cache_reuse > 0) {
|
|
if (params_base.n_cache_reuse > 0) {
|
|
|
size_t head_c = slot.n_past; // cache
|
|
size_t head_c = slot.n_past; // cache
|
|
@@ -3539,6 +3605,20 @@ struct server_context {
|
|
|
slot.n_prompt_tokens_processed += n_pos;
|
|
slot.n_prompt_tokens_processed += n_pos;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // If using an alora, there may be uncached tokens that come
|
|
|
|
|
+ // before the invocation sequence. When this happens, the
|
|
|
|
|
+ // tokens before the invocation sequence need to be
|
|
|
|
|
+ // processed without the adpter in a separate batch, then
|
|
|
|
|
+ // the adapter needs to be enabled for the remaining tokens.
|
|
|
|
|
+ if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
|
|
|
|
|
+ SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
|
|
|
|
|
+ const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
|
|
|
|
+ GGML_ASSERT(enabled_loras.size() == 1);
|
|
|
|
|
+ alora_scale = slot.lora[enabled_loras[0]].scale;
|
|
|
|
|
+ slot.lora[enabled_loras[0]].scale = 0.0f;
|
|
|
|
|
+ alora_disabled_id = enabled_loras[0];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// add prompt tokens for processing in the current batch
|
|
// add prompt tokens for processing in the current batch
|
|
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
|
// get next token to process
|
|
// get next token to process
|
|
@@ -3547,6 +3627,14 @@ struct server_context {
|
|
|
break; // end of text chunk
|
|
break; // end of text chunk
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // if this is an alora request with pre-invocation
|
|
|
|
|
+ // tokens that are not cached, we need to stop filling
|
|
|
|
|
+ // this batch at those pre-invocation tokens.
|
|
|
|
|
+ if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
|
|
|
|
|
+ SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// embedding requires all tokens in the batch to be output
|
|
// embedding requires all tokens in the batch to be output
|
|
|
const bool need_embd = server_task_type_need_embd(slot.task_type);
|
|
const bool need_embd = server_task_type_need_embd(slot.task_type);
|
|
|
|
|
|
|
@@ -3605,6 +3693,13 @@ struct server_context {
|
|
|
// apply lora, only need to do it once per batch
|
|
// apply lora, only need to do it once per batch
|
|
|
common_set_adapter_lora(ctx, slot_batched->lora);
|
|
common_set_adapter_lora(ctx, slot_batched->lora);
|
|
|
|
|
|
|
|
|
|
+ // if the lora is temporarily disabled for an alora, re-enable it
|
|
|
|
|
+ // for next time
|
|
|
|
|
+ if (alora_scale > 0.0f) {
|
|
|
|
|
+ SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
|
|
|
|
|
+ slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
llama_set_embeddings(ctx, slot_batched->need_embd());
|
|
llama_set_embeddings(ctx, slot_batched->need_embd());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -4990,13 +5085,26 @@ int main(int argc, char ** argv) {
|
|
|
const auto & loras = ctx_server.params_base.lora_adapters;
|
|
const auto & loras = ctx_server.params_base.lora_adapters;
|
|
|
for (size_t i = 0; i < loras.size(); ++i) {
|
|
for (size_t i = 0; i < loras.size(); ++i) {
|
|
|
auto & lora = loras[i];
|
|
auto & lora = loras[i];
|
|
|
- result.push_back({
|
|
|
|
|
|
|
+ json entry = {
|
|
|
{"id", i},
|
|
{"id", i},
|
|
|
{"path", lora.path},
|
|
{"path", lora.path},
|
|
|
{"scale", lora.scale},
|
|
{"scale", lora.scale},
|
|
|
{"task_name", lora.task_name},
|
|
{"task_name", lora.task_name},
|
|
|
{"prompt_prefix", lora.prompt_prefix},
|
|
{"prompt_prefix", lora.prompt_prefix},
|
|
|
- });
|
|
|
|
|
|
|
+ };
|
|
|
|
|
+ std::string alora_invocation_string = "";
|
|
|
|
|
+ const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
|
|
|
|
|
+ std::vector<llama_token> alora_invocation_tokens;
|
|
|
|
|
+ if (n_alora_tokens) {
|
|
|
|
|
+ const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
|
|
|
|
|
+ for (uint64_t i = 0; i < n_alora_tokens; ++i) {
|
|
|
|
|
+ alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]);
|
|
|
|
|
+ alora_invocation_tokens.push_back(alora_tokens[i]);
|
|
|
|
|
+ }
|
|
|
|
|
+ entry["alora_invocation_string"] = alora_invocation_string;
|
|
|
|
|
+ entry["alora_invocation_tokens"] = alora_invocation_tokens;
|
|
|
|
|
+ }
|
|
|
|
|
+ result.push_back(std::move(entry));
|
|
|
}
|
|
}
|
|
|
res_ok(res, result);
|
|
res_ok(res, result);
|
|
|
res.status = 200; // HTTP OK
|
|
res.status = 200; // HTTP OK
|