Browse Source

jinja : attribute support for join, map and sort (#18883)

* support negative array index and default value

* attribute support (int and str) for join, map and sort

* add tests

* update CODEOWNERS

* improve fixme sorting comment
Sigbjørn Skjæret 1 week ago
parent
commit
d03c45c9c5
4 changed files with 144 additions and 38 deletions
  1. 1 0
      CODEOWNERS
  2. 52 31
      common/jinja/value.cpp
  3. 14 2
      common/jinja/value.h
  4. 77 5
      tests/test-jinja.cpp

+ 1 - 0
CODEOWNERS

@@ -15,6 +15,7 @@
 /common/common.*                        @ggerganov
 /common/console.*                       @ggerganov
 /common/http.*                          @angt
+/common/jinja/                          @ngxson @CISC @aldehir
 /common/llguidance.*                    @ggerganov
 /common/log.*                           @ggerganov
 /common/peg-parser.*                    @aldehir

+ 52 - 31
common/jinja/value.cpp

@@ -776,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const {
             if (!is_val<value_array>(args.get_pos(0))) {
                 throw raised_exception("join() first argument must be an array");
             }
-            value val_delim     = args.get_kwarg_or_pos("d",         1);
-            value val_attribute = args.get_kwarg_or_pos("attribute", 2);
-            if (!val_attribute->is_undefined()) {
-                throw not_implemented_exception("array attribute join not implemented");
-            }
+            value val_delim = args.get_kwarg_or_pos("d",         1);
+            value attribute = args.get_kwarg_or_pos("attribute", 2);
             const auto & arr = args.get_pos(0)->as_array();
-            std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
+            const bool attr_is_int = is_val<value_int>(attribute);
+            if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
+                throw raised_exception("join() attribute must be string or integer");
+            }
+            const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
+            const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
+            const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
             std::string result;
             for (size_t i = 0; i < arr.size(); ++i) {
-                if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
+                value val_arr = arr[i];
+                if (!attribute->is_undefined()) {
+                    if (attr_is_int && is_val<value_array>(val_arr)) {
+                        val_arr = val_arr->at(attr_int);
+                    } else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
+                        val_arr = val_arr->at(attr_name);
+                    }
+                }
+                if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
                     throw raised_exception("join() can only join arrays of strings or numerics");
                 }
-                result += arr[i]->as_string().str();
+                result += val_arr->as_string().str();
                 if (i < arr.size() - 1) {
                     result += delim;
                 }
@@ -803,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const {
         }},
         {"tojson", tojson},
         {"map", [](const func_args & args) -> value {
-            args.ensure_count(2, 3);
+            args.ensure_count(2);
             if (!is_val<value_array>(args.get_pos(0))) {
                 throw raised_exception("map: first argument must be an array");
             }
-            value attribute = args.get_kwarg_or_pos("attribute", 1);
-            if (is_val<value_int>(attribute)) {
-                throw not_implemented_exception("map: integer attribute not implemented");
+            if (!is_val<value_kwarg>(args.get_args().at(1))) {
+                throw not_implemented_exception("map: filter-mapping not implemented");
             }
-            if (!is_val<value_string>(attribute)) {
+            value attribute = args.get_kwarg_or_pos("attribute", 1);
+            const bool attr_is_int = is_val<value_int>(attribute);
+            if (!is_val<value_string>(attribute) && !attr_is_int) {
                 throw raised_exception("map: attribute must be string or integer");
             }
-            std::string attr_name = attribute->as_string().str();
+            const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
+            const std::string attr_name = attribute->as_string().str();
             value default_val = args.get_kwarg("default", mk_val<value_undefined>());
             auto out = mk_val<value_array>();
             auto arr = args.get_pos(0)->as_array();
             for (const auto & item : arr) {
-                if (!is_val<value_object>(item)) {
-                    throw raised_exception("map: item is not an object");
+                value attr_val;
+                if (attr_is_int) {
+                    attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
+                } else {
+                    attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
                 }
-                value attr_val = item->at(attr_name, default_val);
                 out->push_back(attr_val);
             }
             return out;
@@ -848,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const {
             return arr_editable->pop_at(index);
         }},
         {"sort", [](const func_args & args) -> value {
-            args.ensure_count(1, 3);
+            args.ensure_count(1, 4);
             if (!is_val<value_array>(args.get_pos(0))) {
                 throw raised_exception("sort: first argument must be an array");
             }
-            bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
-            value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
-            std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
+            value val_reverse = args.get_kwarg_or_pos("reverse",        1);
+            value val_case    = args.get_kwarg_or_pos("case_sensitive", 2);
+            value attribute   = args.get_kwarg_or_pos("attribute",      3);
+            // FIXME: sorting is currently always case sensitive
+            //const bool case_sensitive = val_case->as_bool(); // undefined == false
+            const bool reverse = val_reverse->as_bool(); // undefined == false
+            const bool attr_is_int = is_val<value_int>(attribute);
+            const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
+            const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
             std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
             std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
                 value val_a = a;
                 value val_b = b;
                 if (!attribute->is_undefined()) {
-                    if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
-                        throw raised_exception("sort: items are not objects");
+                    if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
+                        val_a = a->at(attr_int);
+                        val_b = b->at(attr_int);
+                    } else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
+                        val_a = a->at(attr_name);
+                        val_b = b->at(attr_name);
+                    } else {
+                        throw raised_exception("sort: unsupported object attribute comparison");
                     }
-                    val_a = attr.empty() ? a : a->at(attr);
-                    val_b = attr.empty() ? b : b->at(attr);
-                }
-                if (reverse) {
-                    return value_compare(val_a, val_b, value_compare_op::gt);
-                } else {
-                    return !value_compare(val_a, val_b, value_compare_op::gt);
                 }
+                return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
             });
             return mk_val<value_array>(arr);
         }},
