Просмотр исходного кода

model : add EXAONE MoE (#18543)

* Add EXAONE MoE implementations

Co-authored-by: Junwon Hwang <nuclear1221@gmail.com>

* Address PR feedback

* Address PR feedback

* [WIP] Add MTP for EXAONE-MoE

* Address PR feedback

* Address PR feedback

* Address PR feedback

* Address PR feedback

* Address PR feedback

* Address PR feedback

* Address PR feedback

---------

Co-authored-by: LG-AI-EXAONE <exaonemodels@lgresearch.ai>
Junwon Hwang 2 недель назад
Родитель
Сommit
60591f01d4

+ 115 - 0
common/chat-parser.cpp

@@ -1403,6 +1403,118 @@ static void common_chat_parse_solar_open(common_chat_msg_parser & builder) {
     builder.add_content(builder.consume_rest());
 }
 
+static void common_chat_parse_exaone_moe_content(common_chat_msg_parser & builder) {
+    // 1) <tool_call>{ "name": "...", "arguments": {...} }</tool_call>
+    // 2) <tool_call>{ "id": "...", "type": "function", "function": { "name": "...", "arguments": {...} } }</tool_call>
+    static const common_regex tool_call_open(R"(<tool_call[^>]*>)");
+
+    if (!builder.syntax().parse_tool_calls) {
+        LOG_DBG("%s: not parse_tool_calls\n", __func__);
+        builder.add_content(builder.consume_rest());
+        return;
+    }
+
+    LOG_DBG("%s: parse_tool_calls\n", __func__);
+
+    // Find all <tool_call></tool_call> blocks
+    while (auto first = builder.try_find_regex(tool_call_open, std::string::npos, /* add_prelude_to_content= */ true)) {
+        builder.move_to(first->groups[0].end);
+        builder.consume_spaces();
+
+        builder.try_consume_literal("```json");
+        builder.try_consume_literal("```");
+        builder.consume_spaces();
+
+        // Consume JSON object
+        auto data = builder.consume_json();
+
+        builder.consume_spaces();
+        builder.try_consume_literal("```");
+        builder.consume_spaces();
+
+        if (!builder.try_consume_literal("</tool_call>")) {
+            throw common_chat_msg_partial_exception("incomplete tool call");
+        }
+        builder.consume_spaces();
+
+        // Extract name and arguments
+        std::string name;
+        std::string id;
+        nlohmann::ordered_json arguments;
+
+        const auto extract_args = [&](const nlohmann::ordered_json & obj) -> bool {
+            if (!obj.contains("name") || !obj.contains("arguments")) {
+                return false;
+            }
+            name = obj.at("name").get<std::string>();
+            arguments = obj.at("arguments");
+            if (obj.contains("id") && obj.at("id").is_string()) {
+                id = obj.at("id").get<std::string>();
+            }
+            return true;
+        };
+
+        if (!extract_args(data.json)) {
+            if (data.json.contains("function") && data.json.at("function").is_object()) {
+                auto fn = data.json.at("function");
+                extract_args(fn);
+                if (id.empty() && data.json.contains("id") && data.json.at("id").is_string()) {
+                    id = data.json.at("id").get<std::string>();
+                }
+            }
+        }
+
+        // If name is empty, treat the JSON object as content
+        if (name.empty()) {
+            LOG_DBG("%s: tool call missing name, treating as content\n", __func__);
+            builder.add_content(data.json.dump());
+            continue;
+        }
+
+        std::string args_str = arguments.dump();
+        if (!builder.add_tool_call(name, id, args_str)) {
+            throw common_chat_msg_partial_exception("incomplete tool call");
+        }
+    }
+
+    builder.add_content(builder.consume_rest());
+}
+
+static void common_chat_parse_exaone_moe(common_chat_msg_parser & builder) {
+    LOG_DBG("%s: parsing exaone_moe\n", __func__);
+    // EXAONE MoE outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
+    // First try to parse using the standard reasoning parsing method
+    LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
+
+    auto start_pos = builder.pos();
+    auto found_end_think = builder.try_find_literal("</think>");
+    builder.move_to(start_pos);
+
+    if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
+        LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
+        common_chat_parse_exaone_moe_content(builder);
+    } else if (builder.try_parse_reasoning("<think>", "</think>")) {
+        // If reasoning was parsed successfully, the remaining content is regular content
+        LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
+        common_chat_parse_exaone_moe_content(builder);
+    } else {
+        if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
+          LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
+          common_chat_parse_exaone_moe_content(builder);
+          return;
+        }
+        // If no reasoning tags found, check if we should treat everything as reasoning
+        if (builder.syntax().thinking_forced_open) {
+            // If thinking is forced open but no tags found, treat everything as reasoning
+            LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
+            builder.add_reasoning_content(builder.consume_rest());
+        } else {
+            LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
+            common_chat_parse_exaone_moe_content(builder);
+        }
+    }
+}
+
 static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
     builder.try_parse_reasoning("<think>", "</think>");
     builder.add_content(builder.consume_rest());
