Piotr Wilkin 3 месяцев назад
Родитель
Сommit
64de434118
2 измененных файлов с 13 добавлено и 13 удалено
  1. 12 12
      src/models/llm_build_mamba.cpp
  2. 1 1
      src/models/llm_build_mamba.h

+ 12 - 12
src/models/llm_build_mamba.cpp

@@ -1,11 +1,13 @@
-#include "../llama-model.h"
+
 #include "../llama-graph.h"
 #include "../llama-graph.h"
-#include "llm_graph_context_mamba.h"
+#include "../llama-model.h"
 
 
+#include "llm_graph_context_mamba.h"
 #include "llm_build_mamba.h"
 #include "llm_build_mamba.h"
+
 #include <cmath>
 #include <cmath>
 
 
-llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
     ggml_tensor * cur;
     ggml_tensor * cur;
     ggml_tensor * inpL;
     ggml_tensor * inpL;
 
 
@@ -18,22 +20,20 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para
 
 
     for (int il = 0; il < n_layer; ++il) {
     for (int il = 0; il < n_layer; ++il) {
         // norm
         // norm
-        cur = build_norm(inpL,
-                model.layers[il].attn_norm, NULL,
-                LLM_NORM_RMS, il);
+        cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
         cb(cur, "attn_norm", il);
         cb(cur, "attn_norm", il);
 
 
         if (model.arch == LLM_ARCH_MAMBA2) {
         if (model.arch == LLM_ARCH_MAMBA2) {
-            // TODO: implement mamba2_layer inline
-            // cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
+            cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
         } else {
         } else {
-            // TODO: implement mamba_layer inline
-            // cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
+            cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
         }
         }
+
         if (il == n_layer - 1 && inp_out_ids) {
         if (il == n_layer - 1 && inp_out_ids) {
-            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            cur  = ggml_get_rows(ctx0, cur, inp_out_ids);
             inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
             inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
         }
         }
+
         // residual
         // residual
         cur = ggml_add(ctx0, cur, inpL);
         cur = ggml_add(ctx0, cur, inpL);
 
 
@@ -43,7 +43,7 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para
         // input for next layer
         // input for next layer
         inpL = cur;
         inpL = cur;
     }
     }
-;
+
     // final rmsnorm
     // final rmsnorm
     cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
     cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
 
 

+ 1 - 1
src/models/llm_build_mamba.h

@@ -5,6 +5,6 @@
 
 
 #include <cmath>
 #include <cmath>
 
 
-struct llm_build_mamba : public llm_graph_context {
+struct llm_build_mamba : public llm_graph_context_mamba {
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
 };
 };