Bläddra i källkod

model : add glm-asr support (#17901)

* [model] add glm-asr support

* fix format for ci

* fix convert format for ci

* update glm_asr convert script & use build_ffn for glm_asr clip & use build_stack for padding and review

* check root architecture for convert hf script

* fix conficlt with upstream

* fix convert script for glm asr & format clip-impl

* format

* restore hparams text

* improved conversion

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
piDack 1 månad sedan
förälder
incheckning
745fa0e78b

+ 84 - 2
convert_hf_to_gguf.py

@@ -713,6 +713,9 @@ class ModelBase:
         if "llm_config" in config:
             # rename for InternVL
             config["text_config"] = config["llm_config"]
+        if "lm_config" in config:
+            # rename for GlmASR
+            config["text_config"] = config["lm_config"]
         if "thinker_config" in config:
             # rename for Qwen2.5-Omni
             config["text_config"] = config["thinker_config"]["text_config"]
@@ -1529,6 +1532,21 @@ class TextModel(ModelBase):
                 raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
             self.gguf_writer.add_pooling_type(pooling_type)
 
+    def _set_vocab_glmedge(self):
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+        tokens, toktypes, tokpre = self.get_vocab_base()
+        self.gguf_writer.add_tokenizer_model("gpt2")
+        self.gguf_writer.add_tokenizer_pre(tokpre)
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+        special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
+        special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab.add_to_gguf(self.gguf_writer)
+
     def _set_vocab_interns1(self):
         tokens: list[str] = []
         toktypes: list[int] = []
@@ -1658,7 +1676,7 @@ class MmprojModel(ModelBase):
     preprocessor_config: dict[str, Any]
     global_config: dict[str, Any]
 
-    n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
+    n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
 
     has_vision_encoder: bool = True # by default
     has_audio_encoder: bool = False
@@ -1734,7 +1752,8 @@ class MmprojModel(ModelBase):
         return self.global_config.get(config_name)
 
     def get_audio_config(self) -> dict[str, Any] | None:
-        return self.global_config.get("audio_config")
+        mm_config_key = "whisper_config" if "whisper_config" in self.hparams else "audio_config"
+        return self.global_config.get(mm_config_key)
 
     def set_type(self):
         self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
@@ -2372,8 +2391,13 @@ class LlamaModel(TextModel):
         # fix for SmolVLM2, missing `num_attention_heads` in config.json
         if self.hf_arch == "VLlama3ForCausalLM":
             self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
+        hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
+        self.origin_hf_arch = hparams.get('architectures', [None])[0]
 
     def set_vocab(self):
+        if self.origin_hf_arch == "GlmasrModel":
+            return self._set_vocab_glmedge()
+
         if self.is_mistral_format:
             return self._set_vocab_mistral()
 
@@ -2444,6 +2468,7 @@ class LlamaModel(TextModel):
             "vision_language_adapter.",
             "patch_merger.",
             "pre_mm_projector_norm",
+            "audio_encoder.",
         ]
 
         is_multimodal_tensor = "vision_tower" in name \
@@ -8846,6 +8871,63 @@ class UltravoxModel(TextModel):
         raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")
 
 
+@ModelBase.register("GlmasrModel")
+class GlmASRWhisperEncoderModel(MmprojModel):
+    has_vision_encoder = False
+    has_audio_encoder = True
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams:
+            self.hparams["hidden_size"] = self.hparams["d_model"]
+            self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
+            self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA)
+        self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
+        self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
+        self.gguf_writer.add_audio_stack_factor(self.global_config["merge_factor"])
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        if ".conv" in name and ".weight" in name:
+            return gguf.GGMLQuantizationType.F16
+        return super().tensor_force_quant(name, new_name, bid, n_dims)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if name.startswith("model.") or name.startswith("lm_head."):
+            # skip language model tensors
+            return []
+
+        if name.startswith("audio_encoder.whisper."):
+            name = name.replace("audio_encoder.whisper.","audio_tower.")
+        if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name:
+            name = name.replace("audio_encoder.", "audio_encoder.adapting.")
+
+        if name.startswith("audio_encoder.audio_bos_eos_token."):
+            return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])]
+
+        if name.startswith("audio_encoder.adapting."):
+            name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.")
+            if ".layer_norm." in name:
+                name = name.replace(".layer_norm.", ".ln_pre.")
+            if ".0." in name:
+                name = name.replace(".0.", ".linear_1.")
+            if ".2." in name:
+                name = name.replace(".2.", ".linear_2.")
+            if ".proj." in name:
+                return []
+
+        if "conv1.bias" in name or "conv2.bias" in name:
+            # transpose conv1 and conv2 bias
+            data_torch = data_torch.unsqueeze(-1)
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+
 @ModelBase.register("Qwen2AudioForConditionalGeneration")
 class WhisperEncoderModel(MmprojModel):
     has_vision_encoder = False # no vision encoder