@@ -1490,6 +1602,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
         case COMMON_CHAT_FORMAT_SOLAR_OPEN:
             common_chat_parse_solar_open(builder);
             break;
+        case COMMON_CHAT_FORMAT_EXAONE_MOE:
+            common_chat_parse_exaone_moe(builder);
+            break;
         default:
             throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
     }

+ 67 - 0
common/chat.cpp

@@ -670,6 +670,7 @@ const char * common_chat_format_name(common_chat_format format) {
         case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
         case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
         case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
+        case COMMON_CHAT_FORMAT_EXAONE_MOE: return "EXAONE MoE";
         case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
         case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
         case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
@@ -2539,6 +2540,65 @@ static common_chat_params common_chat_params_init_solar_open(const common_chat_t
     return data;
 }
 
+static common_chat_params common_chat_params_init_exaone_moe(const common_chat_template & tmpl, const struct templates_params & inputs) {
+    common_chat_params data;
+
+    data.prompt = apply(tmpl, inputs);
+    data.format = COMMON_CHAT_FORMAT_EXAONE_MOE;
+    if (string_ends_with(data.prompt, "<think>\n")) {
+        if (!inputs.enable_thinking) {
+            data.prompt += "</think>\n\n";
+        } else {
+            data.thinking_forced_open = true;
+        }
+    }
+
+    if (inputs.tools.is_array() && !inputs.tools.empty()) {
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
+        data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+            std::vector<std::string> tool_rules;
+            foreach_function(inputs.tools, [&](const json & tool) {
+                const auto & function = tool.at("function");
+                std::string name = function.at("name");
+                auto parameters = function.at("parameters");
+                builder.resolve_refs(parameters);
+                // Expect: <tool_call>{"name": "<name>", "arguments": {...}}</tool_call>
+                tool_rules.push_back(builder.add_rule(
+                    name + "-call",
+                    "\"<tool_call>\" space " +
+                        builder.add_schema(name + "-obj", json{
+                            {"type", "object"},
+                            {"properties", {
+                                {"name",      json{{"const", name}}},
+                                {"arguments", parameters},
+                            }},
+                            {"required", json::array({"name", "arguments"})},
+                        }) +
+                    " space \"</tool_call>\" space"));
+            });
+
+            auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
+            builder.add_rule("root",
+                std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
+                (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
+
+            data.grammar_triggers.push_back({
+                COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+                std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)?" : "") +
+                    "(<tool_call>)[\\s\\S]*"
+            });
+            data.preserved_tokens = {
+                "<think>",
+                "</think>",
+                "<tool_call>",
+                "</tool_call>",
+            };
+        });
+    }
+
+    return data;
+}
+
 static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
     data.prompt = apply(tmpl, inputs);
@@ -2709,6 +2769,13 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_xiaomi_mimo(tmpl, params);
     }
 
+    // EXAONE MoE format detection
+    if (src.find("<tool_call>") != std::string::npos &&
+        src.find("<tool_result>") != std::string::npos &&
+        src.find("<|tool_declare|>") != std::string::npos) {
+        return common_chat_params_init_exaone_moe(tmpl, params);
+    }
+
     // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
     if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
         return common_chat_params_init_hermes_2_pro(tmpl, params);

+ 1 - 0
common/chat.h

