json-partial.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #include "json-partial.h"
  2. #include "log.h"
  3. #include <nlohmann/json.hpp>
  4. #include <string>
  5. using json = nlohmann::ordered_json;
  6. enum common_json_stack_element_type {
  7. COMMON_JSON_STACK_ELEMENT_OBJECT,
  8. COMMON_JSON_STACK_ELEMENT_KEY,
  9. COMMON_JSON_STACK_ELEMENT_ARRAY,
  10. };
  11. struct common_json_stack_element {
  12. common_json_stack_element_type type;
  13. std::string key;
  14. };
  15. bool common_json_parse(
  16. const std::string & input,
  17. const std::string & healing_marker,
  18. common_json & out)
  19. {
  20. std::string::const_iterator it = input.begin();
  21. const auto end = input.end();
  22. return common_json_parse(it, end, healing_marker, out);
  23. }
  24. bool common_json_parse(
  25. std::string::const_iterator & it,
  26. const std::string::const_iterator & end,
  27. const std::string & healing_marker,
  28. common_json & out)
  29. {
  30. // // https://json.nlohmann.me/features/parsing/sax_interface/
  31. struct json_error_locator : public nlohmann::json_sax<json> {
  32. std::size_t position;
  33. bool found_error;
  34. std::string last_token;
  35. std::string exception_message;
  36. std::vector<common_json_stack_element> stack;
  37. json_error_locator() : position(0), found_error(false) {}
  38. bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
  39. this->position = position - 1;
  40. this->found_error = true;
  41. this->last_token = last_token;
  42. this->exception_message = ex.what();
  43. return false;
  44. }
  45. void close_value() {
  46. if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
  47. stack.pop_back();
  48. }
  49. }
  50. bool null() override { // NOLINT
  51. close_value();
  52. return true;
  53. }
  54. bool boolean(bool) override { // NOLINT
  55. close_value();
  56. return true;
  57. }
  58. bool number_integer(number_integer_t) override { // NOLINT
  59. close_value();
  60. return true;
  61. }
  62. bool number_unsigned(number_unsigned_t) override { // NOLINT
  63. close_value();
  64. return true;
  65. }
  66. bool number_float(number_float_t, const string_t &) override { // NOLINT
  67. close_value();
  68. return true;
  69. }
  70. bool string(string_t &) override { // NOLINT
  71. close_value();
  72. return true;
  73. }
  74. bool binary(binary_t &) override { // NOLINT
  75. close_value();
  76. return true;
  77. }
  78. bool start_object(std::size_t) override { // NOLINT
  79. stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
  80. return true;
  81. }
  82. bool end_object() override {
  83. GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
  84. stack.pop_back();
  85. close_value();
  86. return true;
  87. }
  88. bool key(string_t & key) override { // NOLINT
  89. stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
  90. return true;
  91. }
  92. bool start_array(std::size_t) override { // NOLINT
  93. stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
  94. return true;
  95. }
  96. bool end_array() override {
  97. GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
  98. stack.pop_back();
  99. close_value();
  100. return true;
  101. }
  102. };
  103. json_error_locator err_loc;
  104. auto start = it;
  105. json::sax_parse(it, end, &err_loc);
  106. if (err_loc.found_error) {
  107. it = start;
  108. auto temptative_end = it + err_loc.position;
  109. // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
  110. auto input = std::string(it, temptative_end);
  111. try {
  112. out.json = json::parse(input);
  113. // out.json = json::parse(it, temptative_end);
  114. it = temptative_end;
  115. return true;
  116. } catch (const std::exception & ex) {
  117. // No, needs healing.
  118. LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
  119. }
  120. auto can_parse = [](const std::string & str) {
  121. try {
  122. auto _ = json::parse(str); // NOLINT
  123. return true;
  124. } catch (const std::exception &) {
  125. return false;
  126. }
  127. };
  128. if (!healing_marker.empty() && !err_loc.stack.empty()) {
  129. std::string str(it, temptative_end);
  130. auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
  131. if (last_non_sp_pos == std::string::npos) {
  132. throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
  133. }
  134. auto last_non_sp_char = str[last_non_sp_pos];
  135. // Used to detect stops on a number, which may not be complete.
  136. auto was_maybe_number = [&]() {
  137. if (!str.empty() && std::isspace(str.back())) {
  138. return false;
  139. }
  140. return std::isdigit(last_non_sp_char) ||
  141. last_non_sp_char == '.' ||
  142. last_non_sp_char == 'e' ||
  143. last_non_sp_char == 'E' ||
  144. last_non_sp_char == '-';
  145. };
  146. std::string closing;
  147. for (size_t i = err_loc.stack.size(); i > 0; i--) {
  148. auto & el = err_loc.stack[i - 1];
  149. if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
  150. closing += "}";
  151. } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
  152. closing += "]";
  153. } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
  154. throw std::runtime_error("Unexpected stack element type");
  155. }
  156. }
  157. const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
  158. if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
  159. // We're inside an object value
  160. if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
  161. // Was about to create an object value
  162. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  163. } else if (can_parse(str + ": 1" + closing)) {
  164. str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
  165. } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
  166. // Was about to create an object
  167. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
  168. } else if (can_parse(str + "\"" + closing)) {
  169. // Was inside an object value string
  170. str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
  171. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
  172. // Was inside an object value string after an escape
  173. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
  174. } else {
  175. // find last :
  176. auto last_pos = str.find_last_of(':');
  177. if (last_pos == std::string::npos) {
  178. throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
  179. }
  180. // Cutting back to opening : for object value
  181. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  182. }
  183. } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
  184. if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
  185. // Was about to create an array value
  186. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  187. } else if (can_parse(str + "\"" + closing)) {
  188. // Was inside an array value string
  189. str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
  190. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
  191. // Was inside an array value string after an escape
  192. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
  193. } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
  194. // Had just finished a value
  195. str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
  196. } else {
  197. auto last_pos = str.find_last_of("[,");
  198. if (last_pos == std::string::npos) {
  199. throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
  200. }
  201. // Cutting back to last [ or , for array value
  202. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  203. }
  204. } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
  205. if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
  206. (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
  207. // Was about to create an object key+value
  208. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
  209. } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
  210. // Was about to create an object key+value
  211. str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
  212. } else if (can_parse(str + "\": 1" + closing)) {
  213. // Was inside an object key string
  214. str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
  215. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
  216. // Was inside an object key string after an escape
  217. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
  218. } else {
  219. auto last_pos = str.find_last_of(':');
  220. if (last_pos == std::string::npos) {
  221. throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
  222. }
  223. // fprintf(stderr, "Cutting back to last : for object key+value\n");
  224. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  225. }
  226. } else {
  227. throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
  228. }
  229. // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
  230. out.json = json::parse(str);
  231. it = temptative_end;
  232. return true;
  233. }
  234. // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
  235. // fprintf(stderr, "Closing: TODO\n");
  236. return false;
  237. }
  238. out.json = json::parse(it, end);
  239. it = end;
  240. return true;
  241. }