Просмотр исходного кода

tests : sync test-grad0 from ggml

Georgi Gerganov 2 лет назад
Родитель
Сommit
65bdd52a86
1 измененных файлов с 20 добавлено и 0 удалено
  1. 20 0
      tests/test-grad0.c

+ 20 - 0
tests/test-grad0.c

@@ -1,3 +1,4 @@
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 #include "ggml.h"
 
 #include <math.h>
@@ -5,6 +6,10 @@
 #include <stdlib.h>
 #include <assert.h>
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 #define MAX_NARGS 3
 
 #undef MIN
@@ -197,8 +202,23 @@ bool check_gradient(
         float max_error_abs,
         float max_error_rel) {
 
+    static int n_threads = -1;
+    if (n_threads < 0) {
+        n_threads = GGML_DEFAULT_N_THREADS;
+
+        const char *env = getenv("GGML_N_THREADS");
+        if (env) {
+            n_threads = atoi(env);
+        }
+
+        printf("GGML_N_THREADS = %d\n", n_threads);
+    }
+
     struct ggml_cgraph gf = ggml_build_forward (f);
+    gf.n_threads = n_threads;
+
     struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+    gb.n_threads = n_threads;
 
     ggml_graph_compute(ctx0, &gf);
     ggml_graph_reset  (&gf);