@@ -125,6 +125,7 @@ enum common_chat_format {
     COMMON_CHAT_FORMAT_APRIEL_1_5,
     COMMON_CHAT_FORMAT_XIAOMI_MIMO,
     COMMON_CHAT_FORMAT_SOLAR_OPEN,
+    COMMON_CHAT_FORMAT_EXAONE_MOE,
 
     // These are intended to be parsed by the PEG parser
     COMMON_CHAT_FORMAT_PEG_SIMPLE,

+ 103 - 0
convert_hf_to_gguf.py

@@ -1252,6 +1252,9 @@ class TextModel(ModelBase):
         if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
             # ref: https://huggingface.co/upstage/Solar-Open-100B
             res = "solar-open"
+        if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f":
+            # ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B
+            res = "exaone-moe"
 
         if res is None:
             logger.warning("\n")
@@ -8748,6 +8751,106 @@ class Exaone4Model(TextModel):
                 yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
 
 
+@ModelBase.register("ExaoneMoEForCausalLM")
+class ExaoneMoEModel(Exaone4Model):
+    model_arch = gguf.MODEL_ARCH.EXAONE_MOE
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
+        self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_expert_count(self.hparams["num_experts"])
+        moe_intermediate_size = self.hparams["moe_intermediate_size"]
+        num_shared_experts = self.hparams["num_shared_experts"]
+        self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+        self.gguf_writer.add_expert_shared_count(num_shared_experts)
+        self.gguf_writer.add_expert_shared_feed_forward_length(moe_intermediate_size * num_shared_experts)
+        self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
+        self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
+        n_dense_layer = self.hparams.get("first_k_dense_replace", self.hparams.get("first_last_k_dense_replace", 0))
+        self.gguf_writer.add_leading_dense_block_count(n_dense_layer)
+        # For here, we hard-code the number of NextN/MTP layers to 1 for K-EXAONE,
+        # so that we can convert MTP weights to GGUF format for speculative decoding.
+        # This is because HF config of K-EXAONE does not have `num_nextn_predict_layers` at now.
+        # Will be updated when HF config is updated.
+        self.gguf_writer.add_nextn_predict_layers(self.hparams.get("num_nextn_predict_layers", 1))
+
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if name.startswith("mtp."):
+            if name.find("layers.") != -1:
+                # `mtp.layers.0.[module_name]` format
+                name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + self.hparams['num_hidden_layers']}")
+            else:
+                # mtp fc/norm weights
+                remapper = {
+                    "mtp.fc": "model.layers.{bid}.eh_proj",
+                    "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
+                    "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
+                    "mtp.norm": "model.layers.{bid}.shared_head.norm",
+                }
+                _n = Path(name)
+                new_name = remapper[_n.stem] + _n.suffix
+
+                # set shared weights for all NextN/MTP layers
+                tensors = []
+                for bid in range(self.hparams['num_hidden_layers'], self.block_count):
+                    new_name = new_name.format(bid=bid)
+                    tensors.append((self.map_tensor_name(new_name), data_torch))
+                return tensors
+
+        if name.endswith("e_score_correction_bias"):
+            name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+        if name.find("mlp.experts") != -1:
+            n_experts = self.hparams["num_experts"]
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for w_name in ["down_proj", "gate_proj", "up_proj"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+
+                    tensors.append((new_name, data_torch))
+                return tensors
+            else:
+                return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 @ModelBase.register("GraniteForCausalLM")
 class GraniteModel(LlamaModel):
     """Conversion for IBM's GraniteForCausalLM"""

+ 1 - 0
convert_hf_to_gguf_update.py

@@ -147,6 +147,7 @@ models = [
     {"name": "kormo",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
     {"name": "youtu",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
     {"name": "solar-open",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
+    {"name": "exaone-moe",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
 ]
 
 # some models are known to be broken upstream, so we will skip them as exceptions

+ 34 - 0
gguf-py/gguf/constants.py

@@ -424,6 +424,7 @@ class MODEL_ARCH(IntEnum):
     NEMOTRON_H_MOE   = auto()
     EXAONE           = auto()
     EXAONE4          = auto()
+    EXAONE_MOE       = auto()
     GRANITE          = auto()
     GRANITE_MOE      = auto()
     GRANITE_HYBRID   = auto()
@@ -843,6 +844,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.NEMOTRON_H_MOE:   "nemotron_h_moe",
     MODEL_ARCH.EXAONE:           "exaone",
     MODEL_ARCH.EXAONE4:          "exaone4",
+    MODEL_ARCH.EXAONE_MOE:       "exaone-moe",
     MODEL_ARCH.GRANITE:          "granite",
     MODEL_ARCH.GRANITE_MOE:      "granitemoe",
     MODEL_ARCH.GRANITE_HYBRID:   "granitehybrid",
@@ -2754,6 +2756,38 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_UP,
         MODEL_TENSOR.FFN_POST_NORM,
     ],
+    MODEL_ARCH.EXAONE_MOE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
+        # NextN/MTP tensors - preserved but unused
+        MODEL_TENSOR.NEXTN_EH_PROJ,
+        MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+        MODEL_TENSOR.NEXTN_ENORM,
+        MODEL_TENSOR.NEXTN_HNORM,
+        MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+        MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+    ],
     MODEL_ARCH.GRANITE: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,

+ 3 - 2
gguf-py/gguf/tensor_mapping.py

@@ -436,7 +436,8 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.expert_bias",                           # afmoe
             "model.layers.{bid}.feed_forward.expert_bias",                  # lfm2moe
             "model.layers.{bid}.block_sparse_moe.e_score_correction",       # minimax-m2
-            "backbone.layers.{bid}.mixer.gate.e_score_correction"           # nemotron-h-moe
+            "backbone.layers.{bid}.mixer.gate.e_score_correction",          # nemotron-h-moe
+            "model.layers.{bid}.mlp.e_score_correction",                    # exaone-moe
         ),
 
         # Feed-forward up
@@ -1794,7 +1795,7 @@ class TensorNameMap:
             "model.embed_audio.soft_embedding_norm", # gemma3n
         ),
 
-        # NextN/MTP tensors for GLM4_MOE
+        # NextN/MTP tensors
         MODEL_TENSOR.NEXTN_EH_PROJ: (
             "model.layers.{bid}.eh_proj",
         ),

+ 1 - 0
src/CMakeLists.txt

@@ -62,6 +62,7 @@ add_library(llama
             models/ernie4-5.cpp
             models/exaone.cpp
             models/exaone4.cpp
+            models/exaone-moe.cpp
             models/falcon-h1.cpp
             models/falcon.cpp
             models/gemma-embedding.cpp

+ 33 - 0
src/llama-arch.cpp

@@ -81,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_NEMOTRON_H_MOE,   "nemotron_h_moe"   },
     { LLM_ARCH_EXAONE,           "exaone"           },
     { LLM_ARCH_EXAONE4,          "exaone4"          },
+    { LLM_ARCH_EXAONE_MOE,       "exaone-moe"       },
     { LLM_ARCH_RWKV6,            "rwkv6"            },
     { LLM_ARCH_RWKV6QWEN2,       "rwkv6qwen2"       },
     { LLM_ARCH_RWKV7,            "rwkv7"            },
@@ -1728,6 +1729,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_FFN_UP,
                 LLM_TENSOR_FFN_POST_NORM,
             };
