|
|
@@ -3069,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl(
|
|
|
struct ggml_context * ctx,
|
|
|
struct ggml_tensor * a,
|
|
|
float s,
|
|
|
+ float b,
|
|
|
bool inplace) {
|
|
|
GGML_ASSERT(ggml_is_padded_1d(a));
|
|
|
|
|
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
|
|
|
|
- ggml_set_op_params(result, &s, sizeof(s));
|
|
|
+ float params[2] = { s, b };
|
|
|
+ ggml_set_op_params(result, ¶ms, sizeof(params));
|
|
|
|
|
|
result->op = GGML_OP_SCALE;
|
|
|
result->src[0] = a;
|
|
|
@@ -3086,14 +3088,30 @@ struct ggml_tensor * ggml_scale(
|
|
|
struct ggml_context * ctx,
|
|
|
struct ggml_tensor * a,
|
|
|
float s) {
|
|
|
- return ggml_scale_impl(ctx, a, s, false);
|
|
|
+ return ggml_scale_impl(ctx, a, s, 0.0, false);
|
|
|
}
|
|
|
|
|
|
struct ggml_tensor * ggml_scale_inplace(
|
|
|
struct ggml_context * ctx,
|
|
|
struct ggml_tensor * a,
|
|
|
float s) {
|
|
|
- return ggml_scale_impl(ctx, a, s, true);
|
|
|
+ return ggml_scale_impl(ctx, a, s, 0.0, true);
|
|
|
+}
|
|
|
+
|
|
|
+struct ggml_tensor * ggml_scale_bias(
|
|
|
+ struct ggml_context * ctx,
|
|
|
+ struct ggml_tensor * a,
|
|
|
+ float s,
|
|
|
+ float b) {
|
|
|
+ return ggml_scale_impl(ctx, a, s, b, false);
|
|
|
+}
|
|
|
+
|
|
|
+struct ggml_tensor * ggml_scale_bias_inplace(
|
|
|
+ struct ggml_context * ctx,
|
|
|
+ struct ggml_tensor * a,
|
|
|
+ float s,
|
|
|
+ float b) {
|
|
|
+ return ggml_scale_impl(ctx, a, s, b, true);
|
|
|
}
|
|
|
|
|
|
// ggml_set
|
|
|
@@ -5777,7 +5795,7 @@ static void ggml_compute_backward(
|
|
|
} break;
|
|
|
case GGML_OP_MEAN: {
|
|
|
if (src0_needs_grads) {
|
|
|
- ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
|
|
|
+ ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
|
|
|
}
|
|
|
} break;
|
|
|
case GGML_OP_REPEAT: {
|
|
|
@@ -5854,7 +5872,7 @@ static void ggml_compute_backward(
|
|
|
if (src0_needs_grads) {
|
|
|
float s;
|
|
|
memcpy(&s, tensor->op_params, sizeof(float));
|
|
|
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
|
|
|
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
|
|
|
}
|
|
|
} break;
|
|
|
case GGML_OP_SET: {
|