|
|
@@ -222,8 +222,8 @@ int main(int argc, char ** argv) {
|
|
|
float * emb = embeddings.data();
|
|
|
|
|
|
// break into batches
|
|
|
- int p = 0; // number of prompts processed already
|
|
|
- int s = 0; // number of prompts in current batch
|
|
|
+ unsigned int p = 0; // number of prompts processed already
|
|
|
+ unsigned int s = 0; // number of prompts in current batch
|
|
|
for (int k = 0; k < n_chunks; k++) {
|
|
|
// clamp to n_batch tokens
|
|
|
auto & inp = chunks[k].tokens;
|
|
|
@@ -231,7 +231,7 @@ int main(int argc, char ** argv) {
|
|
|
const uint64_t n_toks = inp.size();
|
|
|
|
|
|
// encode if at capacity
|
|
|
- if (batch.n_tokens + n_toks > n_batch) {
|
|
|
+ if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
|
|
|
float * out = emb + p * n_embd;
|
|
|
batch_process(ctx, batch, out, s, n_embd);
|
|
|
common_batch_clear(batch);
|