Explorar o código

graph : fix geglu (#14077)

ggml-ci
Georgi Gerganov hai 7 meses
pai
achega
201b31dc2e
Modificáronse 1 ficheiros con 8 adicións e 16 borrados
  1. 8 16
      src/llama-graph.cpp

+ 8 - 16
src/llama-graph.cpp

@@ -663,22 +663,14 @@ ggml_tensor * llm_graph_context::build_ffn(
             {
                 // Split into two equal parts
                 int64_t split_point = cur->ne[0] / 2;
-                ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
-                                                ctx0, cur, split_point,
-                                                cur->ne[1], cur->nb[1], 0
-                                            ));
-                ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
-                                                ctx0, cur, split_point,
-                                                cur->ne[1], cur->nb[1],
-                                                split_point * ggml_element_size(cur)
-                                            ));
-
-                // Apply GELU activation function to the first part
-                output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
-                cb(output_ffn_up, "ffn_gelu", il);
-
-                // Element-wise multiplication between the activated part and the gate part
-                cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
+                // TODO: these conts should not be needed
+                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_gelu(ctx0, x0);
+                cb(x0, "ffn_gelu", il);
+
+                cur = ggml_mul(ctx0, x0, x1);
                 cb(cur, "ffn_geglu", il);
             } break;
     }