فهرست منبع

CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (#16577)

Aman Gupta 3 ماه پیش
والد
کامیت
120bf7046d
2فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 1 1
      ggml/src/ggml-cuda/ggml-cuda.cu
  2. 1 1
      ggml/src/ggml-opencl/ggml-opencl.cpp

+ 1 - 1
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -2876,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
 
         //if rms norm is the B operand, then we don't handle broadcast
-        if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+        if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
             return false;
         }
 

+ 1 - 1
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -2686,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
 
         // if rms_norm is the B operand, then we don't handle broadcast
         if (rms_norm == mul->src[1] &&
-            !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+            !ggml_are_same_shape(mul->src[0], rms_norm)) {
             return false;
         }