|
|
@@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
|
|
|
params.n_batch = params.n_ctx;
|
|
|
}
|
|
|
|
|
|
- // For non-causal models, batch size must be equal to ubatch size
|
|
|
- params.n_ubatch = params.n_batch;
|
|
|
+ // for non-causal models, batch size must be equal to ubatch size
|
|
|
+ if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
|
|
|
+ params.n_ubatch = params.n_batch;
|
|
|
+ }
|
|
|
+
|
|
|
+ // get max number of sequences per batch
|
|
|
+ const int n_seq_max = llama_max_parallel_sequences();
|
|
|
|
|
|
llama_backend_init();
|
|
|
llama_numa_init(params.numa);
|
|
|
@@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
|
|
|
// get added sep and eos token, if any
|
|
|
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
|
|
|
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
|
|
|
+ const char * rerank_prompt = llama_model_chat_template(model, "rerank");
|
|
|
|
|
|
// tokenize the prompts and trim
|
|
|
std::vector<std::vector<int32_t>> inputs;
|
|
|
@@ -153,21 +159,28 @@ int main(int argc, char ** argv) {
|
|
|
// split classification pairs and insert expected separator tokens
|
|
|
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
|
|
|
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
|
|
|
- std::string final_prompt;
|
|
|
-
|
|
|
- for (size_t i = 0; i < pairs.size(); i++) {
|
|
|
- final_prompt += pairs[i];
|
|
|
- if (i != pairs.size() - 1) {
|
|
|
- if (!added_eos_token.empty()) {
|
|
|
- final_prompt += added_eos_token;
|
|
|
- }
|
|
|
- if (!added_sep_token.empty()) {
|
|
|
- final_prompt += added_sep_token;
|
|
|
+ if (rerank_prompt != nullptr) {
|
|
|
+ const std::string query = pairs[0];
|
|
|
+ const std::string doc = pairs[1];
|
|
|
+ std::string final_prompt = rerank_prompt;
|
|
|
+ string_replace_all(final_prompt, "{query}" , query);
|
|
|
+ string_replace_all(final_prompt, "{document}", doc );
|
|
|
+ inp = common_tokenize(vocab, final_prompt, true, true);
|
|
|
+ } else {
|
|
|
+ std::string final_prompt;
|
|
|
+ for (size_t i = 0; i < pairs.size(); i++) {
|
|
|
+ final_prompt += pairs[i];
|
|
|
+ if (i != pairs.size() - 1) {
|
|
|
+ if (!added_eos_token.empty()) {
|
|
|
+ final_prompt += added_eos_token;
|
|
|
+ }
|
|
|
+ if (!added_sep_token.empty()) {
|
|
|
+ final_prompt += added_sep_token;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+ inp = common_tokenize(ctx, final_prompt, true, true);
|
|
|
}
|
|
|
-
|
|
|
- inp = common_tokenize(ctx, final_prompt, true, true);
|
|
|
} else {
|
|
|
inp = common_tokenize(ctx, prompt, true, true);
|
|
|
}
|
|
|
@@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
|
|
|
const uint64_t n_toks = inp.size();
|
|
|
|
|
|
// encode if at capacity
|
|
|
- if (batch.n_tokens + n_toks > n_batch) {
|
|
|
+ if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
|
|
|
float * out = emb + e * n_embd;
|
|
|
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
|
|
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
|