@@ -964,7 +985,7 @@ const func_builtins & value_object_t::get_builtins() const {
             value val_case    = args.get_kwarg_or_pos("case_sensitive", 1);
             value val_by      = args.get_kwarg_or_pos("by",             2);
             value val_reverse = args.get_kwarg_or_pos("reverse",        3);
-            // FIXME: sorting is case sensitive
+            // FIXME: sorting is currently always case sensitive
             //const bool case_sensitive = val_case->as_bool(); // undefined == false
             const bool reverse = val_reverse->as_bool(); // undefined == false
             if (!val_by->is_undefined()) {

+ 14 - 2
common/jinja/value.h

@@ -168,8 +168,20 @@ struct value_t {
         }
         return val_obj.unordered.at(key);
     }
-    virtual value & at(size_t index) {
-        if (index >= val_arr.size()) {
+    virtual value & at(int64_t index, value & default_val) {
+        if (index < 0) {
+            index += val_arr.size();
+        }
+        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
+            return default_val;
+        }
+        return val_arr[index];
+    }
+    virtual value & at(int64_t index) {
+        if (index < 0) {
+            index += val_arr.size();
+        }
+        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
             throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
         }
         return val_arr[index];

+ 77 - 5
tests/test-jinja.cpp

@@ -389,6 +389,32 @@ static void test_filters(testing & t) {
         "123"
     );
 
+    test_template(t, "sort reverse",
+        "{% for i in items|sort(true) %}{{ i }}{% endfor %}",
+        {{"items", json::array({3, 1, 2})}},
+        "321"
+    );
+
+    test_template(t, "sort with attribute",
+        "{{ items|sort(attribute='name')|join(attribute='age') }}",
+        {{"items", json::array({
+            json({{"name", "c"}, {"age", 3}}),
+            json({{"name", "a"}, {"age", 1}}),
+            json({{"name", "b"}, {"age", 2}}),
+        })}},
+        "123"
+    );
+
+    test_template(t, "sort with numeric attribute",
+        "{{ items|sort(attribute=0)|join(attribute=1) }}",
+        {{"items", json::array({
+            json::array({3, "z"}),
+            json::array({1, "x"}),
+            json::array({2, "y"}),
+        })}},
+        "xyz"
+    );
+
     test_template(t, "join",
         "{{ items|join(', ') }}",
         {{"items", json::array({"a", "b", "c"})}},
@@ -1000,7 +1026,17 @@ static void test_array_methods(testing & t) {
     );
 
     test_template(t, "array|join attribute",
-        "{{ arr|join(attribute=0) }}",
+        "{{ arr|join(attribute='age') }}",
+        {{"arr", json::array({
+            json({{"name", "a"}, {"age", 1}}),
+            json({{"name", "b"}, {"age", 2}}),
+            json({{"name", "c"}, {"age", 3}}),
+        })}},
+        "123"
+    );
+
+    test_template(t, "array|join numeric attribute",
+        "{{ arr|join(attribute=-1) }}",
         {{"arr", json::array({json::array({1}), json::array({2}), json::array({3})})}},
         "123"
     );
@@ -1023,8 +1059,8 @@ static void test_array_methods(testing & t) {
         "a,b,c,d"
     );
 
-    test_template(t, "array.map() with attribute",
-        "{% for v in arr.map('age') %}{{ v }} {% endfor %}",
+    test_template(t, "array|map with attribute",
+        "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
         {{"arr", json::array({
             json({{"name", "a"}, {"age", 1}}),
             json({{"name", "b"}, {"age", 2}}),
@@ -1033,8 +1069,28 @@ static void test_array_methods(testing & t) {
         "1 2 3 "
     );
 
-    test_template(t, "array.map() with numeric attribute",
-        "{% for v in arr.map(0) %}{{ v }} {% endfor %}",
+    test_template(t, "array|map with attribute default",
+        "{% for v in arr|map(attribute='age', default=3) %}{{ v }} {% endfor %}",
+        {{"arr", json::array({
+            json({{"name", "a"}, {"age", 1}}),
+            json({{"name", "b"}, {"age", 2}}),
+            json({{"name", "c"}}),
+        })}},
+        "1 2 3 "
+    );
+
+    test_template(t, "array|map without attribute default",
+        "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
+        {{"arr", json::array({
+            json({{"name", "a"}, {"age", 1}}),
+            json({{"name", "b"}, {"age", 2}}),
+            json({{"name", "c"}}),
+        })}},
+        "1 2  "
+    );
+
+    test_template(t, "array|map with numeric attribute",
+        "{% for v in arr|map(attribute=0) %}{{ v }} {% endfor %}",
         {{"arr", json::array({
             json::array({10, "x"}),
             json::array({20, "y"}),
@@ -1043,6 +1099,22 @@ static void test_array_methods(testing & t) {
         "10 20 30 "
     );
 
+    test_template(t, "array|map with negative attribute",
+        "{% for v in arr|map(attribute=-1) %}{{ v }} {% endfor %}",
+        {{"arr", json::array({
+            json::array({10, "x"}),
+            json::array({20, "y"}),
+            json::array({30, "z"}),
+        })}},
+        "x y z "
+    );
+
+    test_template(t, "array|map with filter",
+        "{{ arr|map('int')|sum }}",
+        {{"arr", json::array({"1", "2", "3"})}},
+        "6"
+    );
+
     // not used by any chat templates
     // test_template(t, "array.insert()",
     //     "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",