1
0

json-partial.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #include <json-partial.h>
  2. #include "ggml.h"
  3. #include "log.h"
  4. #include <string>
  5. #include <json.hpp>
  6. using json = nlohmann::ordered_json;
  7. enum common_json_stack_element_type {
  8. COMMON_JSON_STACK_ELEMENT_OBJECT,
  9. COMMON_JSON_STACK_ELEMENT_KEY,
  10. COMMON_JSON_STACK_ELEMENT_ARRAY,
  11. };
  12. struct common_json_stack_element {
  13. common_json_stack_element_type type;
  14. std::string key;
  15. };
  16. bool common_json_parse(
  17. const std::string & input,
  18. const std::string & healing_marker,
  19. common_json & out)
  20. {
  21. std::string::const_iterator it = input.begin();
  22. const auto end = input.end();
  23. return common_json_parse(it, end, healing_marker, out);
  24. }
  25. bool common_json_parse(
  26. std::string::const_iterator & it,
  27. const std::string::const_iterator & end,
  28. const std::string & healing_marker,
  29. common_json & out)
  30. {
  31. // // https://json.nlohmann.me/features/parsing/sax_interface/
  32. struct json_error_locator : public nlohmann::json_sax<json> {
  33. std::size_t position;
  34. bool found_error;
  35. std::string last_token;
  36. std::string exception_message;
  37. std::vector<common_json_stack_element> stack;
  38. json_error_locator() : position(0), found_error(false) {}
  39. bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
  40. this->position = position - 1;
  41. this->found_error = true;
  42. this->last_token = last_token;
  43. this->exception_message = ex.what();
  44. return false;
  45. }
  46. void close_value() {
  47. if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
  48. stack.pop_back();
  49. }
  50. }
  51. bool null() override { // NOLINT
  52. close_value();
  53. return true;
  54. }
  55. bool boolean(bool) override { // NOLINT
  56. close_value();
  57. return true;
  58. }
  59. bool number_integer(number_integer_t) override { // NOLINT
  60. close_value();
  61. return true;
  62. }
  63. bool number_unsigned(number_unsigned_t) override { // NOLINT
  64. close_value();
  65. return true;
  66. }
  67. bool number_float(number_float_t, const string_t &) override { // NOLINT
  68. close_value();
  69. return true;
  70. }
  71. bool string(string_t &) override { // NOLINT
  72. close_value();
  73. return true;
  74. }
  75. bool binary(binary_t &) override { // NOLINT
  76. close_value();
  77. return true;
  78. }
  79. bool start_object(std::size_t) override { // NOLINT
  80. stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
  81. return true;
  82. }
  83. bool end_object() override {
  84. GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
  85. stack.pop_back();
  86. close_value();
  87. return true;
  88. }
  89. bool key(string_t & key) override { // NOLINT
  90. stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
  91. return true;
  92. }
  93. bool start_array(std::size_t) override { // NOLINT
  94. stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
  95. return true;
  96. }
  97. bool end_array() override {
  98. GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
  99. stack.pop_back();
  100. close_value();
  101. return true;
  102. }
  103. };
  104. json_error_locator err_loc;
  105. auto start = it;
  106. json::sax_parse(it, end, &err_loc);
  107. if (err_loc.found_error) {
  108. it = start;
  109. auto temptative_end = it + err_loc.position;
  110. // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
  111. auto input = std::string(it, temptative_end);
  112. try {
  113. out.json = json::parse(input);
  114. // out.json = json::parse(it, temptative_end);
  115. it = temptative_end;
  116. return true;
  117. } catch (const std::exception & ex) {
  118. // No, needs healing.
  119. LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
  120. }
  121. auto can_parse = [](const std::string & str) {
  122. try {
  123. auto _ = json::parse(str); // NOLINT
  124. return true;
  125. } catch (const std::exception &) {
  126. return false;
  127. }
  128. };
  129. if (!healing_marker.empty() && !err_loc.stack.empty()) {
  130. std::string str(it, temptative_end);
  131. auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
  132. if (last_non_sp_pos == std::string::npos) {
  133. throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
  134. }
  135. auto last_non_sp_char = str[last_non_sp_pos];
  136. // Used to detect stops on a number, which may not be complete.
  137. auto was_maybe_number = [&]() {
  138. if (!str.empty() && std::isspace(str.back())) {
  139. return false;
  140. }
  141. return std::isdigit(last_non_sp_char) ||
  142. last_non_sp_char == '.' ||
  143. last_non_sp_char == 'e' ||
  144. last_non_sp_char == 'E' ||
  145. last_non_sp_char == '-';
  146. };
  147. std::string closing;
  148. for (size_t i = err_loc.stack.size(); i > 0; i--) {
  149. auto & el = err_loc.stack[i - 1];
  150. if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
  151. closing += "}";
  152. } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
  153. closing += "]";
  154. } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
  155. throw std::runtime_error("Unexpected stack element type");
  156. }
  157. }
  158. const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
  159. if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
  160. // We're inside an object value
  161. if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
  162. // Was about to create an object value
  163. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  164. } else if (can_parse(str + ": 1" + closing)) {
  165. str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
  166. } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
  167. // Was about to create an object
  168. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
  169. } else if (can_parse(str + "\"" + closing)) {
  170. // Was inside an object value string
  171. str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
  172. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
  173. // Was inside an object value string after an escape
  174. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
  175. } else {
  176. // find last :
  177. auto last_pos = str.find_last_of(':');
  178. if (last_pos == std::string::npos) {
  179. throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
  180. }
  181. // Cutting back to opening : for object value
  182. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  183. }
  184. } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
  185. if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
  186. // Was about to create an array value
  187. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  188. } else if (can_parse(str + "\"" + closing)) {
  189. // Was inside an array value string
  190. str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
  191. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
  192. // Was inside an array value string after an escape
  193. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
  194. } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
  195. // Had just finished a value
  196. str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
  197. } else {
  198. auto last_pos = str.find_last_of("[,");
  199. if (last_pos == std::string::npos) {
  200. throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
  201. }
  202. // Cutting back to last [ or , for array value
  203. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  204. }
  205. } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
  206. if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
  207. (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
  208. // Was about to create an object key+value
  209. str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
  210. } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
  211. // Was about to create an object key+value
  212. str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
  213. } else if (can_parse(str + "\": 1" + closing)) {
  214. // Was inside an object key string
  215. str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
  216. } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
  217. // Was inside an object key string after an escape
  218. str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
  219. } else {
  220. auto last_pos = str.find_last_of(':');
  221. if (last_pos == std::string::npos) {
  222. throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
  223. }
  224. // fprintf(stderr, "Cutting back to last : for object key+value\n");
  225. str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
  226. }
  227. } else {
  228. throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
  229. }
  230. // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
  231. out.json = json::parse(str);
  232. it = temptative_end;
  233. return true;
  234. }
  235. // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
  236. // fprintf(stderr, "Closing: TODO\n");
  237. return false;
  238. }
  239. out.json = json::parse(it, end);
  240. it = end;
  241. return true;
  242. }