@@ -2876,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
//if rms norm is the B operand, then we don't handle broadcast
- if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
@@ -2686,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
- !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+ !ggml_are_same_shape(mul->src[0], rms_norm)) {