json-partial.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #include "json-partial.h"
  2. #include "log.h"
  3. #include <nlohmann/json.hpp>
  4. #include <string>
  5. #include <regex>
  6. using json = nlohmann::ordered_json;
  7. enum common_json_stack_element_type {
  8. COMMON_JSON_STACK_ELEMENT_OBJECT,
  9. COMMON_JSON_STACK_ELEMENT_KEY,
  10. COMMON_JSON_STACK_ELEMENT_ARRAY,
  11. };
  12. struct common_json_stack_element {
  13. common_json_stack_element_type type;
  14. std::string key;
  15. };
  16. bool common_json_parse(
  17. const std::string & input,
  18. const std::string & healing_marker,
  19. common_json & out)
  20. {
  21. std::string::const_iterator it = input.begin();
  22. const auto end = input.end();
  23. return common_json_parse(it, end, healing_marker, out);
  24. }
  25. bool common_json_parse(
  26. std::string::const_iterator & it,
  27. const std::string::const_iterator & end,
  28. const std::string & healing_marker,
  29. common_json & out)
  30. {
  31. // // https://json.nlohmann.me/features/parsing/sax_interface/
  32. struct json_error_locator : public nlohmann::json_sax<json> {
  33. std::size_t position;
  34. bool found_error;
  35. std::string last_token;
  36. std::string exception_message;
  37. std::vector<common_json_stack_element> stack;
  38. json_error_locator() : position(0), found_error(false) {}
  39. bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
  40. this->position = position - 1;
  41. this->found_error = true;
  42. this->last_token = last_token;
  43. this->exception_message = ex.what();
  44. return false;
  45. }
  46. void close_value() {
  47. if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
  48. stack.pop_back();
  49. }
  50. }
  51. bool null() override { // NOLINT
  52. close_value();
  53. return true;
  54. }
  55. bool boolean(bool) override { // NOLINT
  56. close_value();
  57. return true;
  58. }
  59. bool number_integer(number_integer_t) override { // NOLINT
  60. close_value();
  61. return true;
  62. }
  63. bool number_unsigned(number_unsigned_t) override { // NOLINT
  64. close_value();
  65. return true;
  66. }
  67. bool number_float(number_float_t, const string_t &) override { // NOLINT
  68. close_value();
  69. return true;
  70. }
  71. bool string(string_t &) override { // NOLINT
  72. close_value();
  73. return true;
  74. }
  75. bool binary(binary_t &) override { // NOLINT
  76. close_value();
  77. return true;
  78. }
  79. bool start_object(std::size_t) override { // NOLINT
  80. stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
  81. return true;
  82. }
  83. bool end_object() override {
  84. GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
  85. stack.pop_back();
  86. close_value();
  87. return true;
  88. }
  89. bool key(string_t & key) override { // NOLINT
  90. stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
  91. return true;
  92. }
  93. bool start_array(std::size_t) override { // NOLINT
  94. stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
  95. return true;
  96. }
  97. bool end_array() override {
  98. GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
  99. stack.pop_back();
  100. close_value();
  101. return true;
  102. }
  103. };
  104. json_error_locator err_loc;
  105. auto start = it;
  106. json::sax_parse(it, end, &err_loc);
  107. if (err_loc.found_error) {
  108. it = start;
  109. auto temptative_end = it + err_loc.position;
  110. // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
  111. auto input = std::string(it, temptative_end);
  112. try {
  113. out.json = json::parse(input);
  114. // out.json = json::parse(it, temptative_end);
  115. it = temptative_end;
  116. return true;
  117. } catch (const std::exception & ex) {
  118. // No, needs healing.
  119. LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
  120. }
  121. auto can_parse = [](const std::string & str) {
  122. try {
  123. auto _ = json::parse(str); // NOLINT
  124. return true;
  125. } catch (const std::exception &) {
  126. return false;
  127. }
  128. };
  129. if (!healing_marker.empty() && !err_loc.stack.empty()) {
  130. std::string str(it, temptative_end);
  131. auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
  132. if (last_non_sp_pos == std::string::npos) {
  133. throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
  134. }
  135. auto last_non_sp_char = str[last_non_sp_pos];
  136. // Used to detect stops on a number, which may not be complete.
  137. auto was_maybe_number = [&]() {
  138. if (!str.empty() && std::isspace(str.back())) {
  139. return false;
  140. }
  141. return std::isdigit(last_non_sp_char) ||
  142. last_non_sp_char == '.' ||
  143. last_non_sp_char == 'e' ||
  144. last_non_sp_char == 'E' ||
  145. last_non_sp_char == '-';
  146. };
  147. std::string closing;
  148. for (size_t i = err_loc.stack.size(); i > 0; i--) {
  149. auto & el = err_loc.stack[i - 1];
  150. if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
  151. closing += "}";
  152. } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
  153. closing += "]";
  154. } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
  155. throw std::runtime_error("Unexpected stack element type");
  156. }
  157. }
  158. // Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
  159. static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
  160. auto is_high_surrogate = [&](const std::string & s) {
  161. // Check if a partial of a high surrogate (U+D800-U+DBFF)
  162. return s.length() >= 4 &&
  163. s[0] == '\\' && s[1] == 'u' &&
  164. std::tolower(s[2]) == 'd' &&
  165. (s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
  166. };
  167. // Initialize the unicode marker to a low surrogate to handle the edge case
  168. // where a high surrogate (U+D800-U+DBFF) is immediately followed by a
  169. // backslash (\)
  170. std::string unicode_marker_padding = "udc00";
  171. std::smatch last_unicode_seq;
  172. if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
  173. std::smatch second_last_seq;
  174. std::string prelude = str.substr(0, last_unicode_seq.position());
  175. // Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
  176. unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
  177. if (is_high_surrogate(last_unicode_seq.str())) {
  178. // If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
  179. unicode_marker_padding += "\\udc00";
  180. } else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
  181. if (is_high_surrogate(second_last_seq.str())) {
  182. // If this follows a high surrogate, pad it to be a low surrogate
  183. if (last_unicode_seq.length() == 2) {
  184. unicode_marker_padding = "dc00";
  185. } else if (last_unicode_seq.length() == 3) {
  186. unicode_marker_padding = "c00";
  187. } else {
  188. // The original unicode_marker_padding is already padded with 0s
  189. }
  190. }
  191. }
  192. }
  193. const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
  194. if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
  195. // We're inside an object value
  196. if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
  197. // Was about to create an object value
  198. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  199. } else if (can_parse(str + ": 1" + closing)) {
  200. str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
  201. } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
  202. // Was about to create an object
  203. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
  204. } else if (can_parse(str + "\"" + closing)) {
  205. // Was inside an object value string
  206. str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
  207. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
  208. // Was inside an object value string after an escape
  209. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
  210. } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
  211. // Was inside an object value string after a partial unicode escape
  212. str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
  213. } else {
  214. // find last :
  215. auto last_pos = str.find_last_of(':');
  216. if (last_pos == std::string::npos) {
  217. throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
  218. }
  219. // Cutting back to opening : for object value
  220. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  221. }
  222. } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
  223. if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
  224. // Was about to create an array value
  225. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  226. } else if (can_parse(str + "\"" + closing)) {
  227. // Was inside an array value string
  228. str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
  229. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
  230. // Was inside an array value string after an escape
  231. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
  232. } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
  233. // Was inside an array value string after a partial unicode escape
  234. str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
  235. } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
  236. // Had just finished a value
  237. str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
  238. } else {
  239. auto last_pos = str.find_last_of("[,");
  240. if (last_pos == std::string::npos) {
  241. throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
  242. }
  243. // Cutting back to last [ or , for array value
  244. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  245. }
  246. } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
  247. if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
  248. (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
  249. // Was about to create an object key+value
  250. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
  251. } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
  252. // Was about to create an object key+value
  253. str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
  254. } else if (can_parse(str + "\": 1" + closing)) {
  255. // Was inside an object key string
  256. str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
  257. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
  258. // Was inside an object key string after an escape
  259. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
  260. } else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
  261. // Was inside an object key string after a partial unicode escape
  262. str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
  263. } else {
  264. auto last_pos = str.find_last_of(':');
  265. if (last_pos == std::string::npos) {
  266. throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
  267. }
  268. // fprintf(stderr, "Cutting back to last : for object key+value\n");
  269. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  270. }
  271. } else {
  272. throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
  273. }
  274. // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
  275. out.json = json::parse(str);
  276. it = temptative_end;
  277. return true;
  278. }
  279. // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
  280. // fprintf(stderr, "Closing: TODO\n");
  281. return false;
  282. }
  283. out.json = json::parse(it, end);
  284. it = end;
  285. return true;
  286. }