Browse Source

retrieval : use at most n_seq_max chunks (#18400)

Héctor Estrada Moreno 1 month ago
parent
commit
0c8986403b
1 changed files with 3 additions and 3 deletions
  1. 3 3
      examples/retrieval/retrieval.cpp

+ 3 - 3
examples/retrieval/retrieval.cpp

@@ -222,8 +222,8 @@ int main(int argc, char ** argv) {
     float * emb = embeddings.data();
 
     // break into batches
-    int p = 0; // number of prompts processed already
-    int s = 0; // number of prompts in current batch
+    unsigned int p = 0; // number of prompts processed already
+    unsigned int s = 0; // number of prompts in current batch
     for (int k = 0; k < n_chunks; k++) {
         // clamp to n_batch tokens
         auto & inp = chunks[k].tokens;
@@ -231,7 +231,7 @@ int main(int argc, char ** argv) {
         const uint64_t n_toks = inp.size();
 
         // encode if at capacity
-        if (batch.n_tokens + n_toks > n_batch) {
+        if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
             float * out = emb + p * n_embd;
             batch_process(ctx, batch, out, s, n_embd);
             common_batch_clear(batch);