+ 1 - 0
gguf-py/gguf/constants.py

@@ -3320,6 +3320,7 @@ class VisionProjectorType:
     ULTRAVOX = "ultravox"
     INTERNVL = "internvl"
     QWEN2A = "qwen2a" # audio
+    GLMA = "glma" # audio
     QWEN25O = "qwen2.5o" # omni
     VOXTRAL = "voxtral"
     LFM2 = "lfm2"

+ 4 - 0
tools/mtmd/clip-graph.h

@@ -112,4 +112,8 @@ struct clip_graph {
     // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
     // support dynamic resolution
     ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor);
+
+    // Generic function to stack frames for audio processing
+    // Abstracts out the StackAudioFrames logic used by ultravox
+    ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed);
 };

+ 2 - 0
tools/mtmd/clip-impl.h

@@ -157,6 +157,7 @@ enum projector_type {
     PROJECTOR_TYPE_INTERNVL,
     PROJECTOR_TYPE_LLAMA4,
     PROJECTOR_TYPE_QWEN2A,
+    PROJECTOR_TYPE_GLMA,
     PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
     PROJECTOR_TYPE_VOXTRAL,
     PROJECTOR_TYPE_LFM2,
@@ -183,6 +184,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_INTERNVL,  "internvl"},
     { PROJECTOR_TYPE_LLAMA4,    "llama4"},
     { PROJECTOR_TYPE_QWEN2A,    "qwen2a"},
+    { PROJECTOR_TYPE_GLMA,      "glma"},
     { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
     { PROJECTOR_TYPE_VOXTRAL,   "voxtral"},
     { PROJECTOR_TYPE_LFM2,      "lfm2"},

+ 1 - 0
tools/mtmd/clip-model.h

@@ -256,6 +256,7 @@ struct clip_model {
     ggml_tensor * conv1d_2_w = nullptr;
     ggml_tensor * conv1d_2_b = nullptr;
     ggml_tensor * mm_norm_pre_w = nullptr;
+    ggml_tensor * mm_norm_pre_b = nullptr;
     ggml_tensor * mm_norm_mid_w = nullptr;
 
     // cogvlm

+ 59 - 1
tools/mtmd/clip.cpp

@@ -720,6 +720,32 @@ ggml_tensor * clip_graph::build_rope_2d(
     return cur;
 }
 
+// Generic function to stack frames for audio processing
+// Abstracts out the StackAudioFrames logic used by ultravox
+ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) {
+    if (stack_factor <= 1) {
+        return cur;
+    }
+
+    int64_t total_elements = ggml_nelements(cur);
+    int64_t stride = n_embed * stack_factor;
+
+    // Calculate padded length
+    int64_t padded_len = GGML_PAD(total_elements, stride);
+    int64_t pad = padded_len - total_elements;
+
+    if (pad > 0) {
+        // Pad the tensor to make it divisible by stride
+        cur = ggml_view_1d(ctx0, cur, total_elements, 0);
+        cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
+    }
+
+    // Reshape to [stride, padded_len / stride]
+    cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
+                        ggml_row_size(cur->type, stride), 0);
+    return cur;
+}
+
 // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
 // support dynamic resolution
 ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
@@ -796,6 +822,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
         case PROJECTOR_TYPE_ULTRAVOX:
         case PROJECTOR_TYPE_VOXTRAL:
         case PROJECTOR_TYPE_QWEN2A:
+        case PROJECTOR_TYPE_GLMA:
             {
                 builder = std::make_unique<clip_graph_whisper_enc>(ctx, img);
             } break;
