|
@@ -3431,6 +3431,65 @@ struct test_rms_norm_mul_add : public test_case {
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
|
|
+// GGML_OP_ADD + GGML_OP_RMS_NORM (fused operation)
|
|
|
|
|
+struct test_add_rms_norm : public test_case {
|
|
|
|
|
+ const ggml_type type;
|
|
|
|
|
+ const std::array<int64_t, 4> ne;
|
|
|
|
|
+ const float eps;
|
|
|
|
|
+ const bool broadcast;
|
|
|
|
|
+
|
|
|
|
|
+ std::string op_desc(ggml_tensor * t) override {
|
|
|
|
|
+ GGML_UNUSED(t);
|
|
|
|
|
+ return "ADD_RMS_NORM";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ bool run_whole_graph() override { return true; }
|
|
|
|
|
+
|
|
|
|
|
+ std::string vars() override {
|
|
|
|
|
+ return VARS_TO_STR4(type, ne, eps, broadcast);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ test_add_rms_norm(ggml_type type = GGML_TYPE_F32,
|
|
|
|
|
+ std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
|
|
|
|
+ float eps = 1e-6f, bool broadcast = false)
|
|
|
|
|
+ : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
|
|
|
+ std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
|
|
|
|
|
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
|
|
|
+
|
|
|
|
|
+ ggml_set_param(a);
|
|
|
|
|
+ ggml_set_name(a, "a");
|
|
|
|
|
+ ggml_set_param(b);
|
|
|
|
|
+ ggml_set_name(b, "b");
|
|
|
|
|
+
|
|
|
|
|
+ // ADD operation followed by RMS_NORM
|
|
|
|
|
+ ggml_tensor * add_result = ggml_add(ctx, a, b);
|
|
|
|
|
+ ggml_set_name(add_result, "add_result");
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * out = ggml_rms_norm(ctx, add_result, eps);
|
|
|
|
|
+ ggml_set_name(out, "out");
|
|
|
|
|
+
|
|
|
|
|
+ return out;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ void initialize_tensors(ggml_context * ctx) override {
|
|
|
|
|
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
|
|
|
+ init_tensor_uniform(t, -10.f, 10.f);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ float grad_eps() override {
|
|
|
|
|
+ return 1.0f;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ bool grad_precise() override {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
// GGML_OP_SSM_CONV
|
|
// GGML_OP_SSM_CONV
|
|
|
struct test_ssm_conv : public test_case {
|
|
struct test_ssm_conv : public test_case {
|
|
|
const ggml_type type;
|
|
const ggml_type type;
|
|
@@ -7393,11 +7452,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
|
|
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
|
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
|
|
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
|
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
|
|
|
|
+ test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
|
|
|
|
+ test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
|
|
}
|
|
}
|
|
|
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
|
|
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
|
|
|
for (bool multi_add : {false, true}) {
|
|
for (bool multi_add : {false, true}) {
|
|
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
|
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
|
|
|
}
|
|
}
|
|
|
|
|
+ test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for (auto multi_add : {false, true}) {
|
|
for (auto multi_add : {false, true}) {
|