|
|
@@ -4,6 +4,7 @@
|
|
|
#include <cstdlib>
|
|
|
|
|
|
#include <nlohmann/json.hpp>
|
|
|
+#include <sheredom/subprocess.h>
|
|
|
|
|
|
#include "jinja/runtime.h"
|
|
|
#include "jinja/parser.h"
|
|
|
@@ -31,12 +32,24 @@ static void test_array_methods(testing & t);
|
|
|
static void test_object_methods(testing & t);
|
|
|
static void test_fuzzing(testing & t);
|
|
|
|
|
|
+static bool g_python_mode = false;
|
|
|
+
|
|
|
int main(int argc, char *argv[]) {
|
|
|
testing t(std::cout);
|
|
|
t.verbose = true;
|
|
|
|
|
|
- if (argc >= 2) {
|
|
|
- t.set_filter(argv[1]);
|
|
|
+ // usage: test-jinja [-py] [filter_regex]
|
|
|
+ // -py : enable python mode (use python jinja2 for rendering expected output)
|
|
|
+ // only use this for cross-checking, not for correctness
|
|
|
+ // note: the implementation of this flag is basic, only intented to be used by maintainers
|
|
|
+
|
|
|
+ for (int i = 1; i < argc; i++) {
|
|
|
+ std::string arg = argv[i];
|
|
|
+ if (arg == "-py") {
|
|
|
+ g_python_mode = true;
|
|
|
+ } else {
|
|
|
+ t.set_filter(arg);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
t.test("whitespace control", test_whitespace_control);
|
|
|
@@ -53,7 +66,9 @@ int main(int argc, char *argv[]) {
|
|
|
t.test("string methods", test_string_methods);
|
|
|
t.test("array methods", test_array_methods);
|
|
|
t.test("object methods", test_object_methods);
|
|
|
- t.test("fuzzing", test_fuzzing);
|
|
|
+ if (!g_python_mode) {
|
|
|
+ t.test("fuzzing", test_fuzzing);
|
|
|
+ }
|
|
|
|
|
|
return t.summary();
|
|
|
}
|
|
|
@@ -1215,7 +1230,7 @@ static void test_object_methods(testing & t) {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
-static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
|
|
+static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
|
|
t.test(name, [&tmpl, &vars, &expect](testing & t) {
|
|
|
jinja::lexer lexer;
|
|
|
auto lexer_res = lexer.tokenize(tmpl);
|
|
|
@@ -1248,6 +1263,99 @@ static void test_template(testing & t, const std::string & name, const std::stri
|
|
|
});
|
|
|
}
|
|
|
|
|
|
+// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
|
|
|
+// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop()
|
|
|
+static std::string py_script = R"(
|
|
|
+import jinja2
|
|
|
+import jinja2.ext as jinja2_ext
|
|
|
+import json
|
|
|
+import sys
|
|
|
+from datetime import datetime
|
|
|
+from jinja2.sandbox import SandboxedEnvironment
|
|
|
+
|
|
|
+tmpl = json.loads(sys.argv[1])
|
|
|
+vars_json = json.loads(sys.argv[2])
|
|
|
+
|
|
|
+env = SandboxedEnvironment(
|
|
|
+ trim_blocks=True,
|
|
|
+ lstrip_blocks=True,
|
|
|
+ extensions=[jinja2_ext.loopcontrols],
|
|
|
+)
|
|
|
+
|
|
|
+def raise_exception(message):
|
|
|
+ raise jinja2.exceptions.TemplateError(message)
|
|
|
+
|
|
|
+env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
|
|
+env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
|
|
|
+env.globals["raise_exception"] = raise_exception
|
|
|
+
|
|
|
+template = env.from_string(tmpl)
|
|
|
+result = template.render(**vars_json)
|
|
|
+print(result, end='')
|
|
|
+)";
|
|
|
+
|
|
|
+static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
|
|
+ t.test(name, [&tmpl, &vars, &expect](testing & t) {
|
|
|
+ // Prepare arguments
|
|
|
+ std::string tmpl_json = json(tmpl).dump();
|
|
|
+ std::string vars_json = vars.dump();
|
|
|
+
|
|
|
+#ifdef _WIN32
|
|
|
+ const char * python_executable = "python.exe";
|
|
|
+#else
|
|
|
+ const char * python_executable = "python3";
|
|
|
+#endif
|
|
|
+
|
|
|
+ const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL};
|
|
|
+
|
|
|
+ struct subprocess_s subprocess;
|
|
|
+ int options = subprocess_option_combined_stdout_stderr
|
|
|
+ | subprocess_option_no_window
|
|
|
+ | subprocess_option_inherit_environment
|
|
|
+ | subprocess_option_search_user_path;
|
|
|
+ int result = subprocess_create(command_line, options, &subprocess);
|
|
|
+
|
|
|
+ if (result != 0) {
|
|
|
+ t.log("Failed to create subprocess, error code: " + std::to_string(result));
|
|
|
+ t.assert_true("subprocess creation", false);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Read output
|
|
|
+ std::string output;
|
|
|
+ char buffer[1024];
|
|
|
+ FILE * p_stdout = subprocess_stdout(&subprocess);
|
|
|
+ while (fgets(buffer, sizeof(buffer), p_stdout)) {
|
|
|
+ output += buffer;
|
|
|
+ }
|
|
|
+
|
|
|
+ int process_return;
|
|
|
+ subprocess_join(&subprocess, &process_return);
|
|
|
+ subprocess_destroy(&subprocess);
|
|
|
+
|
|
|
+ if (process_return != 0) {
|
|
|
+ t.log("Python script failed with exit code: " + std::to_string(process_return));
|
|
|
+ t.log("Output: " + output);
|
|
|
+ t.assert_true("python execution", false);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!t.assert_true("Template render mismatch", expect == output)) {
|
|
|
+ t.log("Template: " + json(tmpl).dump());
|
|
|
+ t.log("Expected: " + json(expect).dump());
|
|
|
+ t.log("Python : " + json(output).dump());
|
|
|
+ }
|
|
|
+ });
|
|
|
+}
|
|
|
+
|
|
|
+static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
|
|
+ if (g_python_mode) {
|
|
|
+ test_template_py(t, name, tmpl, vars, expect);
|
|
|
+ } else {
|
|
|
+ test_template_cpp(t, name, tmpl, vars, expect);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
//
|
|
|
// fuzz tests to ensure no crashes occur on malformed inputs
|
|
|
//
|