|
@@ -5374,6 +5374,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
|
|
struct ggml_context * ctx,
|
|
struct ggml_context * ctx,
|
|
|
struct ggml_tensor * a,
|
|
struct ggml_tensor * a,
|
|
|
int n_groups,
|
|
int n_groups,
|
|
|
|
|
+ float eps,
|
|
|
bool inplace) {
|
|
bool inplace) {
|
|
|
|
|
|
|
|
bool is_node = false;
|
|
bool is_node = false;
|
|
@@ -5384,7 +5385,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
|
|
|
|
|
|
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
|
|
|
|
|
|
- result->op_params[0] = n_groups;
|
|
|
|
|
|
|
+ ggml_set_op_params_i32(result, 0, n_groups);
|
|
|
|
|
+ ggml_set_op_params_f32(result, 1, eps);
|
|
|
|
|
|
|
|
result->op = GGML_OP_GROUP_NORM;
|
|
result->op = GGML_OP_GROUP_NORM;
|
|
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
@@ -5396,15 +5398,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
|
|
struct ggml_tensor * ggml_group_norm(
|
|
struct ggml_tensor * ggml_group_norm(
|
|
|
struct ggml_context * ctx,
|
|
struct ggml_context * ctx,
|
|
|
struct ggml_tensor * a,
|
|
struct ggml_tensor * a,
|
|
|
- int n_groups) {
|
|
|
|
|
- return ggml_group_norm_impl(ctx, a, n_groups, false);
|
|
|
|
|
|
|
+ int n_groups,
|
|
|
|
|
+ float eps) {
|
|
|
|
|
+ return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
struct ggml_tensor * ggml_group_norm_inplace(
|
|
struct ggml_tensor * ggml_group_norm_inplace(
|
|
|
struct ggml_context * ctx,
|
|
struct ggml_context * ctx,
|
|
|
struct ggml_tensor * a,
|
|
struct ggml_tensor * a,
|
|
|
- int n_groups) {
|
|
|
|
|
- return ggml_group_norm_impl(ctx, a, n_groups, true);
|
|
|
|
|
|
|
+ int n_groups,
|
|
|
|
|
+ float eps) {
|
|
|
|
|
+ return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// ggml_mul_mat
|
|
// ggml_mul_mat
|
|
@@ -12095,10 +12099,11 @@ static void ggml_compute_forward_group_norm_f32(
|
|
|
|
|
|
|
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
|
|
|
|
|
- const float eps = 1e-6f; // TODO: make this a parameter
|
|
|
|
|
-
|
|
|
|
|
// TODO: optimize
|
|
// TODO: optimize
|
|
|
|
|
|
|
|
|
|
+ float eps;
|
|
|
|
|
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
|
|
|
|
|
+
|
|
|
int n_channels = src0->ne[2];
|
|
int n_channels = src0->ne[2];
|
|
|
int n_groups = dst->op_params[0];
|
|
int n_groups = dst->op_params[0];
|
|
|
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
|
|
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
|