|
|
@@ -1,3 +1,4 @@
|
|
|
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
|
|
#include "ggml.h"
|
|
|
|
|
|
#include <math.h>
|
|
|
@@ -5,6 +6,10 @@
|
|
|
#include <stdlib.h>
|
|
|
#include <assert.h>
|
|
|
|
|
|
+#if defined(_MSC_VER)
|
|
|
+#pragma warning(disable: 4244 4267) // possible loss of data
|
|
|
+#endif
|
|
|
+
|
|
|
#define MAX_NARGS 3
|
|
|
|
|
|
#undef MIN
|
|
|
@@ -197,8 +202,23 @@ bool check_gradient(
|
|
|
float max_error_abs,
|
|
|
float max_error_rel) {
|
|
|
|
|
|
+ static int n_threads = -1;
|
|
|
+ if (n_threads < 0) {
|
|
|
+ n_threads = GGML_DEFAULT_N_THREADS;
|
|
|
+
|
|
|
+ const char *env = getenv("GGML_N_THREADS");
|
|
|
+ if (env) {
|
|
|
+ n_threads = atoi(env);
|
|
|
+ }
|
|
|
+
|
|
|
+ printf("GGML_N_THREADS = %d\n", n_threads);
|
|
|
+ }
|
|
|
+
|
|
|
struct ggml_cgraph gf = ggml_build_forward (f);
|
|
|
+ gf.n_threads = n_threads;
|
|
|
+
|
|
|
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
|
|
|
+ gb.n_threads = n_threads;
|
|
|
|
|
|
ggml_graph_compute(ctx0, &gf);
|
|
|
ggml_graph_reset (&gf);
|