llama-android.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. #include <android/log.h>
  2. #include <jni.h>
  3. #include <iomanip>
  4. #include <math.h>
  5. #include <string>
  6. #include <unistd.h>
  7. #include "llama.h"
  8. #include "common/common.h"
  9. // Write C++ code here.
  10. //
  11. // Do not forget to dynamically load the C++ library into your application.
  12. //
  13. // For instance,
  14. //
  15. // In MainActivity.java:
  16. // static {
  17. // System.loadLibrary("llama-android");
  18. // }
  19. //
  20. // Or, in MainActivity.kt:
  21. // companion object {
  22. // init {
  23. // System.loadLibrary("llama-android")
  24. // }
  25. // }
  26. #define TAG "llama-android.cpp"
  27. #define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
  28. #define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
  29. jclass la_int_var;
  30. jmethodID la_int_var_value;
  31. jmethodID la_int_var_inc;
  32. std::string cached_token_chars;
  33. bool is_valid_utf8(const char * string) {
  34. if (!string) {
  35. return true;
  36. }
  37. const unsigned char * bytes = (const unsigned char *)string;
  38. int num;
  39. while (*bytes != 0x00) {
  40. if ((*bytes & 0x80) == 0x00) {
  41. // U+0000 to U+007F
  42. num = 1;
  43. } else if ((*bytes & 0xE0) == 0xC0) {
  44. // U+0080 to U+07FF
  45. num = 2;
  46. } else if ((*bytes & 0xF0) == 0xE0) {
  47. // U+0800 to U+FFFF
  48. num = 3;
  49. } else if ((*bytes & 0xF8) == 0xF0) {
  50. // U+10000 to U+10FFFF
  51. num = 4;
  52. } else {
  53. return false;
  54. }
  55. bytes += 1;
  56. for (int i = 1; i < num; ++i) {
  57. if ((*bytes & 0xC0) != 0x80) {
  58. return false;
  59. }
  60. bytes += 1;
  61. }
  62. }
  63. return true;
  64. }
  65. static void log_callback(ggml_log_level level, const char * fmt, void * data) {
  66. if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
  67. else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
  68. else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
  69. else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
  70. }
  71. extern "C"
  72. JNIEXPORT jlong JNICALL
  73. Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) {
  74. llama_model_params model_params = llama_model_default_params();
  75. auto path_to_model = env->GetStringUTFChars(filename, 0);
  76. LOGi("Loading model from %s", path_to_model);
  77. auto model = llama_load_model_from_file(path_to_model, model_params);
  78. env->ReleaseStringUTFChars(filename, path_to_model);
  79. if (!model) {
  80. LOGe("load_model() failed");
  81. env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
  82. return 0;
  83. }
  84. return reinterpret_cast<jlong>(model);
  85. }
  86. extern "C"
  87. JNIEXPORT void JNICALL
  88. Java_com_example_llama_Llm_free_1model(JNIEnv *, jobject, jlong model) {
  89. llama_free_model(reinterpret_cast<llama_model *>(model));
  90. }
  91. extern "C"
  92. JNIEXPORT jlong JNICALL
  93. Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) {
  94. auto model = reinterpret_cast<llama_model *>(jmodel);
  95. if (!model) {
  96. LOGe("new_context(): model cannot be null");
  97. env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
  98. return 0;
  99. }
  100. int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
  101. LOGi("Using %d threads", n_threads);
  102. llama_context_params ctx_params = llama_context_default_params();
  103. ctx_params.seed = 1234;
  104. ctx_params.n_ctx = 2048;
  105. ctx_params.n_threads = n_threads;
  106. ctx_params.n_threads_batch = n_threads;
  107. llama_context * context = llama_new_context_with_model(model, ctx_params);
  108. if (!context) {
  109. LOGe("llama_new_context_with_model() returned null)");
  110. env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
  111. "llama_new_context_with_model() returned null)");
  112. return 0;
  113. }
  114. return reinterpret_cast<jlong>(context);
  115. }
  116. extern "C"
  117. JNIEXPORT void JNICALL
  118. Java_com_example_llama_Llm_free_1context(JNIEnv *, jobject, jlong context) {
  119. llama_free(reinterpret_cast<llama_context *>(context));
  120. }
  121. extern "C"
  122. JNIEXPORT void JNICALL
  123. Java_com_example_llama_Llm_backend_1free(JNIEnv *, jobject) {
  124. llama_backend_free();
  125. }
  126. extern "C"
  127. JNIEXPORT void JNICALL
  128. Java_com_example_llama_Llm_log_1to_1android(JNIEnv *, jobject) {
  129. llama_log_set(log_callback, NULL);
  130. }
  131. extern "C"
  132. JNIEXPORT jstring JNICALL
  133. Java_com_example_llama_Llm_bench_1model(
  134. JNIEnv *env,
  135. jobject,
  136. jlong context_pointer,
  137. jlong model_pointer,
  138. jlong batch_pointer,
  139. jint pp,
  140. jint tg,
  141. jint pl,
  142. jint nr
  143. ) {
  144. auto pp_avg = 0.0;
  145. auto tg_avg = 0.0;
  146. auto pp_std = 0.0;
  147. auto tg_std = 0.0;
  148. const auto context = reinterpret_cast<llama_context *>(context_pointer);
  149. const auto model = reinterpret_cast<llama_model *>(model_pointer);
  150. const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
  151. const int n_ctx = llama_n_ctx(context);
  152. LOGi("n_ctx = %d", n_ctx);
  153. int i, j;
  154. int nri;
  155. for (nri = 0; nri < nr; nri++) {
  156. LOGi("Benchmark prompt processing (pp)");
  157. llama_batch_clear(*batch);
  158. const int n_tokens = pp;
  159. for (i = 0; i < n_tokens; i++) {
  160. llama_batch_add(*batch, 0, i, { 0 }, false);
  161. }
  162. batch->logits[batch->n_tokens - 1] = true;
  163. llama_kv_cache_clear(context);
  164. const auto t_pp_start = ggml_time_us();
  165. if (llama_decode(context, *batch) != 0) {
  166. LOGi("llama_decode() failed during prompt processing");
  167. }
  168. const auto t_pp_end = ggml_time_us();
  169. // bench text generation
  170. LOGi("Benchmark text generation (tg)");
  171. llama_kv_cache_clear(context);
  172. const auto t_tg_start = ggml_time_us();
  173. for (i = 0; i < tg; i++) {
  174. llama_batch_clear(*batch);
  175. for (j = 0; j < pl; j++) {
  176. llama_batch_add(*batch, 0, i, { j }, true);
  177. }
  178. LOGi("llama_decode() text generation: %d", i);
  179. if (llama_decode(context, *batch) != 0) {
  180. LOGi("llama_decode() failed during text generation");
  181. }
  182. }
  183. const auto t_tg_end = ggml_time_us();
  184. llama_kv_cache_clear(context);
  185. const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
  186. const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
  187. const auto speed_pp = double(pp) / t_pp;
  188. const auto speed_tg = double(pl * tg) / t_tg;
  189. pp_avg += speed_pp;
  190. tg_avg += speed_tg;
  191. pp_std += speed_pp * speed_pp;
  192. tg_std += speed_tg * speed_tg;
  193. LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
  194. }
  195. pp_avg /= double(nr);
  196. tg_avg /= double(nr);
  197. if (nr > 1) {
  198. pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
  199. tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
  200. } else {
  201. pp_std = 0;
  202. tg_std = 0;
  203. }
  204. char model_desc[128];
  205. llama_model_desc(model, model_desc, sizeof(model_desc));
  206. const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
  207. const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
  208. const auto backend = "(Android)"; // TODO: What should this be?
  209. std::stringstream result;
  210. result << std::setprecision(2);
  211. result << "| model | size | params | backend | test | t/s |\n";
  212. result << "| --- | --- | --- | --- | --- | --- |\n";
  213. result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
  214. result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
  215. return env->NewStringUTF(result.str().c_str());
  216. }
  217. extern "C"
  218. JNIEXPORT void JNICALL
  219. Java_com_example_llama_Llm_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
  220. llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
  221. }
  222. extern "C"
  223. JNIEXPORT jlong JNICALL
  224. Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
  225. // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
  226. llama_batch *batch = new llama_batch {
  227. 0,
  228. nullptr,
  229. nullptr,
  230. nullptr,
  231. nullptr,
  232. nullptr,
  233. nullptr,
  234. 0,
  235. 0,
  236. 0,
  237. };
  238. if (embd) {
  239. batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
  240. } else {
  241. batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
  242. }
  243. batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
  244. batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
  245. batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
  246. for (int i = 0; i < n_tokens; ++i) {
  247. batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
  248. }
  249. batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
  250. return reinterpret_cast<jlong>(batch);
  251. }
  252. extern "C"
  253. JNIEXPORT void JNICALL
  254. Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject) {
  255. llama_backend_init();
  256. }
  257. extern "C"
  258. JNIEXPORT jstring JNICALL
  259. Java_com_example_llama_Llm_system_1info(JNIEnv *env, jobject) {
  260. return env->NewStringUTF(llama_print_system_info());
  261. }
  262. extern "C"
  263. JNIEXPORT jint JNICALL
  264. Java_com_example_llama_Llm_completion_1init(
  265. JNIEnv *env,
  266. jobject,
  267. jlong context_pointer,
  268. jlong batch_pointer,
  269. jstring jtext,
  270. jint n_len
  271. ) {
  272. cached_token_chars.clear();
  273. const auto text = env->GetStringUTFChars(jtext, 0);
  274. const auto context = reinterpret_cast<llama_context *>(context_pointer);
  275. const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
  276. const auto tokens_list = llama_tokenize(context, text, 1);
  277. auto n_ctx = llama_n_ctx(context);
  278. auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
  279. LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
  280. if (n_kv_req > n_ctx) {
  281. LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
  282. }
  283. for (auto id : tokens_list) {
  284. LOGi("%s", llama_token_to_piece(context, id).c_str());
  285. }
  286. llama_batch_clear(*batch);
  287. // evaluate the initial prompt
  288. for (auto i = 0; i < tokens_list.size(); i++) {
  289. llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
  290. }
  291. // llama_decode will output logits only for the last token of the prompt
  292. batch->logits[batch->n_tokens - 1] = true;
  293. if (llama_decode(context, *batch) != 0) {
  294. LOGe("llama_decode() failed");
  295. }
  296. env->ReleaseStringUTFChars(jtext, text);
  297. return batch->n_tokens;
  298. }
  299. extern "C"
  300. JNIEXPORT jstring JNICALL
  301. Java_com_example_llama_Llm_completion_1loop(
  302. JNIEnv * env,
  303. jobject,
  304. jlong context_pointer,
  305. jlong batch_pointer,
  306. jint n_len,
  307. jobject intvar_ncur
  308. ) {
  309. const auto context = reinterpret_cast<llama_context *>(context_pointer);
  310. const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
  311. const auto model = llama_get_model(context);
  312. if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
  313. if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
  314. if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
  315. auto n_vocab = llama_n_vocab(model);
  316. auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
  317. std::vector<llama_token_data> candidates;
  318. candidates.reserve(n_vocab);
  319. for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
  320. candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
  321. }
  322. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  323. // sample the most likely token
  324. const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
  325. const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
  326. if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
  327. return env->NewStringUTF("");
  328. }
  329. auto new_token_chars = llama_token_to_piece(context, new_token_id);
  330. cached_token_chars += new_token_chars;
  331. jstring new_token = nullptr;
  332. if (is_valid_utf8(cached_token_chars.c_str())) {
  333. new_token = env->NewStringUTF(cached_token_chars.c_str());
  334. LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
  335. cached_token_chars.clear();
  336. } else {
  337. new_token = env->NewStringUTF("");
  338. }
  339. llama_batch_clear(*batch);
  340. llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
  341. env->CallVoidMethod(intvar_ncur, la_int_var_inc);
  342. if (llama_decode(context, *batch) != 0) {
  343. LOGe("llama_decode() returned null");
  344. }
  345. return new_token;
  346. }
  347. extern "C"
  348. JNIEXPORT void JNICALL
  349. Java_com_example_llama_Llm_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
  350. llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
  351. }