+        case LLM_ARCH_EXAONE_MOE:
+            return {
+                LLM_TENSOR_TOKEN_EMBD,
+                LLM_TENSOR_OUTPUT_NORM,
+                LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_ROPE_FREQS,
+                LLM_TENSOR_ATTN_NORM,
+                LLM_TENSOR_ATTN_Q,
+                LLM_TENSOR_ATTN_Q_NORM,
+                LLM_TENSOR_ATTN_K,
+                LLM_TENSOR_ATTN_K_NORM,
+                LLM_TENSOR_ATTN_V,
+                LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_FFN_NORM,
+                LLM_TENSOR_FFN_GATE,
+                LLM_TENSOR_FFN_DOWN,
+                LLM_TENSOR_FFN_UP,
+                LLM_TENSOR_FFN_GATE_INP,
+                LLM_TENSOR_FFN_GATE_EXPS,
+                LLM_TENSOR_FFN_DOWN_EXPS,
+                LLM_TENSOR_FFN_UP_EXPS,
+                LLM_TENSOR_FFN_GATE_SHEXP,
+                LLM_TENSOR_FFN_UP_SHEXP,
+                LLM_TENSOR_FFN_DOWN_SHEXP,
+                LLM_TENSOR_FFN_EXP_PROBS_B,
+                LLM_TENSOR_NEXTN_EH_PROJ,
+                LLM_TENSOR_NEXTN_EMBED_TOKENS,
+                LLM_TENSOR_NEXTN_ENORM,
+                LLM_TENSOR_NEXTN_HNORM,
+                LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
+                LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
+            };
         case LLM_ARCH_RWKV6:
             return {
                 LLM_TENSOR_TOKEN_EMBD,

+ 1 - 0
src/llama-arch.h

@@ -85,6 +85,7 @@ enum llm_arch {
     LLM_ARCH_NEMOTRON_H_MOE,
     LLM_ARCH_EXAONE,
     LLM_ARCH_EXAONE4,
+    LLM_ARCH_EXAONE_MOE,
     LLM_ARCH_RWKV6,
     LLM_ARCH_RWKV6QWEN2,
     LLM_ARCH_RWKV7,

+ 20 - 0
src/llama-chat.cpp

@@ -57,6 +57,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "minicpm",           LLM_CHAT_TEMPLATE_MINICPM           },
     { "exaone3",           LLM_CHAT_TEMPLATE_EXAONE_3          },
     { "exaone4",           LLM_CHAT_TEMPLATE_EXAONE_4          },
+    { "exaone-moe",        LLM_CHAT_TEMPLATE_EXAONE_MOE        },
     { "rwkv-world",        LLM_CHAT_TEMPLATE_RWKV_WORLD        },
     { "granite",           LLM_CHAT_TEMPLATE_GRANITE           },
     { "gigachat",          LLM_CHAT_TEMPLATE_GIGACHAT          },
@@ -137,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
     } else if (tmpl_contains("[gMASK]<sop>")) {
         return LLM_CHAT_TEMPLATE_CHATGLM_4;
     } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
