Forráskód Böngészése

metal : update support condition for im2col + fix warning (#0)

Georgi Gerganov 1 éve
szülő
commit
a876861455
2 módosított fájl, 4 hozzáadás és 3 törlés
  1. 2 1
      ggml/src/ggml-metal.m
  2. 2 2
      tests/test-backend-ops.cpp

+ 2 - 1
ggml/src/ggml-metal.m

@@ -799,8 +799,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             return ctx->support_simdgroup_reduction;
         case GGML_OP_NORM:
         case GGML_OP_ROPE:
-        case GGML_OP_IM2COL:
             return true;
+        case GGML_OP_IM2COL:
+            return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
         case GGML_OP_POOL_2D:
             return false;

+ 2 - 2
tests/test-backend-ops.cpp

@@ -24,6 +24,7 @@
 #include <cfloat>
 #include <cstdint>
 #include <cstring>
+#include <cinttypes>
 #include <functional>
 #include <memory>
 #include <random>
@@ -33,7 +34,6 @@
 #include <thread>
 #include <vector>
 
-
 static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
     // static RNG initialization (revisit if n_threads stops being constant)
     static const size_t n_threads = std::thread::hardware_concurrency();
@@ -869,7 +869,7 @@ struct test_case {
             for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
                 // check for nans
                 if (!std::isfinite(ga[i])) {
-                    printf("[%s] nonfinite gradient at index %zu (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
+                    printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
                     ok = false;
                     break;
                 }