|
|
@@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|
|
return env->NewStringUTF(result.str().c_str());
|
|
|
}
|
|
|
|
|
|
-extern "C"
|
|
|
-JNIEXPORT void JNICALL
|
|
|
-Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
|
|
- llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
|
|
-}
|
|
|
-
|
|
|
extern "C"
|
|
|
JNIEXPORT jlong JNICALL
|
|
|
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
|
|
@@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|
|
return reinterpret_cast<jlong>(batch);
|
|
|
}
|
|
|
|
|
|
+extern "C"
|
|
|
+JNIEXPORT void JNICALL
|
|
|
+Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
|
|
+ llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
|
|
+}
|
|
|
+
|
|
|
+extern "C"
|
|
|
+JNIEXPORT jlong JNICALL
|
|
|
+Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
|
|
+ auto sparams = llama_sampler_chain_default_params();
|
|
|
+ sparams.no_perf = true;
|
|
|
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
|
|
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
|
|
+
|
|
|
+ return reinterpret_cast<jlong>(smpl);
|
|
|
+}
|
|
|
+
|
|
|
+extern "C"
|
|
|
+JNIEXPORT void JNICALL
|
|
|
+Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
|
|
|
+ llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
|
|
|
+}
|
|
|
+
|
|
|
extern "C"
|
|
|
JNIEXPORT void JNICALL
|
|
|
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
|
|
|
@@ -380,14 +397,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|
|
JNIEnv * env,
|
|
|
jobject,
|
|
|
jlong context_pointer,
|
|
|
- jlong sampling_pointer,
|
|
|
jlong batch_pointer,
|
|
|
+ jlong sampler_pointer,
|
|
|
jint n_len,
|
|
|
jobject intvar_ncur
|
|
|
) {
|
|
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
|
|
- const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
|
|
|
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
|
+ const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
|
+ const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
|
|
const auto model = llama_get_model(context);
|
|
|
|
|
|
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
|
|
@@ -395,9 +412,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|
|
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
|
|
|
|
|
// sample the most likely token
|
|
|
- const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
|
|
|
+ const auto new_token_id = llama_sampler_sample(sampler, context, -1);
|
|
|
|
|
|
- llama_sampler_accept(sampling, new_token_id);
|
|
|
+ llama_sampler_accept(sampler, new_token_id);
|
|
|
|
|
|
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
|
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|