|
|
@@ -1142,6 +1142,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
default: type = LLM_TYPE_UNKNOWN;
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_GEMMA_EMBEDDING:
|
|
|
+ {
|
|
|
+ hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
|
|
|
+ hparams.set_swa_pattern(6);
|
|
|
+
|
|
|
+ hparams.causal_attn = false; // embeddings do not use causal attention
|
|
|
+ hparams.rope_freq_base_train_swa = 10000.0f;
|
|
|
+ hparams.rope_freq_scale_train_swa = 1.0f;
|
|
|
+
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
|
|
+
|
|
|
+ switch (hparams.n_layer) {
|
|
|
+ case 24: type = LLM_TYPE_0_3B; break;
|
|
|
+ default: type = LLM_TYPE_UNKNOWN;
|
|
|
+ }
|
|
|
+ hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
|
|
+
|
|
|
+ } break;
|
|
|
case LLM_ARCH_STARCODER2:
|
|
|
{
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
|
@@ -3484,6 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
|
}
|
|
|
} break;
|
|
|
case LLM_ARCH_GEMMA3:
|
|
|
+ case LLM_ARCH_GEMMA_EMBEDDING:
|
|
|
{
|
|
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
|
|
@@ -11045,6 +11066,136 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
|
|
+ llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
|
|
+ const int64_t n_embd_head = hparams.n_embd_head_k;
|
|
|
+
|
|
|
+ ggml_tensor * cur;
|
|
|
+ ggml_tensor * inpL;
|
|
|
+
|
|
|
+ inpL = build_inp_embd(model.tok_embd);
|
|
|
+
|
|
|
+ // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
|
|
+ if (ubatch.token) {
|
|
|
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
|
|
+ cb(inpL, "inp_scaled", -1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // inp_pos - contains the positions
|
|
|
+ ggml_tensor * inp_pos = build_inp_pos();
|
|
|
+
|
|
|
+ auto * inp_attn = build_attn_inp_no_cache();
|
|
|
+
|
|
|
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
+
|
|
|
+ for (int il = 0; il < n_layer; ++il) {
|
|
|
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
|
|
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
|
+
|
|
|
+ // norm
|
|
|
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+ cb(cur, "attn_norm", il);
|
|
|
+
|
|
|
+ // self-attention
|
|
|
+ {
|
|
|
+ // compute Q and K and RoPE them
|
|
|
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+
|
|
|
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+
|
|
|
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+
|
|
|
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
|
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
+
|
|
|
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+ cb(Qcur, "Qcur_normed", il);
|
|
|
+
|
|
|
+ Qcur = ggml_rope_ext(
|
|
|
+ ctx0, Qcur, inp_pos, nullptr,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
+
|
|
|
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
|
|
+ cb(Kcur, "Kcur_normed", il);
|
|
|
+
|
|
|
+ Kcur = ggml_rope_ext(
|
|
|
+ ctx0, Kcur, inp_pos, nullptr,
|
|
|
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
|
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
|
+
|
|
|
+ cb(Qcur, "Qcur", il);
|
|
|
+ cb(Kcur, "Kcur", il);
|
|
|
+ cb(Vcur, "Vcur", il);
|
|
|
+
|
|
|
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
|
|
|
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
|
|
+
|
|
|
+ cur = build_attn(inp_attn,
|
|
|
+ model.layers[il].wo, NULL,
|
|
|
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (il == n_layer - 1 && inp_out_ids) {
|
|
|
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = build_norm(cur,
|
|
|
+ model.layers[il].attn_post_norm, NULL,
|
|
|
+ LLM_NORM_RMS, il);
|
|
|
+ cb(cur, "attn_post_norm", il);
|
|
|
+
|
|
|
+ ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
|
|
+ cb(sa_out, "sa_out", il);
|
|
|
+
|
|
|
+ cur = build_norm(sa_out,
|
|
|
+ model.layers[il].ffn_norm, NULL,
|
|
|
+ LLM_NORM_RMS, il);
|
|
|
+ cb(cur, "ffn_norm", il);
|
|
|
+
|
|
|
+ // feed-forward network
|
|
|
+ {
|
|
|
+ cur = build_ffn(cur,
|
|
|
+ model.layers[il].ffn_up, NULL, NULL,
|
|
|
+ model.layers[il].ffn_gate, NULL, NULL,
|
|
|
+ model.layers[il].ffn_down, NULL, NULL,
|
|
|
+ NULL,
|
|
|
+ LLM_FFN_GELU, LLM_FFN_PAR, il);
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = build_norm(cur,
|
|
|
+ model.layers[il].ffn_post_norm, NULL,
|
|
|
+ LLM_NORM_RMS, -1);
|
|
|
+ cb(cur, "ffn_post_norm", -1);
|
|
|
+
|
|
|
+ cur = ggml_add(ctx0, cur, sa_out);
|
|
|
+
|
|
|
+ cur = build_cvec(cur, il);
|
|
|
+ cb(cur, "l_out", il);
|
|
|
+
|
|
|
+ // input for next layer
|
|
|
+ inpL = cur;
|
|
|
+ }
|
|
|
+
|
|
|
+ cur = inpL;
|
|
|
+
|
|
|
+ cur = build_norm(cur,
|
|
|
+ model.output_norm, NULL,
|
|
|
+ LLM_NORM_RMS, -1);
|
|
|
+
|
|
|
+ cb(cur, "result_norm", -1);
|
|
|
+ res->t_embd = cur;
|
|
|
+
|
|
|
+ ggml_build_forward_expand(gf, cur);
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
// TODO: move up next to build_starcoder
|
|
|
struct llm_build_starcoder2 : public llm_graph_context {
|
|
|
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
|
|
@@ -18481,6 +18632,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
|
case LLM_ARCH_NOMIC_BERT_MOE:
|
|
|
case LLM_ARCH_NEO_BERT:
|
|
|
case LLM_ARCH_WAVTOKENIZER_DEC:
|
|
|
+ case LLM_ARCH_GEMMA_EMBEDDING:
|
|
|
case LLM_ARCH_DREAM:
|
|
|
case LLM_ARCH_LLADA:
|
|
|
{
|
|
|
@@ -18529,7 +18681,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
|
/* attn_kv_size */ cparams.n_ctx,
|
|
|
/* attn_n_pad */ padding,
|
|
|
/* attn_n_swa */ hparams.n_swa,
|
|
|
- /* attn_swa_type */ hparams.swa_type,
|
|
|
/* recurrent_type_k */ GGML_TYPE_F32,
|
|
|
/* recurrent_type_v */ GGML_TYPE_F32,
|
|
|
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
|
|
@@ -18599,7 +18750,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
|
cparams.n_seq_max,
|
|
|
padding,
|
|
|
hparams.n_swa,
|
|
|
- hparams.swa_type,
|
|
|
nullptr,
|
|
|
nullptr);
|
|
|
}
|
|
|
@@ -18761,6 +18911,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|
|
{
|
|
|
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
|
|
|
} break;
|
|
|
+ case LLM_ARCH_GEMMA_EMBEDDING:
|
|
|
+ {
|
|
|
+ llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
|
|
|
+ } break;
|
|
|
case LLM_ARCH_STARCODER2:
|
|
|
{
|
|
|
llm = std::make_unique<llm_build_starcoder2>(*this, params);
|
|
|
@@ -19161,6 +19315,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
|
case LLM_ARCH_GEMMA2:
|
|
|
case LLM_ARCH_GEMMA3:
|
|
|
case LLM_ARCH_GEMMA3N:
|
|
|
+ case LLM_ARCH_GEMMA_EMBEDDING:
|
|
|
case LLM_ARCH_STARCODER2:
|
|
|
case LLM_ARCH_OPENELM:
|
|
|
case LLM_ARCH_GPTNEOX:
|