|
@@ -15665,6 +15665,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
|
|
|
|
+
|
|
|
|
|
+static size_t hash(void * p) {
|
|
|
|
|
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+static bool hash_insert(void * hash_table[], void * p) {
|
|
|
|
|
+ size_t h = hash(p);
|
|
|
|
|
+
|
|
|
|
|
+ // linear probing
|
|
|
|
|
+ size_t i = h;
|
|
|
|
|
+ while (hash_table[i] != NULL && hash_table[i] != p) {
|
|
|
|
|
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
|
|
|
|
+ if (i == h) {
|
|
|
|
|
+ // hash table is full
|
|
|
|
|
+ GGML_ASSERT(false);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (hash_table[i] == p) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // insert
|
|
|
|
|
+ hash_table[i] = p;
|
|
|
|
|
+ return false;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
|
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
|
|
if (node->grad == NULL) {
|
|
if (node->grad == NULL) {
|
|
|
// this usually happens when we generate intermediate nodes from constants in the backward pass
|
|
// this usually happens when we generate intermediate nodes from constants in the backward pass
|
|
@@ -15675,16 +15703,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// check if already visited
|
|
// check if already visited
|
|
|
- for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
|
|
|
- if (cgraph->nodes[i] == node) {
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- for (int i = 0; i < cgraph->n_leafs; i++) {
|
|
|
|
|
- if (cgraph->leafs[i] == node) {
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if (hash_insert(cgraph->visited_hash_table, node)) {
|
|
|
|
|
+ return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
|
@@ -15747,6 +15767,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
|
|
/*.nodes =*/ { NULL },
|
|
/*.nodes =*/ { NULL },
|
|
|
/*.grads =*/ { NULL },
|
|
/*.grads =*/ { NULL },
|
|
|
/*.leafs =*/ { NULL },
|
|
/*.leafs =*/ { NULL },
|
|
|
|
|
+ /*.hash_table =*/ { NULL },
|
|
|
/*.perf_runs =*/ 0,
|
|
/*.perf_runs =*/ 0,
|
|
|
/*.perf_cycles =*/ 0,
|
|
/*.perf_cycles =*/ 0,
|
|
|
/*.perf_time_us =*/ 0,
|
|
/*.perf_time_us =*/ 0,
|
|
@@ -15788,7 +15809,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
|
|
|
|
|
|
|
|
if (node->is_param) {
|
|
if (node->is_param) {
|
|
|
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
|
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
|
|
- ggml_build_forward_impl(&result, node->grad, true);
|
|
|
|
|
|
|
+ ggml_build_forward_expand(&result, node->grad);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|