+        if (tmpl_contains("<|tool_declare|>")) {
+            return LLM_CHAT_TEMPLATE_EXAONE_MOE;
+        }
         return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
     } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
         return LLM_CHAT_TEMPLATE_GLMEDGE;
@@ -576,6 +580,22 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "[|assistant|]";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) {
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n";
+            } else if (role == "user") {
+                ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n";
+            } else if (role == "assistant") {
+                ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n";
+            } else if (role == "tool") {
+                ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n";
+            }
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
     } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
         // this template requires the model to have "\n\n" as EOT token
         for (size_t i = 0; i < chat.size(); i++) {

+ 1 - 0
src/llama-chat.h

@@ -36,6 +36,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_MINICPM,
     LLM_CHAT_TEMPLATE_EXAONE_3,
     LLM_CHAT_TEMPLATE_EXAONE_4,
+    LLM_CHAT_TEMPLATE_EXAONE_MOE,
     LLM_CHAT_TEMPLATE_RWKV_WORLD,
     LLM_CHAT_TEMPLATE_GRANITE,
     LLM_CHAT_TEMPLATE_GIGACHAT,

+ 115 - 0
src/llama-model.cpp

@@ -1933,6 +1933,38 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_EXAONE_MOE:
+            {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                hparams.n_swa = 128;
+                hparams.set_swa_pattern(4);
+                hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA,                hparams.rope_freq_base_train_swa, false);
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,          hparams.n_swa, true);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_COUNT,                      hparams.n_expert);
+                ml.get_key(LLM_KV_EXPERT_USED_COUNT,                 hparams.n_expert_used);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,               hparams.n_expert_shared, false);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+                ml.get_key(LLM_KV_EXPERT_GROUP_COUNT,                hparams.n_expert_groups, false);
+                ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT,           hparams.n_group_used, false);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,                hparams.expert_gating_func, false);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,              hparams.expert_weights_scale, false);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,               hparams.expert_weights_norm, false);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);
+
+                ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,              hparams.nextn_predict_layers, false);
+
+                switch (hparams.n_layer) {
+                    case 32: type = LLM_TYPE_30B_A3B; break;
+                    case 48:
+                    case 49: type = LLM_TYPE_235B_A22B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_RWKV6:
         case LLM_ARCH_RWKV6QWEN2:
             {
@@ -5516,6 +5548,84 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_post_norm  = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_EXAONE_MOE:
+                {
+                    const int64_t n_ff_exp       = hparams.n_ff_exp;
+                    const int64_t n_expert       = hparams.n_expert;
+                    const int64_t n_expert_used  = hparams.n_expert_used;
+                    const int64_t n_ff_shexp     = hparams.n_ff_shexp;
+                    const int64_t head_dim       = hparams.n_embd_head_k;
+                    const int64_t n_qo_dim       = n_head * head_dim;
+                    const int64_t n_kv_dim       = n_head_kv * head_dim;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        int flags = 0;
+                        if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+                            // skip all tensors in the NextN layers
+                            flags |= TENSOR_SKIP;
+                        }
+
+                        auto & layer = layers[i];
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_qo_dim}, flags);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_kv_dim}, flags);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_kv_dim}, flags);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags);
+
+                        layer.rope_freqs   = create_tensor(tn(LLM_TENSOR_ROPE_FREQS,  "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags);
+
+                        layer.attn_norm    = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, flags);
+                        layer.attn_q_norm  = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
+                        layer.attn_k_norm  = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
+
+                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,    "weight", i), {n_embd}, flags);
+
+                        // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end
+                        if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers)) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, flags);
+                        } else {
+                            layer.ffn_gate_inp    = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, flags);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            layer.ffn_gate_exps  = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS,  "weight", i), {n_embd, n_ff_exp, n_expert}, flags);
+                            layer.ffn_down_exps  = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS,  "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
+                            layer.ffn_up_exps    = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,    "weight", i), {n_embd, n_ff_exp, n_expert}, flags);
+
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_shexp}, flags);
+                        }
+
+                        // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
+                        if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+                            layer.nextn.eh_proj          = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags);
+                            layer.nextn.enorm            = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM,   "weight", i), {n_embd}, flags);
+                            layer.nextn.hnorm            = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM,   "weight", i), {n_embd}, flags);
+
+                            layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED);
+                            layer.nextn.embed_tokens     = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS,     "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED);
+                            layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED);
+                        }
+                    }
+                } break;
             case LLM_ARCH_RWKV6:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7811,6 +7921,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
                     llm = std::make_unique<llm_build_exaone4<false>>(*this, params);
                 }
             } break;
