Sfoglia il codice sorgente

opencl: fix support ops condition for `rms_norm` (#15560)

lhez 4 mesi fa
parent
commit
f7207b0415
1 ha cambiato i file con 2 aggiunte e 1 eliminazioni
  1. 2 1
      ggml/src/ggml-opencl/ggml-opencl.cpp

+ 2 - 1
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -2647,8 +2647,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return op->src[0]->type == GGML_TYPE_F32;
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_NORM:
         case GGML_OP_NORM:
-        case GGML_OP_RMS_NORM:
             return true;
             return true;
+        case GGML_OP_RMS_NORM:
+            return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
         case GGML_OP_PAD:
         case GGML_OP_PAD: