|
|
@@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
|
|
|
struct test_sum : public test_case {
|
|
|
const ggml_type type;
|
|
|
const std::array<int64_t, 4> ne;
|
|
|
+ const std::array<int64_t, 4> permute;
|
|
|
+ bool _use_permute;
|
|
|
|
|
|
std::string vars() override {
|
|
|
- return VARS_TO_STR2(type, ne);
|
|
|
+ std::string v = VARS_TO_STR2(type, ne);
|
|
|
+ if (_use_permute) v += "," + VAR_TO_STR(permute);
|
|
|
+ return v;
|
|
|
}
|
|
|
|
|
|
test_sum(ggml_type type = GGML_TYPE_F32,
|
|
|
- std::array<int64_t, 4> ne = {10, 5, 4, 3})
|
|
|
- : type(type), ne(ne) {}
|
|
|
+ std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
|
|
+ std::array<int64_t, 4> permute = {0, 0, 0, 0})
|
|
|
+ : type(type), ne(ne), permute(permute),
|
|
|
+ _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
|
|
|
|
|
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
|
ggml_set_param(a);
|
|
|
ggml_set_name(a, "a");
|
|
|
|
|
|
+ if (_use_permute) {
|
|
|
+ a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
|
|
|
+ ggml_set_name(a, "a_permuted");
|
|
|
+ }
|
|
|
+
|
|
|
ggml_tensor * out = ggml_sum(ctx, a);
|
|
|
ggml_set_name(out, "out");
|
|
|
|
|
|
@@ -6724,6 +6735,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
|
|
|
test_cases.emplace_back(new test_sum());
|
|
|
test_cases.emplace_back(new test_sum_rows());
|
|
|
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
|
|
|
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
|
|
|
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
|
|
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
|
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
|
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
|
|
@@ -6734,6 +6748,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
|
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
|
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
|
|
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
|
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
|
|
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
|
|
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|