瀏覽代碼

tests : add test-jinja -py option for cross-checking (#18906)

* tests : add test-jinja -py option or cross-checking

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix + add source

* SandboxedEnvironment

* fix array.map case

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Xuan-Son Nguyen 1 周之前
父節點
當前提交
fe44d35574
共有 1 個文件被更改,包括 112 次插入4 次删除
  1. 112 4
      tests/test-jinja.cpp

+ 112 - 4
tests/test-jinja.cpp

@@ -4,6 +4,7 @@
 #include <cstdlib>
 
 #include <nlohmann/json.hpp>
+#include <sheredom/subprocess.h>
 
 #include "jinja/runtime.h"
 #include "jinja/parser.h"
@@ -31,12 +32,24 @@ static void test_array_methods(testing & t);
 static void test_object_methods(testing & t);
 static void test_fuzzing(testing & t);
 
+static bool g_python_mode = false;
+
 int main(int argc, char *argv[]) {
     testing t(std::cout);
     t.verbose = true;
 
-    if (argc >= 2) {
-        t.set_filter(argv[1]);
+    // usage: test-jinja [-py] [filter_regex]
+    //  -py : enable python mode (use python jinja2 for rendering expected output)
+    //        only use this for cross-checking, not for correctness
+    //        note: the implementation of this flag is basic, only intented to be used by maintainers
+
+    for (int i = 1; i < argc; i++) {
+        std::string arg = argv[i];
+        if (arg == "-py") {
+            g_python_mode = true;
+        } else {
+            t.set_filter(arg);
+        }
     }
 
     t.test("whitespace control", test_whitespace_control);
@@ -53,7 +66,9 @@ int main(int argc, char *argv[]) {
     t.test("string methods", test_string_methods);
     t.test("array methods", test_array_methods);
     t.test("object methods", test_object_methods);
-    t.test("fuzzing", test_fuzzing);
+    if (!g_python_mode) {
+        t.test("fuzzing", test_fuzzing);
+    }
 
     return t.summary();
 }
@@ -1215,7 +1230,7 @@ static void test_object_methods(testing & t) {
     );
 }
 
-static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
+static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
     t.test(name, [&tmpl, &vars, &expect](testing & t) {
         jinja::lexer lexer;
         auto lexer_res = lexer.tokenize(tmpl);
@@ -1248,6 +1263,99 @@ static void test_template(testing & t, const std::string & name, const std::stri
     });
 }
 
+// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
+// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop()
+static std::string py_script = R"(
+import jinja2
+import jinja2.ext as jinja2_ext
+import json
+import sys
+from datetime import datetime
+from jinja2.sandbox import SandboxedEnvironment
+
+tmpl = json.loads(sys.argv[1])
+vars_json = json.loads(sys.argv[2])
+
+env = SandboxedEnvironment(
+    trim_blocks=True,
+    lstrip_blocks=True,
+    extensions=[jinja2_ext.loopcontrols],
+)
+
+def raise_exception(message):
+    raise jinja2.exceptions.TemplateError(message)
+
+env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
+env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
+env.globals["raise_exception"] = raise_exception
+
+template = env.from_string(tmpl)
+result = template.render(**vars_json)
+print(result, end='')
+)";
+
+static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
+    t.test(name, [&tmpl, &vars, &expect](testing & t) {
+        // Prepare arguments
+        std::string tmpl_json = json(tmpl).dump();
+        std::string vars_json = vars.dump();
+
+#ifdef _WIN32
+        const char * python_executable = "python.exe";
+#else
+        const char * python_executable = "python3";
+#endif
+
+        const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL};
+
+        struct subprocess_s subprocess;
+        int options = subprocess_option_combined_stdout_stderr
+                    | subprocess_option_no_window
+                    | subprocess_option_inherit_environment
+                    | subprocess_option_search_user_path;
+        int result = subprocess_create(command_line, options, &subprocess);
+
+        if (result != 0) {
+            t.log("Failed to create subprocess, error code: " + std::to_string(result));
+            t.assert_true("subprocess creation", false);
+            return;
+        }
+
+        // Read output
+        std::string output;
+        char buffer[1024];
+        FILE * p_stdout = subprocess_stdout(&subprocess);
+        while (fgets(buffer, sizeof(buffer), p_stdout)) {
+            output += buffer;
+        }
+
+        int process_return;
+        subprocess_join(&subprocess, &process_return);
+        subprocess_destroy(&subprocess);
+
+        if (process_return != 0) {
+            t.log("Python script failed with exit code: " + std::to_string(process_return));
+            t.log("Output: " + output);
+            t.assert_true("python execution", false);
+            return;
+        }
+
+        if (!t.assert_true("Template render mismatch", expect == output)) {
+            t.log("Template: " + json(tmpl).dump());
+            t.log("Expected: " + json(expect).dump());
+            t.log("Python  : " + json(output).dump());
+        }
+    });
+}
+
+static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
+    if (g_python_mode) {
+        test_template_py(t, name, tmpl, vars, expect);
+    } else {
+        test_template_cpp(t, name, tmpl, vars, expect);
+    }
+}
+
 //
 // fuzz tests to ensure no crashes occur on malformed inputs
 //