|
|
@@ -214,6 +214,7 @@ enum llm_arch {
|
|
|
LLM_ARCH_NEMOTRON,
|
|
|
LLM_ARCH_EXAONE,
|
|
|
LLM_ARCH_RWKV6,
|
|
|
+ LLM_ARCH_GRANITE,
|
|
|
LLM_ARCH_UNKNOWN,
|
|
|
};
|
|
|
|
|
|
@@ -264,6 +265,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
|
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
|
|
{ LLM_ARCH_EXAONE, "exaone" },
|
|
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
|
|
+ { LLM_ARCH_GRANITE, "granite" },
|
|
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
|
|
};
|
|
|
|
|
|
@@ -303,6 +305,8 @@ enum llm_kv {
|
|
|
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
|
|
LLM_KV_TIME_MIX_EXTRA_DIM,
|
|
|
LLM_KV_TIME_DECAY_EXTRA_DIM,
|
|
|
+ LLM_KV_RESIDUAL_SCALE,
|
|
|
+ LLM_KV_EMBEDDING_SCALE,
|
|
|
|
|
|
LLM_KV_ATTENTION_HEAD_COUNT,
|
|
|
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
|
|
@@ -317,6 +321,7 @@ enum llm_kv {
|
|
|
LLM_KV_ATTENTION_KV_LORA_RANK,
|
|
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
|
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
|
|
+ LLM_KV_ATTENTION_SCALE,
|
|
|
|
|
|
LLM_KV_ROPE_DIMENSION_COUNT,
|
|
|
LLM_KV_ROPE_FREQ_BASE,
|
|
|
@@ -407,6 +412,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
|
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
|
|
|
{ LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
|
|
|
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
|
|
|
+ { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
|
|
+ { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
|
|
|
|
|
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
|
|
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
|
|
@@ -421,6 +428,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
|
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
|
|
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
|
|
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
|
|
+ { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
|
|
|
|
|
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
|
|
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
|
|
@@ -1454,6 +1462,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|
|
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
|
|
|
},
|
|
|
},
|
|
|
+ {
|
|
|
+ LLM_ARCH_GRANITE,
|
|
|
+ {
|
|
|
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
|
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
|
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
|
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
|
|
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
|
|
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
|
|
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
|
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
|
|
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
|
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
|
|
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
|
+ },
|
|
|
+ },
|
|
|
{
|
|
|
LLM_ARCH_UNKNOWN,
|
|
|
{
|
|
|
@@ -2372,6 +2396,11 @@ struct llama_hparams {
|
|
|
float f_max_alibi_bias = 0.0f;
|
|
|
float f_logit_scale = 0.0f;
|
|
|
|
|
|
+ // Additional scale factors (Granite)
|
|
|
+ float f_residual_scale = 0.0f;
|
|
|
+ float f_embedding_scale = 0.0f;
|
|
|
+ float f_attention_scale = 0.0f;
|
|
|
+
|
|
|
bool causal_attn = true;
|
|
|
bool use_alibi = false;
|
|
|
bool attn_soft_cap = false;
|
|
|
@@ -2434,6 +2463,9 @@ struct llama_hparams {
|
|
|
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
|
|
|
if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true;
|
|
|
if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true;
|
|
|
+ if (!is_float_close(this->f_residual_scale, other.f_residual_scale, EPSILON)) return true;
|
|
|
+ if (!is_float_close(this->f_embedding_scale, other.f_embedding_scale, EPSILON)) return true;
|
|
|
+ if (!is_float_close(this->f_attention_scale, other.f_attention_scale, EPSILON)) return true;
|
|
|
|
|
|
return false;
|
|
|
}
|
|
|
@@ -6019,6 +6051,20 @@ static void llm_load_hparams(
|
|
|
default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
}
|
|
|
} break;
|
|
|
+ case LLM_ARCH_GRANITE:
|
|
|
+ {
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
+ ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
|
|
+ ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
|
|
|
+ ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
|
|
|
+ ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
|
|
+
|
|
|
+ switch (hparams.n_layer) {
|
|
|
+ case 40: model.type = e_model::MODEL_3B; break;
|
|
|
+ // Add additional layer/vocab/etc checks here for other model sizes
|
|
|
+ default: model.type = e_model::MODEL_UNKNOWN;
|
|
|
+ }
|
|
|
+ } break;
|
|
|
default: (void)0;
|
|
|
}
|
|
|
|
|
|
@@ -6717,6 +6763,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
|
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
|
|
}
|
|
|
+
|
|
|
+ if (model.arch == LLM_ARCH_GRANITE) {
|
|
|
+ LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
|
|
+ LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
|
|
+ LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Returns false if cancelled by progress_callback
|
|
|
@@ -6885,6 +6937,7 @@ static bool llm_load_tensors(
|
|
|
case LLM_ARCH_LLAMA:
|
|
|
case LLM_ARCH_REFACT:
|
|
|
case LLM_ARCH_MINICPM:
|
|
|
+ case LLM_ARCH_GRANITE:
|
|
|
{
|
|
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
|
|
|
|
@@ -8868,6 +8921,11 @@ static struct ggml_tensor * llm_build_inp_embd(
|
|
|
ggml_set_input(lctx.inp_embd);
|
|
|
}
|
|
|
|
|
|
+ // For Granite architecture
|
|
|
+ if (hparams.f_embedding_scale != 0.0f) {
|
|
|
+ inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
|
|
|
+ }
|
|
|
+
|
|
|
cb(inpL, "inp_embd", -1);
|
|
|
|
|
|
return inpL;
|
|
|
@@ -10146,6 +10204,7 @@ struct llm_build_context {
|
|
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
|
|
|
|
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
|
struct ggml_tensor * inpSA = inpL;
|
|
|
|
|
|
@@ -10198,7 +10257,7 @@ struct llm_build_context {
|
|
|
|
|
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
|
- Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
|
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
|
|
}
|
|
|
|
|
|
if (il == n_layer - 1) {
|
|
|
@@ -10209,6 +10268,11 @@ struct llm_build_context {
|
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
}
|
|
|
|
|
|
+ // For Granite architecture
|
|
|
+ if (hparams.f_residual_scale) {
|
|
|
+ cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
|
+ }
|
|
|
+
|
|
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
|
|
@@ -10245,6 +10309,11 @@ struct llm_build_context {
|
|
|
cb(cur, "ffn_moe_out", il);
|
|
|
}
|
|
|
|
|
|
+ // For Granite architecture
|
|
|
+ if (hparams.f_residual_scale) {
|
|
|
+ cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
|
+ }
|
|
|
+
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
|
cb(cur, "ffn_out", il);
|
|
|
|
|
|
@@ -10264,6 +10333,12 @@ struct llm_build_context {
|
|
|
|
|
|
// lm_head
|
|
|
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
|
|
+
|
|
|
+ // For Granite architecture
|
|
|
+ if (hparams.f_logit_scale) {
|
|
|
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
|
|
+ }
|
|
|
+
|
|
|
cb(cur, "result_output", -1);
|
|
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
@@ -15789,6 +15864,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|
|
|
|
|
switch (model.arch) {
|
|
|
case LLM_ARCH_LLAMA:
|
|
|
+ case LLM_ARCH_GRANITE:
|
|
|
{
|
|
|
result = llm.build_llama();
|
|
|
} break;
|
|
|
@@ -19089,6 +19165,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
|
case LLM_ARCH_ARCTIC:
|
|
|
case LLM_ARCH_DEEPSEEK2:
|
|
|
case LLM_ARCH_CHATGLM:
|
|
|
+ case LLM_ARCH_GRANITE:
|
|
|
return LLAMA_ROPE_TYPE_NORM;
|
|
|
|
|
|
// the pairs of head values are offset by n_rot/2
|