+        case LLM_ARCH_EXAONE_MOE:
+            {
+                llm = std::make_unique<llm_build_exaone_moe>(*this, params);
+            } break;
         case LLM_ARCH_RWKV6:
             {
                 llm = std::make_unique<llm_build_rwkv6>(*this, params);
@@ -8171,6 +8285,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
         case LLM_ARCH_EXAONE4:
+        case LLM_ARCH_EXAONE_MOE:
         case LLM_ARCH_MINICPM3:
         case LLM_ARCH_BAILINGMOE2:
         case LLM_ARCH_DOTS1:

+ 10 - 0
src/llama-vocab.cpp

@@ -461,6 +461,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
+                };
+                break;
             default:
                 // default regex for BPE tokenization pre-processing
                 regex_exprs = {
@@ -1965,6 +1972,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             } else if (
                 tokenizer_pre == "exaone4") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
+            } else if (
+                tokenizer_pre == "exaone-moe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE;
             } else if (
                 tokenizer_pre == "chameleon") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;

+ 1 - 0
src/llama-vocab.h

@@ -53,6 +53,7 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_AFMOE           = 42,
     LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN      = 43,
     LLAMA_VOCAB_PRE_TYPE_YOUTU           = 44,
+    LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE      = 45,
 };
 
 struct LLM_KV;

+ 146 - 0
src/models/exaone-moe.cpp

@@ -0,0 +1,146 @@
+#include "models.h"
+
+
+llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_k;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn_iswa = build_attn_inp_kv_iswa();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < n_transformer_layers; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // use RoPE for SWA layers
+        const bool is_local_layer = hparams.is_swa(il);
+
+        // norm
+        cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+            cb(Kcur, "Kcur_normed", il);
+
+            if (is_local_layer) {
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+                                     freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+                                     freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+            }
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn_iswa,
+                model.layers[il].wo, NULL,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+            cb(cur, "attn_out", il);
+        }
+        if (il == n_transformer_layers - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // norm
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward network
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense branch
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL, NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE branch
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp,
+                model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps,
+                model.layers[il].ffn_down_exps,
+                model.layers[il].ffn_exp_probs_b,
+                n_expert, n_expert_used,
+                LLM_FFN_SILU, hparams.expert_weights_norm,
+                true, hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // FFN shared expert
+            {
+                ggml_tensor * ffn_shexp =
+                    build_ffn(cur,
+                        model.layers[il].ffn_up_shexp, NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // final norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}

+ 4 - 0
src/models/models.h

@@ -167,6 +167,10 @@ struct llm_build_exaone : public llm_graph_context {
     llm_build_exaone(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_exaone_moe : public llm_graph_context {
+    llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_falcon : public llm_graph_context {
     llm_build_falcon(const llama_model & model, const llm_graph_params & params);
 };