| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452 |
- #include <android/log.h>
- #include <jni.h>
- #include <iomanip>
- #include <math.h>
- #include <string>
- #include <unistd.h>
- #include "llama.h"
- #include "common.h"
- // Write C++ code here.
- //
- // Do not forget to dynamically load the C++ library into your application.
- //
- // For instance,
- //
- // In MainActivity.java:
- // static {
- // System.loadLibrary("llama-android");
- // }
- //
- // Or, in MainActivity.kt:
- // companion object {
- // init {
- // System.loadLibrary("llama-android")
- // }
- // }
- #define TAG "llama-android.cpp"
- #define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
- #define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
- jclass la_int_var;
- jmethodID la_int_var_value;
- jmethodID la_int_var_inc;
- std::string cached_token_chars;
- bool is_valid_utf8(const char * string) {
- if (!string) {
- return true;
- }
- const unsigned char * bytes = (const unsigned char *)string;
- int num;
- while (*bytes != 0x00) {
- if ((*bytes & 0x80) == 0x00) {
- // U+0000 to U+007F
- num = 1;
- } else if ((*bytes & 0xE0) == 0xC0) {
- // U+0080 to U+07FF
- num = 2;
- } else if ((*bytes & 0xF0) == 0xE0) {
- // U+0800 to U+FFFF
- num = 3;
- } else if ((*bytes & 0xF8) == 0xF0) {
- // U+10000 to U+10FFFF
- num = 4;
- } else {
- return false;
- }
- bytes += 1;
- for (int i = 1; i < num; ++i) {
- if ((*bytes & 0xC0) != 0x80) {
- return false;
- }
- bytes += 1;
- }
- }
- return true;
- }
- static void log_callback(ggml_log_level level, const char * fmt, void * data) {
- if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
- else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
- else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
- else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
- }
- extern "C"
- JNIEXPORT jlong JNICALL
- Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
- llama_model_params model_params = llama_model_default_params();
- auto path_to_model = env->GetStringUTFChars(filename, 0);
- LOGi("Loading model from %s", path_to_model);
- auto model = llama_load_model_from_file(path_to_model, model_params);
- env->ReleaseStringUTFChars(filename, path_to_model);
- if (!model) {
- LOGe("load_model() failed");
- env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
- return 0;
- }
- return reinterpret_cast<jlong>(model);
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
- llama_free_model(reinterpret_cast<llama_model *>(model));
- }
- extern "C"
- JNIEXPORT jlong JNICALL
- Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
- auto model = reinterpret_cast<llama_model *>(jmodel);
- if (!model) {
- LOGe("new_context(): model cannot be null");
- env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
- return 0;
- }
- int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
- LOGi("Using %d threads", n_threads);
- llama_context_params ctx_params = llama_context_default_params();
- ctx_params.n_ctx = 2048;
- ctx_params.n_threads = n_threads;
- ctx_params.n_threads_batch = n_threads;
- llama_context * context = llama_new_context_with_model(model, ctx_params);
- if (!context) {
- LOGe("llama_new_context_with_model() returned null)");
- env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
- "llama_new_context_with_model() returned null)");
- return 0;
- }
- return reinterpret_cast<jlong>(context);
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
- llama_free(reinterpret_cast<llama_context *>(context));
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
- llama_backend_free();
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
- llama_log_set(log_callback, NULL);
- }
- extern "C"
- JNIEXPORT jstring JNICALL
- Java_android_llama_cpp_LLamaAndroid_bench_1model(
- JNIEnv *env,
- jobject,
- jlong context_pointer,
- jlong model_pointer,
- jlong batch_pointer,
- jint pp,
- jint tg,
- jint pl,
- jint nr
- ) {
- auto pp_avg = 0.0;
- auto tg_avg = 0.0;
- auto pp_std = 0.0;
- auto tg_std = 0.0;
- const auto context = reinterpret_cast<llama_context *>(context_pointer);
- const auto model = reinterpret_cast<llama_model *>(model_pointer);
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
- const int n_ctx = llama_n_ctx(context);
- LOGi("n_ctx = %d", n_ctx);
- int i, j;
- int nri;
- for (nri = 0; nri < nr; nri++) {
- LOGi("Benchmark prompt processing (pp)");
- llama_batch_clear(*batch);
- const int n_tokens = pp;
- for (i = 0; i < n_tokens; i++) {
- llama_batch_add(*batch, 0, i, { 0 }, false);
- }
- batch->logits[batch->n_tokens - 1] = true;
- llama_kv_cache_clear(context);
- const auto t_pp_start = ggml_time_us();
- if (llama_decode(context, *batch) != 0) {
- LOGi("llama_decode() failed during prompt processing");
- }
- const auto t_pp_end = ggml_time_us();
- // bench text generation
- LOGi("Benchmark text generation (tg)");
- llama_kv_cache_clear(context);
- const auto t_tg_start = ggml_time_us();
- for (i = 0; i < tg; i++) {
- llama_batch_clear(*batch);
- for (j = 0; j < pl; j++) {
- llama_batch_add(*batch, 0, i, { j }, true);
- }
- LOGi("llama_decode() text generation: %d", i);
- if (llama_decode(context, *batch) != 0) {
- LOGi("llama_decode() failed during text generation");
- }
- }
- const auto t_tg_end = ggml_time_us();
- llama_kv_cache_clear(context);
- const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
- const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
- const auto speed_pp = double(pp) / t_pp;
- const auto speed_tg = double(pl * tg) / t_tg;
- pp_avg += speed_pp;
- tg_avg += speed_tg;
- pp_std += speed_pp * speed_pp;
- tg_std += speed_tg * speed_tg;
- LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
- }
- pp_avg /= double(nr);
- tg_avg /= double(nr);
- if (nr > 1) {
- pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
- tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
- } else {
- pp_std = 0;
- tg_std = 0;
- }
- char model_desc[128];
- llama_model_desc(model, model_desc, sizeof(model_desc));
- const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
- const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
- const auto backend = "(Android)"; // TODO: What should this be?
- std::stringstream result;
- result << std::setprecision(2);
- result << "| model | size | params | backend | test | t/s |\n";
- result << "| --- | --- | --- | --- | --- | --- |\n";
- result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
- result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
- return env->NewStringUTF(result.str().c_str());
- }
- extern "C"
- JNIEXPORT jlong JNICALL
- Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
- // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
- llama_batch *batch = new llama_batch {
- 0,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- 0,
- 0,
- 0,
- };
- if (embd) {
- batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
- } else {
- batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
- }
- batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
- batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
- batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
- for (int i = 0; i < n_tokens; ++i) {
- batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
- }
- batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
- return reinterpret_cast<jlong>(batch);
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
- llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
- }
- extern "C"
- JNIEXPORT jlong JNICALL
- Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
- auto sparams = llama_sampler_chain_default_params();
- sparams.no_perf = true;
- llama_sampler * smpl = llama_sampler_chain_init(sparams);
- llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
- return reinterpret_cast<jlong>(smpl);
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
- llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
- llama_backend_init();
- }
- extern "C"
- JNIEXPORT jstring JNICALL
- Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
- return env->NewStringUTF(llama_print_system_info());
- }
- extern "C"
- JNIEXPORT jint JNICALL
- Java_android_llama_cpp_LLamaAndroid_completion_1init(
- JNIEnv *env,
- jobject,
- jlong context_pointer,
- jlong batch_pointer,
- jstring jtext,
- jint n_len
- ) {
- cached_token_chars.clear();
- const auto text = env->GetStringUTFChars(jtext, 0);
- const auto context = reinterpret_cast<llama_context *>(context_pointer);
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
- const auto tokens_list = llama_tokenize(context, text, 1);
- auto n_ctx = llama_n_ctx(context);
- auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
- LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
- if (n_kv_req > n_ctx) {
- LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
- }
- for (auto id : tokens_list) {
- LOGi("%s", llama_token_to_piece(context, id).c_str());
- }
- llama_batch_clear(*batch);
- // evaluate the initial prompt
- for (auto i = 0; i < tokens_list.size(); i++) {
- llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
- }
- // llama_decode will output logits only for the last token of the prompt
- batch->logits[batch->n_tokens - 1] = true;
- if (llama_decode(context, *batch) != 0) {
- LOGe("llama_decode() failed");
- }
- env->ReleaseStringUTFChars(jtext, text);
- return batch->n_tokens;
- }
- extern "C"
- JNIEXPORT jstring JNICALL
- Java_android_llama_cpp_LLamaAndroid_completion_1loop(
- JNIEnv * env,
- jobject,
- jlong context_pointer,
- jlong batch_pointer,
- jlong sampler_pointer,
- jint n_len,
- jobject intvar_ncur
- ) {
- const auto context = reinterpret_cast<llama_context *>(context_pointer);
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
- const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
- const auto model = llama_get_model(context);
- if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
- if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
- if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
- // sample the most likely token
- const auto new_token_id = llama_sampler_sample(sampler, context, -1);
- llama_sampler_accept(sampler, new_token_id);
- const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
- if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
- return nullptr;
- }
- auto new_token_chars = llama_token_to_piece(context, new_token_id);
- cached_token_chars += new_token_chars;
- jstring new_token = nullptr;
- if (is_valid_utf8(cached_token_chars.c_str())) {
- new_token = env->NewStringUTF(cached_token_chars.c_str());
- LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
- cached_token_chars.clear();
- } else {
- new_token = env->NewStringUTF("");
- }
- llama_batch_clear(*batch);
- llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
- env->CallVoidMethod(intvar_ncur, la_int_var_inc);
- if (llama_decode(context, *batch) != 0) {
- LOGe("llama_decode() returned null");
- }
- return new_token;
- }
- extern "C"
- JNIEXPORT void JNICALL
- Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
- llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
- }
|