@@ -1136,10 +1163,12 @@ struct clip_model_loader {
                     } break;
                 case PROJECTOR_TYPE_ULTRAVOX:
                 case PROJECTOR_TYPE_QWEN2A:
+                case PROJECTOR_TYPE_GLMA:
                 case PROJECTOR_TYPE_VOXTRAL:
                     {
                         bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
-                                             model.proj_type == PROJECTOR_TYPE_VOXTRAL;
+                                             model.proj_type == PROJECTOR_TYPE_VOXTRAL ||
+                                             model.proj_type == PROJECTOR_TYPE_GLMA;
                         get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
                         if (hparams.n_mel_bins != 128) {
                             throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
@@ -1510,6 +1539,21 @@ struct clip_model_loader {
                     model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
                     model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
                 } break;
+            case PROJECTOR_TYPE_GLMA:
+                {
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
+                    model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
+                    model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
+                    model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias"));
+                    model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight"));
+                    model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight"));
+                } break;
             case PROJECTOR_TYPE_LLAMA4:
                 {
                     model.mm_model_proj    = get_tensor(TN_MM_PROJECTOR);
@@ -2895,6 +2939,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                     n_patches /= 2;
                 }
             } break;
+        case PROJECTOR_TYPE_GLMA:
+            {
+                n_patches = img->nx;
+                // whisper downscales input token by half after conv1d
+                n_patches /= 2;
+                // reshape by merge_factor
+                n_patches /= ctx->model.hparams.proj_stack_factor;
+                // for BOI and EOI token embeddings
+                n_patches += 2;
+            } break;
         case PROJECTOR_TYPE_COGVLM:
             {
                 n_patches += 2; // for BOI and EOI token embeddings
@@ -3230,6 +3284,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         case PROJECTOR_TYPE_IDEFICS3:
         case PROJECTOR_TYPE_INTERNVL:
         case PROJECTOR_TYPE_QWEN2A:
+        case PROJECTOR_TYPE_GLMA:
         case PROJECTOR_TYPE_ULTRAVOX:
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_VOXTRAL:
@@ -3340,6 +3395,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->model.mm_model_proj->ne[1];
         case PROJECTOR_TYPE_QWEN2A:
             return ctx->model.mm_fc_w->ne[1];
+        case PROJECTOR_TYPE_GLMA:
+            return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
             return ctx->model.mm_2_w->ne[1];
@@ -3386,6 +3443,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
 bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
     return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
         || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
+        || ctx->proj_type() == PROJECTOR_TYPE_GLMA
         || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
 }
 

+ 9 - 10
tools/mtmd/models/whisper-enc.cpp

@@ -30,7 +30,6 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
     GGML_ASSERT(model.layers[0].q_b);
     GGML_ASSERT(model.layers[0].v_b);
     GGML_ASSERT(!model.layers[0].k_b); // no bias for k
-    GGML_ASSERT(model.post_ln_w && model.post_ln_b);
 
     ggml_tensor * pos_embd_selected = ggml_view_2d(
         ctx0, model.position_embeddings,
@@ -49,15 +48,7 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
     if (model.audio_has_stack_frames()) {
         // StackAudioFrames
         // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
-        int64_t stride = n_embd * hparams.proj_stack_factor;
-        int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
-        int64_t pad = padded_len - ggml_nelements(cur);
-        if (pad > 0) {
-            cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
-            cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
-        }
-        cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
-                            ggml_row_size(cur->type, stride), 0);
+        cur = build_stack(cur, hparams.proj_stack_factor, n_embd);
         cb(cur, "after_stacked", -1);
     }
 
@@ -95,6 +86,14 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
             FFN_GELU_ERF,
             -1);
 
+    } else if (proj_type == PROJECTOR_TYPE_GLMA) {
+            cur = ggml_norm(ctx0, cur, hparams.eps);
+            cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
+            cur = ggml_add(ctx0, cur, model.mm_norm_pre_b);
+            cur = build_stack(cur, hparams.proj_stack_factor, n_embd);
+            cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_2_w, model.mm_2_b, hparams.ffn_op, 0);
+            cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
+            cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
     } else {
         GGML_ABORT("%s: unknown projector type", __func__);
     }