|
|
@@ -33,6 +33,45 @@ jclass la_int_var;
|
|
|
jmethodID la_int_var_value;
|
|
|
jmethodID la_int_var_inc;
|
|
|
|
|
|
+std::string cached_token_chars;
|
|
|
+
|
|
|
+bool is_valid_utf8(const char * string) {
|
|
|
+ if (!string) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ const unsigned char * bytes = (const unsigned char *)string;
|
|
|
+ int num;
|
|
|
+
|
|
|
+ while (*bytes != 0x00) {
|
|
|
+ if ((*bytes & 0x80) == 0x00) {
|
|
|
+ // U+0000 to U+007F
|
|
|
+ num = 1;
|
|
|
+ } else if ((*bytes & 0xE0) == 0xC0) {
|
|
|
+ // U+0080 to U+07FF
|
|
|
+ num = 2;
|
|
|
+ } else if ((*bytes & 0xF0) == 0xE0) {
|
|
|
+ // U+0800 to U+FFFF
|
|
|
+ num = 3;
|
|
|
+ } else if ((*bytes & 0xF8) == 0xF0) {
|
|
|
+ // U+10000 to U+10FFFF
|
|
|
+ num = 4;
|
|
|
+ } else {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ bytes += 1;
|
|
|
+ for (int i = 1; i < num; ++i) {
|
|
|
+ if ((*bytes & 0xC0) != 0x80) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ bytes += 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
|
|
|
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
|
|
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
|
|
@@ -295,6 +334,8 @@ Java_com_example_llama_Llm_completion_1init(
|
|
|
jint n_len
|
|
|
) {
|
|
|
|
|
|
+ cached_token_chars.clear();
|
|
|
+
|
|
|
const auto text = env->GetStringUTFChars(jtext, 0);
|
|
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
|
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
|
@@ -372,8 +413,16 @@ Java_com_example_llama_Llm_completion_1loop(
|
|
|
}
|
|
|
|
|
|
auto new_token_chars = llama_token_to_piece(context, new_token_id);
|
|
|
- LOGi("new_token_chars: `%s`", new_token_chars.c_str());
|
|
|
- auto new_token = env->NewStringUTF(new_token_chars.c_str());
|
|
|
+ cached_token_chars += new_token_chars;
|
|
|
+
|
|
|
+ jstring new_token = nullptr;
|
|
|
+ if (is_valid_utf8(cached_token_chars.c_str())) {
|
|
|
+ new_token = env->NewStringUTF(cached_token_chars.c_str());
|
|
|
+ LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
|
|
|
+ cached_token_chars.clear();
|
|
|
+ } else {
|
|
|
+ new_token = env->NewStringUTF("");
|
|
|
+ }
|
|
|
|
|
|
llama_batch_clear(*batch);
|
|
|
llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
|