|
|
@@ -73,8 +73,18 @@ int main(int argc, char ** argv, char ** envp) {
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
+ // validate batch size for embeddings
|
|
|
+ // embeddings require all tokens to be processed in a single ubatch
|
|
|
+ // see https://github.com/ggml-org/llama.cpp/issues/12836
|
|
|
+ if (params.embedding && params.n_batch > params.n_ubatch) {
|
|
|
+ LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
|
|
|
+ LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
|
|
|
+ params.n_batch = params.n_ubatch;
|
|
|
+ }
|
|
|
+
|
|
|
if (params.n_parallel < 0) {
|
|
|
LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
|
|
|
+
|
|
|
params.n_parallel = 4;
|
|
|
params.kv_unified = true;
|
|
|
}
|