|
|
@@ -1,6 +1,7 @@
|
|
|
#include "models.h"
|
|
|
|
|
|
-llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
|
|
+template <bool iswa>
|
|
|
+llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
|
|
const int64_t n_embd_head = hparams.n_embd_head_k;
|
|
|
|
|
|
ggml_tensor * cur;
|
|
|
@@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
|
|
|
ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
|
|
// TODO: is causal == true correct? might need some changes
|
|
|
- auto * inp_attn = build_attn_inp_kv_iswa();
|
|
|
+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
|
|
+ inp_attn_type * inp_attn = nullptr;
|
|
|
+
|
|
|
+ if constexpr (iswa) {
|
|
|
+ inp_attn = build_attn_inp_kv_iswa();
|
|
|
+ } else {
|
|
|
+ inp_attn = build_attn_inp_kv();
|
|
|
+ }
|
|
|
|
|
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
- const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
|
|
- const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
|
+ float freq_base_l = 0.0f;
|
|
|
+ float freq_scale_l = 0.0f;
|
|
|
+
|
|
|
+ if constexpr (iswa) {
|
|
|
+ freq_base_l = model.get_rope_freq_base (cparams, il);
|
|
|
+ freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
|
+ } else {
|
|
|
+ freq_base_l = freq_base;
|
|
|
+ freq_scale_l = freq_scale;
|
|
|
+ }
|
|
|
|
|
|
// norm
|
|
|
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
|
|
@@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
|
|
|
cur = build_norm(cur,
|
|
|
model.layers[il].ffn_post_norm, NULL,
|
|
|
LLM_NORM_RMS, -1);
|
|
|
- cb(cur, "ffn_post_norm", -1);
|
|
|
+ cb(cur, "ffn_post_norm", il);
|
|
|
|
|
|
cur = ggml_add(ctx0, cur, sa_out);
|
|
|
|
|
|
@@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
|
|
|
// lm_head
|
|
|
cur = build_lora_mm(model.output, cur);
|
|
|
|
|
|
+ if (hparams.f_final_logit_softcapping) {
|
|
|
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
|
|
+ cur = ggml_tanh(ctx0, cur);
|
|
|
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
|
|
+ }
|
|
|
+
|
|
|
cb(cur, "result_output", -1);
|
|
|
res->t_logits = cur;
|
|
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
}
|
|
|
+
|
|
|
+template struct llm_build_gemma3<false>;
|
|
|
+template struct llm_build_gemma3<true>;
|