|
|
@@ -1,9 +1,9 @@
|
|
|
-// NOTE: This is modified from clip.cpp only for LLaVA,
|
|
|
-// so there might be still unnecessary artifacts hanging around
|
|
|
-// I'll gradually clean and extend it
|
|
|
-// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
|
|
#include "clip.h"
|
|
|
#include "clip-impl.h"
|
|
|
+#include "clip-model.h"
|
|
|
+#include "clip-graph.h"
|
|
|
+#include "models/models.h"
|
|
|
+
|
|
|
#include "ggml.h"
|
|
|
#include "ggml-cpp.h"
|
|
|
#include "ggml-alloc.h"
|
|
|
@@ -26,18 +26,6 @@
|
|
|
|
|
|
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
|
|
|
|
|
|
-enum ffn_op_type {
|
|
|
- FFN_GELU,
|
|
|
- FFN_GELU_ERF,
|
|
|
- FFN_SILU,
|
|
|
- FFN_GELU_QUICK,
|
|
|
-};
|
|
|
-
|
|
|
-enum norm_type {
|
|
|
- NORM_TYPE_NORMAL,
|
|
|
- NORM_TYPE_RMS,
|
|
|
-};
|
|
|
-
|
|
|
//#define CLIP_DEBUG_FUNCTIONS
|
|
|
|
|
|
#ifdef CLIP_DEBUG_FUNCTIONS
|
|
|
@@ -149,267 +137,6 @@ static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u
|
|
|
#endif
|
|
|
|
|
|
|
|
|
-//
|
|
|
-// clip layers
|
|
|
-//
|
|
|
-
|
|
|
-enum patch_merge_type {
|
|
|
- PATCH_MERGE_FLAT,
|
|
|
- PATCH_MERGE_SPATIAL_UNPAD,
|
|
|
-};
|
|
|
-
|
|
|
-struct clip_hparams {
|
|
|
- int32_t image_size = 0;
|
|
|
- int32_t patch_size = 0;
|
|
|
- int32_t n_embd = 0;
|
|
|
- int32_t n_ff = 0;
|
|
|
- int32_t projection_dim = 0;
|
|
|
- int32_t n_head = 0;
|
|
|
- int32_t n_layer = 0;
|
|
|
- // idefics3
|
|
|
- int32_t image_longest_edge = 0;
|
|
|
- int32_t image_min_pixels = -1;
|
|
|
- int32_t image_max_pixels = -1;
|
|
|
- int32_t n_merge = 0; // number of patch merges **per-side**
|
|
|
-
|
|
|
- float image_mean[3];
|
|
|
- float image_std[3];
|
|
|
-
|
|
|
- // for models using dynamic image size, we need to have a smaller image size to warmup
|
|
|
- // otherwise, user will get OOM everytime they load the model
|
|
|
- int32_t warmup_image_size = 0;
|
|
|
- int32_t warmup_audio_size = 3000;
|
|
|
-
|
|
|
- ffn_op_type ffn_op = FFN_GELU;
|
|
|
-
|
|
|
- patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
|
|
|
-
|
|
|
- float eps = 1e-6;
|
|
|
- float rope_theta = 0.0;
|
|
|
-
|
|
|
- std::vector<clip_image_size> image_res_candidates; // for llava-uhd style models
|
|
|
- int32_t image_crop_resolution;
|
|
|
- std::unordered_set<int32_t> vision_feature_layer;
|
|
|
- int32_t attn_window_size = 0;
|
|
|
- int32_t n_wa_pattern = 0;
|
|
|
-
|
|
|
- // audio
|
|
|
- int32_t n_mel_bins = 0; // whisper preprocessor
|
|
|
- int32_t proj_stack_factor = 0; // ultravox
|
|
|
-
|
|
|
- // legacy
|
|
|
- bool has_llava_projector = false;
|
|
|
- int minicpmv_version = 0;
|
|
|
- int32_t minicpmv_query_num = 0; // MiniCPM-V query number
|
|
|
-
|
|
|
- // custom value provided by user, can be undefined if not set
|
|
|
- int32_t custom_image_min_tokens = -1;
|
|
|
- int32_t custom_image_max_tokens = -1;
|
|
|
-
|
|
|
- void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
|
|
|
- const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
|
|
- const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
|
|
|
- image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
|
|
|
- image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
|
|
|
- warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
|
|
|
- }
|
|
|
-
|
|
|
- void set_warmup_n_tokens(int n_tokens) {
|
|
|
- int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
|
|
|
- GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
|
|
|
- const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
|
|
- warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
|
|
- // TODO: support warmup size for custom token numbers
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
-struct clip_layer {
|
|
|
- // attention
|
|
|
- ggml_tensor * k_w = nullptr;
|
|
|
- ggml_tensor * k_b = nullptr;
|
|
|
- ggml_tensor * q_w = nullptr;
|
|
|
- ggml_tensor * q_b = nullptr;
|
|
|
- ggml_tensor * v_w = nullptr;
|
|
|
- ggml_tensor * v_b = nullptr;
|
|
|
- ggml_tensor * qkv_w = nullptr;
|
|
|
- ggml_tensor * qkv_b = nullptr;
|
|
|
-
|
|
|
- ggml_tensor * o_w = nullptr;
|
|
|
- ggml_tensor * o_b = nullptr;
|
|
|
-
|
|
|
- ggml_tensor * k_norm = nullptr;
|
|
|
- ggml_tensor * q_norm = nullptr;
|
|
|
-
|
|
|
- // layernorm 1
|
|
|
- ggml_tensor * ln_1_w = nullptr;
|
|
|
- ggml_tensor * ln_1_b = nullptr;
|
|
|
-
|
|
|
- ggml_tensor * ff_up_w = nullptr;
|
|
|
- ggml_tensor * ff_up_b = nullptr;
|
|
|
- ggml_tensor * ff_gate_w = nullptr;
|
|
|
- ggml_tensor * ff_gate_b = nullptr;
|
|
|
- ggml_tensor * ff_down_w = nullptr;
|
|
|
- ggml_tensor * ff_down_b = nullptr;
|
|
|
-
|
|
|
- // layernorm 2
|
|
|
- ggml_tensor * ln_2_w = nullptr;
|
|
|
- ggml_tensor * ln_2_b = nullptr;
|
|
|
-
|
|
|
- // layer scale (no bias)
|
|
|
- ggml_tensor * ls_1_w = nullptr;
|
|
|
- ggml_tensor * ls_2_w = nullptr;
|
|
|
-
|
|
|
- // qwen3vl deepstack merger
|
|
|
- ggml_tensor * deepstack_norm_w = nullptr;
|
|
|
- ggml_tensor * deepstack_norm_b = nullptr;
|
|
|
- ggml_tensor * deepstack_fc1_w = nullptr;
|
|
|
- ggml_tensor * deepstack_fc1_b = nullptr;
|
|
|
- ggml_tensor * deepstack_fc2_w = nullptr;
|
|
|
- ggml_tensor * deepstack_fc2_b = nullptr;
|
|
|
-
|
|
|
- bool has_deepstack() const {
|
|
|
- return deepstack_fc1_w != nullptr;
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
-struct clip_model {
|
|
|
- clip_modality modality = CLIP_MODALITY_VISION;
|
|
|
- projector_type proj_type = PROJECTOR_TYPE_MLP;
|
|
|
- clip_hparams hparams;
|
|
|
-
|
|
|
- // embeddings
|
|
|
- ggml_tensor * class_embedding = nullptr;
|
|
|
- ggml_tensor * patch_embeddings_0 = nullptr;
|
|
|
- ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
|
|
|
- ggml_tensor * patch_bias = nullptr;
|
|
|
- ggml_tensor * position_embeddings = nullptr;
|
|
|
-
|
|
|
- ggml_tensor * pre_ln_w = nullptr;
|
|
|
- ggml_tensor * pre_ln_b = nullptr;
|
|
|
-
|
|
|
- std::vector<clip_layer> layers;
|
|
|
-
|
|
|
- int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer
|
|
|
-
|
|
|
- ggml_tensor * post_ln_w;
|
|
|
- ggml_tensor * post_ln_b;
|
|
|
-
|
|
|
- ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
|
|
- ggml_tensor * mm_fc_w;
|
|
|
- ggml_tensor * mm_fc_b;
|
|
|
-
|
|
|
- // LLaVA projection
|
|
|
- ggml_tensor * mm_input_norm_w = nullptr;
|
|
|
- ggml_tensor * mm_input_norm_b = nullptr;
|
|
|
- ggml_tensor * mm_0_w = nullptr;
|
|
|
- ggml_tensor * mm_0_b = nullptr;
|
|
|
- ggml_tensor * mm_2_w = nullptr;
|
|
|
- ggml_tensor * mm_2_b = nullptr;
|
|
|
-
|
|
|
- ggml_tensor * image_newline = nullptr;
|
|
|
-
|
|
|
- // Yi type models with mlp+normalization projection
|
|
|
- ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
|
|
|
- ggml_tensor * mm_1_b = nullptr;
|
|
|
- ggml_tensor * mm_3_w = nullptr;
|
|
|
- ggml_tensor * mm_3_b = nullptr;
|
|
|
- ggml_tensor * mm_4_w = nullptr;
|
|
|
- ggml_tensor * mm_4_b = nullptr;
|
|
|
-
|
|
|
- // GLMV-Edge projection
|
|
|
- ggml_tensor * mm_model_adapter_conv_w = nullptr;
|
|
|
- ggml_tensor * mm_model_adapter_conv_b = nullptr;
|
|
|
-
|
|
|
- // MobileVLM projection
|
|
|
- ggml_tensor * mm_model_mlp_1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_mlp_1_b = nullptr;
|
|
|
- ggml_tensor * mm_model_mlp_3_w = nullptr;
|
|
|
- ggml_tensor * mm_model_mlp_3_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
|
|
|
- ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
|
|
|
-
|
|
|
- // MobileVLM_V2 projection
|
|
|
- ggml_tensor * mm_model_mlp_0_w = nullptr;
|
|
|
- ggml_tensor * mm_model_mlp_0_b = nullptr;
|
|
|
- ggml_tensor * mm_model_mlp_2_w = nullptr;
|
|
|
- ggml_tensor * mm_model_mlp_2_b = nullptr;
|
|
|
- ggml_tensor * mm_model_peg_0_w = nullptr;
|
|
|
- ggml_tensor * mm_model_peg_0_b = nullptr;
|
|
|
-
|
|
|
- // MINICPMV projection
|
|
|
- ggml_tensor * mm_model_pos_embed_k = nullptr;
|
|
|
- ggml_tensor * mm_model_query = nullptr;
|
|
|
- ggml_tensor * mm_model_proj = nullptr;
|
|
|
- ggml_tensor * mm_model_kv_proj = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_q_w = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_q_b = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_k_w = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_k_b = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_v_w = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_v_b = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_o_w = nullptr;
|
|
|
- ggml_tensor * mm_model_attn_o_b = nullptr;
|
|
|
- ggml_tensor * mm_model_ln_q_w = nullptr;
|
|
|
- ggml_tensor * mm_model_ln_q_b = nullptr;
|
|
|
- ggml_tensor * mm_model_ln_kv_w = nullptr;
|
|
|
- ggml_tensor * mm_model_ln_kv_b = nullptr;
|
|
|
- ggml_tensor * mm_model_ln_post_w = nullptr;
|
|
|
- ggml_tensor * mm_model_ln_post_b = nullptr;
|
|
|
-
|
|
|
- // gemma3
|
|
|
- ggml_tensor * mm_input_proj_w = nullptr;
|
|
|
- ggml_tensor * mm_soft_emb_norm_w = nullptr;
|
|
|
-
|
|
|
- // pixtral
|
|
|
- ggml_tensor * token_embd_img_break = nullptr;
|
|
|
- ggml_tensor * mm_patch_merger_w = nullptr;
|
|
|
-
|
|
|
- // ultravox / whisper encoder
|
|
|
- ggml_tensor * conv1d_1_w = nullptr;
|
|
|
- ggml_tensor * conv1d_1_b = nullptr;
|
|
|
- ggml_tensor * conv1d_2_w = nullptr;
|
|
|
- ggml_tensor * conv1d_2_b = nullptr;
|
|
|
- ggml_tensor * mm_norm_pre_w = nullptr;
|
|
|
- ggml_tensor * mm_norm_mid_w = nullptr;
|
|
|
-
|
|
|
- // cogvlm
|
|
|
- ggml_tensor * mm_post_fc_norm_w = nullptr;
|
|
|
- ggml_tensor * mm_post_fc_norm_b = nullptr;
|
|
|
- ggml_tensor * mm_h_to_4h_w = nullptr;
|
|
|
- ggml_tensor * mm_gate_w = nullptr;
|
|
|
- ggml_tensor * mm_4h_to_h_w = nullptr;
|
|
|
- ggml_tensor * mm_boi = nullptr;
|
|
|
- ggml_tensor * mm_eoi = nullptr;
|
|
|
-
|
|
|
- bool audio_has_avgpool() const {
|
|
|
- return proj_type == PROJECTOR_TYPE_QWEN2A
|
|
|
- || proj_type == PROJECTOR_TYPE_VOXTRAL;
|
|
|
- }
|
|
|
-
|
|
|
- bool audio_has_stack_frames() const {
|
|
|
- return proj_type == PROJECTOR_TYPE_ULTRAVOX
|
|
|
- || proj_type == PROJECTOR_TYPE_VOXTRAL;
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
struct clip_ctx {
|
|
|
clip_model model;
|
|
|
|
|
|
@@ -492,2081 +219,613 @@ struct clip_ctx {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-struct clip_graph {
|
|
|
- clip_ctx * ctx;
|
|
|
- const clip_model & model;
|
|
|
- const clip_hparams & hparams;
|
|
|
-
|
|
|
- // we only support single image per batch
|
|
|
- const clip_image_f32 & img;
|
|
|
-
|
|
|
- const int patch_size;
|
|
|
- const int n_patches_x;
|
|
|
- const int n_patches_y;
|
|
|
- const int n_patches;
|
|
|
- const int n_embd;
|
|
|
- const int n_head;
|
|
|
- const int d_head;
|
|
|
- const int n_layer;
|
|
|
- const float eps;
|
|
|
- const float kq_scale;
|
|
|
-
|
|
|
- ggml_context_ptr ctx0_ptr;
|
|
|
- ggml_context * ctx0;
|
|
|
- ggml_cgraph * gf;
|
|
|
-
|
|
|
- clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
|
|
|
- ctx(ctx),
|
|
|
- model(ctx->model),
|
|
|
- hparams(model.hparams),
|
|
|
- img(img),
|
|
|
- patch_size(hparams.patch_size),
|
|
|
- n_patches_x(img.nx / patch_size),
|
|
|
- n_patches_y(img.ny / patch_size),
|
|
|
- n_patches(n_patches_x * n_patches_y),
|
|
|
- n_embd(hparams.n_embd),
|
|
|
- n_head(hparams.n_head),
|
|
|
- d_head(n_embd / n_head),
|
|
|
- n_layer(hparams.n_layer),
|
|
|
- eps(hparams.eps),
|
|
|
- kq_scale(1.0f / sqrtf((float)d_head)) {
|
|
|
- struct ggml_init_params params = {
|
|
|
- /*.mem_size =*/ ctx->buf_compute_meta.size(),
|
|
|
- /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
|
|
- /*.no_alloc =*/ true,
|
|
|
- };
|
|
|
- ctx0_ptr.reset(ggml_init(params));
|
|
|
- ctx0 = ctx0_ptr.get();
|
|
|
- gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
|
|
|
- }
|
|
|
-
|
|
|
- ggml_cgraph * build_siglip() {
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
-
|
|
|
- ggml_tensor * learned_pos_embd = model.position_embeddings;
|
|
|
- if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
|
|
|
- learned_pos_embd = resize_position_embeddings();
|
|
|
- }
|
|
|
-
|
|
|
- ggml_tensor * cur = build_vit(
|
|
|
- inp, n_patches,
|
|
|
- NORM_TYPE_NORMAL,
|
|
|
- hparams.ffn_op,
|
|
|
- learned_pos_embd,
|
|
|
- nullptr);
|
|
|
-
|
|
|
- if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
|
|
|
- const int batch_size = 1;
|
|
|
- GGML_ASSERT(n_patches_x == n_patches_y);
|
|
|
- const int patches_per_image = n_patches_x;
|
|
|
- const int kernel_size = hparams.n_merge;
|
|
|
-
|
|
|
- cur = ggml_transpose(ctx0, cur);
|
|
|
- cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
|
|
|
-
|
|
|
- // doing a pool2d to reduce the number of output tokens
|
|
|
- cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
|
|
|
- cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size);
|
|
|
- cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
-
|
|
|
- // apply norm before projection
|
|
|
- cur = ggml_rms_norm(ctx0, cur, eps);
|
|
|
- cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
|
|
|
-
|
|
|
- // apply projection
|
|
|
- cur = ggml_mul_mat(ctx0,
|
|
|
- ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
|
|
|
- cur);
|
|
|
-
|
|
|
- } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
|
|
|
- // pixel_shuffle
|
|
|
- // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
|
|
- const int scale_factor = model.hparams.n_merge;
|
|
|
- cur = build_patch_merge_permute(cur, scale_factor);
|
|
|
- cur = ggml_mul_mat(ctx0, model.projection, cur);
|
|
|
-
|
|
|
- } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
|
|
|
- // pixel unshuffle block
|
|
|
- const int scale_factor = model.hparams.n_merge;
|
|
|
- cur = build_patch_merge_permute(cur, scale_factor);
|
|
|
-
|
|
|
- // projection
|
|
|
- cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
|
|
|
- cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
|
|
|
- cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
|
|
|
-
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_2_w, model.mm_2_b,
|
|
|
- FFN_GELU,
|
|
|
- -1);
|
|
|
-
|
|
|
- } else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) {
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.mm_0_w, model.mm_0_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- hparams.ffn_op,
|
|
|
- -1);
|
|
|
-
|
|
|
- } else {
|
|
|
- GGML_ABORT("SigLIP: Unsupported projector type");
|
|
|
- }
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
- }
|
|
|
-
|
|
|
- ggml_cgraph * build_pixtral() {
|
|
|
- const int n_merge = hparams.n_merge;
|
|
|
-
|
|
|
- // 2D input positions
|
|
|
- ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
|
|
- ggml_set_name(pos_h, "pos_h");
|
|
|
- ggml_set_input(pos_h);
|
|
|
-
|
|
|
- ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
|
|
- ggml_set_name(pos_w, "pos_w");
|
|
|
- ggml_set_input(pos_w);
|
|
|
-
|
|
|
- auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
|
|
- return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
|
|
|
- };
|
|
|
-
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
- ggml_tensor * cur = build_vit(
|
|
|
- inp, n_patches,
|
|
|
- NORM_TYPE_RMS,
|
|
|
- hparams.ffn_op,
|
|
|
- nullptr, // no learned pos embd
|
|
|
- add_pos);
|
|
|
-
|
|
|
- // mistral small 3.1 patch merger
|
|
|
- // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
|
|
|
- if (model.mm_patch_merger_w) {
|
|
|
- GGML_ASSERT(hparams.n_merge > 0);
|
|
|
-
|
|
|
- cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
|
|
|
-
|
|
|
- // reshape image tokens to 2D grid
|
|
|
- cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
|
|
|
- cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
|
|
|
- cur = ggml_cont(ctx0, cur);
|
|
|
-
|
|
|
- // torch.nn.functional.unfold is just an im2col under the hood
|
|
|
- // we just need a dummy kernel to make it work
|
|
|
- ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
|
|
|
- cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
|
|
|
-
|
|
|
- // project to n_embd
|
|
|
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
|
|
|
- }
|
|
|
-
|
|
|
- // LlavaMultiModalProjector (always using GELU activation)
|
|
|
- {
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_2_w, model.mm_2_b,
|
|
|
- FFN_GELU,
|
|
|
- -1);
|
|
|
- }
|
|
|
+//
|
|
|
+// clip_graph
|
|
|
+//
|
|
|
|
|
|
- // arrangement of the [IMG_BREAK] token
|
|
|
- if (model.token_embd_img_break) {
|
|
|
- // not efficient, but works
|
|
|
- // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
|
|
|
- // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
|
|
- // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
|
|
|
-
|
|
|
- const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
|
|
- const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
|
|
- const int p_total = p_x * p_y;
|
|
|
- const int n_embd_text = cur->ne[0];
|
|
|
- const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
|
|
|
-
|
|
|
- ggml_tensor * tmp = ggml_reshape_3d(ctx0, cur, n_embd_text, p_x, p_y);
|
|
|
- ggml_tensor * tok = ggml_new_tensor_3d(ctx0, tmp->type, n_embd_text, 1, p_y);
|
|
|
- tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
|
|
|
- tok = ggml_add(ctx0, tok, model.token_embd_img_break);
|
|
|
- tmp = ggml_concat(ctx0, tmp, tok, 1);
|
|
|
- cur = ggml_view_2d(ctx0, tmp,
|
|
|
- n_embd_text, n_tokens_output,
|
|
|
- ggml_row_size(tmp->type, n_embd_text), 0);
|
|
|
- }
|
|
|
+clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
|
|
|
+ model(ctx->model),
|
|
|
+ hparams(model.hparams),
|
|
|
+ proj_type(ctx->proj_type()),
|
|
|
+ img(img),
|
|
|
+ patch_size(hparams.patch_size),
|
|
|
+ n_patches_x(img.nx / patch_size),
|
|
|
+ n_patches_y(img.ny / patch_size),
|
|
|
+ n_patches(n_patches_x * n_patches_y),
|
|
|
+ n_embd(hparams.n_embd),
|
|
|
+ n_head(hparams.n_head),
|
|
|
+ d_head(n_embd / n_head),
|
|
|
+ n_layer(hparams.n_layer),
|
|
|
+ n_mmproj_embd(clip_n_mmproj_embd(ctx)),
|
|
|
+ eps(hparams.eps),
|
|
|
+ kq_scale(1.0f / sqrtf((float)d_head)),
|
|
|
+ flash_attn_type(ctx->flash_attn_type),
|
|
|
+ debug_graph(ctx->debug_graph),
|
|
|
+ debug_print_tensors(ctx->debug_print_tensors) {
|
|
|
+ struct ggml_init_params params = {
|
|
|
+ /*.mem_size =*/ ctx->buf_compute_meta.size(),
|
|
|
+ /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
|
|
+ /*.no_alloc =*/ true,
|
|
|
+ };
|
|
|
+ ctx0_ptr.reset(ggml_init(params));
|
|
|
+ ctx0 = ctx0_ptr.get();
|
|
|
+ gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
|
|
|
+}
|
|
|
|
|
|
- // build the graph
|
|
|
+void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const {
|
|
|
+ if (debug_graph) {
|
|
|
+ ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
|
|
|
+ std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
|
|
|
+ ggml_set_name(cur, cur_name.c_str());
|
|
|
+ ggml_set_output(cur);
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
+ debug_print_tensors.push_back(cur);
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- // Qwen2VL and Qwen2.5VL use M-RoPE
|
|
|
- ggml_cgraph * build_qwen2vl() {
|
|
|
- GGML_ASSERT(model.patch_bias == nullptr);
|
|
|
- GGML_ASSERT(model.class_embedding == nullptr);
|
|
|
-
|
|
|
- const int batch_size = 1;
|
|
|
- const bool use_window_attn = hparams.n_wa_pattern > 0;
|
|
|
- const int n_wa_pattern = hparams.n_wa_pattern;
|
|
|
- const int n_pos = n_patches;
|
|
|
- const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
|
|
|
-
|
|
|
- norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|
|
|
- ? NORM_TYPE_RMS // qwen 2.5 vl
|
|
|
- : NORM_TYPE_NORMAL; // qwen 2 vl
|
|
|
-
|
|
|
- int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
|
|
-
|
|
|
- ggml_tensor * inp_raw = build_inp_raw();
|
|
|
- ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
-
|
|
|
- GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
|
|
- GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
|
|
-
|
|
|
- // second conv dimension
|
|
|
- {
|
|
|
- auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
- inp = ggml_add(ctx0, inp, inp_1);
|
|
|
-
|
|
|
- inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
|
|
- inp = ggml_cont_4d(
|
|
|
- ctx0, inp,
|
|
|
- n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
|
|
- inp = ggml_reshape_4d(
|
|
|
- ctx0, inp,
|
|
|
- n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
|
|
- inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
|
|
|
- inp = ggml_cont_3d(
|
|
|
- ctx0, inp,
|
|
|
- n_embd, n_patches_x * n_patches_y, batch_size);
|
|
|
- }
|
|
|
-
|
|
|
- ggml_tensor * inpL = inp;
|
|
|
- ggml_tensor * window_mask = nullptr;
|
|
|
- ggml_tensor * window_idx = nullptr;
|
|
|
- ggml_tensor * inv_window_idx = nullptr;
|
|
|
-
|
|
|
- ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
|
|
- ggml_set_name(positions, "positions");
|
|
|
- ggml_set_input(positions);
|
|
|
-
|
|
|
- // pre-layernorm
|
|
|
- if (model.pre_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
|
|
|
- }
|
|
|
-
|
|
|
- if (use_window_attn) {
|
|
|
- // handle window attention inputs
|
|
|
- inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
|
|
|
- ggml_set_name(inv_window_idx, "inv_window_idx");
|
|
|
- ggml_set_input(inv_window_idx);
|
|
|
- // mask for window attention
|
|
|
- window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
|
|
|
- ggml_set_name(window_mask, "window_mask");
|
|
|
- ggml_set_input(window_mask);
|
|
|
-
|
|
|
- // if flash attn is used, we need to pad the mask and cast to f16
|
|
|
- if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
|
|
- window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
|
|
|
- }
|
|
|
-
|
|
|
- // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
|
|
|
- GGML_ASSERT(batch_size == 1);
|
|
|
- inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
|
|
|
- inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
|
|
|
- inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
|
|
|
- }
|
|
|
-
|
|
|
- // loop over layers
|
|
|
- for (int il = 0; il < n_layer; il++) {
|
|
|
- const auto & layer = model.layers[il];
|
|
|
- const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
|
|
|
-
|
|
|
- ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- // layernorm1
|
|
|
- cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
|
|
|
- cb(cur, "ln1", il);
|
|
|
-
|
|
|
- // self-attention
|
|
|
- {
|
|
|
- ggml_tensor * Qcur = ggml_add(ctx0,
|
|
|
- ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
|
|
|
- ggml_tensor * Kcur = ggml_add(ctx0,
|
|
|
- ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
|
|
|
- ggml_tensor * Vcur = ggml_add(ctx0,
|
|
|
- ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
|
|
|
-
|
|
|
- Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
|
|
|
- Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
|
|
|
- Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
|
|
|
-
|
|
|
- cb(Qcur, "Qcur", il);
|
|
|
- cb(Kcur, "Kcur", il);
|
|
|
- cb(Vcur, "Vcur", il);
|
|
|
-
|
|
|
- // apply M-RoPE
|
|
|
- Qcur = ggml_rope_multi(
|
|
|
- ctx0, Qcur, positions, nullptr,
|
|
|
- d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
|
|
- Kcur = ggml_rope_multi(
|
|
|
- ctx0, Kcur, positions, nullptr,
|
|
|
- d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
|
|
-
|
|
|
- cb(Qcur, "Qcur_rope", il);
|
|
|
- cb(Kcur, "Kcur_rope", il);
|
|
|
-
|
|
|
- ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
|
|
|
-
|
|
|
- cur = build_attn(layer.o_w, layer.o_b,
|
|
|
- Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
|
|
|
- cb(cur, "attn_out", il);
|
|
|
- }
|
|
|
-
|
|
|
- // re-add the layer input, e.g., residual
|
|
|
- cur = ggml_add(ctx0, cur, inpL);
|
|
|
-
|
|
|
- inpL = cur; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- cb(cur, "ffn_inp", il);
|
|
|
-
|
|
|
- // layernorm2
|
|
|
- cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
|
|
- cb(cur, "ffn_inp_normed", il);
|
|
|
-
|
|
|
- // ffn
|
|
|
- cur = build_ffn(cur,
|
|
|
- layer.ff_up_w, layer.ff_up_b,
|
|
|
- layer.ff_gate_w, layer.ff_gate_b,
|
|
|
- layer.ff_down_w, layer.ff_down_b,
|
|
|
- hparams.ffn_op, il);
|
|
|
-
|
|
|
- cb(cur, "ffn_out", il);
|
|
|
-
|
|
|
- // residual 2
|
|
|
- cur = ggml_add(ctx0, inpL, cur);
|
|
|
- cb(cur, "layer_out", il);
|
|
|
-
|
|
|
- inpL = cur;
|
|
|
- }
|
|
|
-
|
|
|
- // post-layernorm
|
|
|
- if (model.post_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
|
|
|
- }
|
|
|
-
|
|
|
- // multimodal projection
|
|
|
- ggml_tensor * embeddings = inpL;
|
|
|
- embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
|
|
|
- embeddings = build_ffn(embeddings,
|
|
|
- model.mm_0_w, model.mm_0_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- FFN_GELU,
|
|
|
- -1);
|
|
|
-
|
|
|
- if (use_window_attn) {
|
|
|
- window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
|
|
|
- ggml_set_name(window_idx, "window_idx");
|
|
|
- ggml_set_input(window_idx);
|
|
|
-
|
|
|
- // embeddings shape: [n_embd, n_patches_x * n_patches_y, batch_size]
|
|
|
- GGML_ASSERT(batch_size == 1);
|
|
|
- embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4);
|
|
|
- embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
|
|
|
- embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size);
|
|
|
- }
|
|
|
+// siglip2 naflex
|
|
|
+ggml_tensor * clip_graph::resize_position_embeddings() {
|
|
|
+ ggml_tensor * pos_embd = model.position_embeddings;
|
|
|
+ const int height = img.ny / patch_size;
|
|
|
+ const int width = img.nx / patch_size;
|
|
|
+ const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS;
|
|
|
+ const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
|
|
|
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, embeddings);
|
|
|
+ GGML_ASSERT(pos_embd);
|
|
|
|
|
|
- return gf;
|
|
|
+ if (height == n_per_side && width == n_per_side) {
|
|
|
+ return pos_embd;
|
|
|
}
|
|
|
|
|
|
- // Qwen3VL
|
|
|
- ggml_cgraph * build_qwen3vl() {
|
|
|
- GGML_ASSERT(model.patch_bias != nullptr);
|
|
|
- GGML_ASSERT(model.position_embeddings != nullptr);
|
|
|
- GGML_ASSERT(model.class_embedding == nullptr);
|
|
|
-
|
|
|
- const int batch_size = 1;
|
|
|
- const int n_pos = n_patches;
|
|
|
- const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
|
|
|
-
|
|
|
- norm_type norm_t = NORM_TYPE_NORMAL;
|
|
|
-
|
|
|
- int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
|
|
+ pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side)
|
|
|
+ pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd)
|
|
|
+ pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
|
|
|
+ pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height)
|
|
|
+ pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height)
|
|
|
|
|
|
- ggml_tensor * inp_raw = build_inp_raw();
|
|
|
- ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
-
|
|
|
- GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
|
|
- GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
|
|
-
|
|
|
- // second conv dimension
|
|
|
- {
|
|
|
- auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
- inp = ggml_add(ctx0, inp, inp_1);
|
|
|
-
|
|
|
- inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
|
|
- inp = ggml_cont_4d(
|
|
|
- ctx0, inp,
|
|
|
- n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
|
|
- inp = ggml_reshape_4d(
|
|
|
- ctx0, inp,
|
|
|
- n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
|
|
- inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
|
|
|
- inp = ggml_cont_3d(
|
|
|
- ctx0, inp,
|
|
|
- n_embd, n_patches_x * n_patches_y, batch_size);
|
|
|
- }
|
|
|
-
|
|
|
- // add patch bias
|
|
|
- if (model.patch_bias != nullptr) {
|
|
|
- inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
- cb(inp, "patch_bias", -1);
|
|
|
- }
|
|
|
+ return pos_embd;
|
|
|
+}
|
|
|
|
|
|
- // calculate absolute position embedding and apply
|
|
|
- ggml_tensor * learned_pos_embd = resize_position_embeddings();
|
|
|
- learned_pos_embd = ggml_cont_4d(
|
|
|
- ctx0, learned_pos_embd,
|
|
|
- n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
|
|
- learned_pos_embd = ggml_reshape_4d(
|
|
|
- ctx0, learned_pos_embd,
|
|
|
- n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
|
|
- learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
|
|
|
- learned_pos_embd = ggml_cont_3d(
|
|
|
- ctx0, learned_pos_embd,
|
|
|
- n_embd, n_patches_x * n_patches_y, batch_size);
|
|
|
+// build vision transformer (ViT) cgraph
|
|
|
+// this function should cover most of the models
|
|
|
+// if your model has specific features, you should probably duplicate this function
|
|
|
+ggml_tensor * clip_graph::build_vit(
|
|
|
+ ggml_tensor * inp,
|
|
|
+ int64_t n_pos,
|
|
|
+ norm_type norm_t,
|
|
|
+ ffn_op_type ffn_t,
|
|
|
+ ggml_tensor * learned_pos_embd,
|
|
|
+ std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
|
|
|
+ ) {
|
|
|
+ if (learned_pos_embd) {
|
|
|
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
|
|
- cb(inp, "inp_pos_emb", -1);
|
|
|
-
|
|
|
- ggml_tensor * inpL = inp;
|
|
|
-
|
|
|
- ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
|
|
- ggml_set_name(positions, "positions");
|
|
|
- ggml_set_input(positions);
|
|
|
-
|
|
|
- // pre-layernorm
|
|
|
- if (model.pre_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
|
|
|
- }
|
|
|
-
|
|
|
- // deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size]
|
|
|
- ggml_tensor * deepstack_features = nullptr;
|
|
|
- const int merge_factor = hparams.n_merge > 0 ? hparams.n_merge * hparams.n_merge : 4; // default 2x2=4 for qwen3vl
|
|
|
-
|
|
|
- // loop over layers
|
|
|
- for (int il = 0; il < n_layer; il++) {
|
|
|
- auto & layer = model.layers[il];
|
|
|
-
|
|
|
- ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- // layernorm1
|
|
|
- cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
|
|
|
- cb(cur, "ln1", il);
|
|
|
-
|
|
|
- // self-attention
|
|
|
- {
|
|
|
- cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
|
|
- cur = ggml_add(ctx0, cur, layer.qkv_b);
|
|
|
-
|
|
|
- ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
- /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
- /* nb2 */ cur->nb[1],
|
|
|
- /* offset */ 0);
|
|
|
-
|
|
|
- ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
- /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
- /* nb2 */ cur->nb[1],
|
|
|
- /* offset */ ggml_row_size(cur->type, n_embd));
|
|
|
-
|
|
|
- ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
- /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
- /* nb2 */ cur->nb[1],
|
|
|
- /* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
|
|
-
|
|
|
- cb(Qcur, "Qcur", il);
|
|
|
- cb(Kcur, "Kcur", il);
|
|
|
- cb(Vcur, "Vcur", il);
|
|
|
-
|
|
|
- // apply M-RoPE
|
|
|
- Qcur = ggml_rope_multi(
|
|
|
- ctx0, Qcur, positions, nullptr,
|
|
|
- d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
|
|
- Kcur = ggml_rope_multi(
|
|
|
- ctx0, Kcur, positions, nullptr,
|
|
|
- d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
|
|
-
|
|
|
- cb(Qcur, "Qcur_rope", il);
|
|
|
- cb(Kcur, "Kcur_rope", il);
|
|
|
-
|
|
|
- cur = build_attn(layer.o_w, layer.o_b,
|
|
|
- Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
|
|
- cb(cur, "attn_out", il);
|
|
|
- }
|
|
|
-
|
|
|
- // re-add the layer input, e.g., residual
|
|
|
- cur = ggml_add(ctx0, cur, inpL);
|
|
|
-
|
|
|
- inpL = cur; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- cb(cur, "ffn_inp", il);
|
|
|
-
|
|
|
- // layernorm2
|
|
|
- cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
|
|
- cb(cur, "ffn_inp_normed", il);
|
|
|
-
|
|
|
- // ffn
|
|
|
- cur = build_ffn(cur,
|
|
|
- layer.ff_up_w, layer.ff_up_b,
|
|
|
- layer.ff_gate_w, layer.ff_gate_b,
|
|
|
- layer.ff_down_w, layer.ff_down_b,
|
|
|
- hparams.ffn_op, il);
|
|
|
-
|
|
|
- cb(cur, "ffn_out", il);
|
|
|
-
|
|
|
- // residual 2
|
|
|
- cur = ggml_add(ctx0, inpL, cur);
|
|
|
- cb(cur, "layer_out", il);
|
|
|
-
|
|
|
- if (layer.has_deepstack()) {
|
|
|
- ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size);
|
|
|
- feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il);
|
|
|
- feat = build_ffn(feat,
|
|
|
- layer.deepstack_fc1_w, layer.deepstack_fc1_b,
|
|
|
- nullptr, nullptr,
|
|
|
- layer.deepstack_fc2_w, layer.deepstack_fc2_b,
|
|
|
- ffn_op_type::FFN_GELU, il);
|
|
|
-
|
|
|
- if(!deepstack_features) {
|
|
|
- deepstack_features = feat;
|
|
|
- } else {
|
|
|
- // concat along the feature dimension
|
|
|
- deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- inpL = cur;
|
|
|
- }
|
|
|
-
|
|
|
- // post-layernorm
|
|
|
- if (model.post_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
|
|
|
- }
|
|
|
-
|
|
|
- // multimodal projection
|
|
|
- ggml_tensor * embeddings = inpL;
|
|
|
- embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
|
|
|
-
|
|
|
- embeddings = build_ffn(embeddings,
|
|
|
- model.mm_0_w, model.mm_0_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- ffn_op_type::FFN_GELU, -1);
|
|
|
-
|
|
|
- embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, embeddings);
|
|
|
-
|
|
|
- return gf;
|
|
|
- }
|
|
|
-
|
|
|
- ggml_cgraph * build_minicpmv() {
|
|
|
- GGML_ASSERT(model.class_embedding == nullptr);
|
|
|
- const int n_pos = n_patches;
|
|
|
- const int n_embd_proj = clip_n_mmproj_embd(ctx);
|
|
|
-
|
|
|
- // position embeddings for the projector (not for ViT)
|
|
|
- // see: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/resampler.py#L70
|
|
|
- // base frequency omega
|
|
|
- ggml_tensor * omega = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_proj / 4);
|
|
|
- ggml_set_name(omega, "omega");
|
|
|
- ggml_set_input(omega);
|
|
|
-
|
|
|
- // 2D input positions (using float for sinusoidal embeddings)
|
|
|
- ggml_tensor * pos_h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos);
|
|
|
- ggml_set_name(pos_h, "pos_h");
|
|
|
- ggml_set_input(pos_h);
|
|
|
- ggml_tensor * pos_w = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos);
|
|
|
- ggml_set_name(pos_w, "pos_w");
|
|
|
- ggml_set_input(pos_w);
|
|
|
-
|
|
|
- // for selecting learned pos embd, used by ViT
|
|
|
- struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
|
|
|
- ggml_set_name(positions, "positions");
|
|
|
- ggml_set_input(positions);
|
|
|
-
|
|
|
- ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
|
|
|
-
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
- ggml_tensor * embeddings = build_vit(
|
|
|
- inp, n_pos,
|
|
|
- NORM_TYPE_NORMAL,
|
|
|
- hparams.ffn_op,
|
|
|
- learned_pos_embd,
|
|
|
- nullptr);
|
|
|
-
|
|
|
- // resampler projector (it is just another transformer)
|
|
|
-
|
|
|
- ggml_tensor * q = model.mm_model_query;
|
|
|
- ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
|
|
|
-
|
|
|
- // norm
|
|
|
- q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
|
|
|
- v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1);
|
|
|
-
|
|
|
- // calculate sinusoidal pos embd
|
|
|
- ggml_tensor * pos_embed = nullptr;
|
|
|
- {
|
|
|
- // outer product
|
|
|
- ggml_tensor * omega_b = ggml_repeat_4d(ctx0, omega, omega->ne[0], n_pos, 1, 1); // n_pos rows
|
|
|
- ggml_tensor * theta_x = ggml_mul(ctx0, omega_b, pos_w);
|
|
|
- ggml_tensor * theta_y = ggml_mul(ctx0, omega_b, pos_h);
|
|
|
- // sin and cos
|
|
|
- ggml_tensor * pos_embd_x = ggml_concat(
|
|
|
- ctx0,
|
|
|
- ggml_sin(ctx0, theta_x),
|
|
|
- ggml_cos(ctx0, theta_x),
|
|
|
- 0 // concat on first dim
|
|
|
- );
|
|
|
- ggml_tensor * pos_embd_y = ggml_concat(
|
|
|
- ctx0,
|
|
|
- ggml_sin(ctx0, theta_y),
|
|
|
- ggml_cos(ctx0, theta_y),
|
|
|
- 0 // concat on first dim
|
|
|
- );
|
|
|
- pos_embed = ggml_concat(ctx0, pos_embd_x, pos_embd_y, 0);
|
|
|
- }
|
|
|
-
|
|
|
- // k = v + pos_embed
|
|
|
- ggml_tensor * k = ggml_add(ctx0, v, pos_embed);
|
|
|
-
|
|
|
- // attention
|
|
|
- {
|
|
|
- const int d_head = 128;
|
|
|
- int n_head = n_embd_proj/d_head;
|
|
|
- // Use actual config value if available, otherwise fall back to hardcoded values
|
|
|
- int num_query = ctx->model.hparams.minicpmv_query_num;
|
|
|
- ggml_tensor * Q = ggml_add(ctx0,
|
|
|
- ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
|
|
|
- model.mm_model_attn_q_b);
|
|
|
- ggml_tensor * K = ggml_add(ctx0,
|
|
|
- ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
|
|
|
- model.mm_model_attn_k_b);
|
|
|
- ggml_tensor * V = ggml_add(ctx0,
|
|
|
- ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
|
|
|
- model.mm_model_attn_v_b);
|
|
|
-
|
|
|
- Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
|
|
|
- K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos);
|
|
|
- V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos);
|
|
|
-
|
|
|
- cb(Q, "resampler_Q", -1);
|
|
|
- cb(K, "resampler_K", -1);
|
|
|
- cb(V, "resampler_V", -1);
|
|
|
-
|
|
|
- float resampler_kq_scale = 1.0f/ sqrtf(float(d_head));
|
|
|
- embeddings = build_attn(
|
|
|
- model.mm_model_attn_o_w,
|
|
|
- model.mm_model_attn_o_b,
|
|
|
- Q, K, V, nullptr, resampler_kq_scale, -1);
|
|
|
- cb(embeddings, "resampler_attn_out", -1);
|
|
|
- }
|
|
|
- // layernorm
|
|
|
- embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
|
|
|
-
|
|
|
- // projection
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, embeddings);
|
|
|
-
|
|
|
- return gf;
|
|
|
- }
|
|
|
-
|
|
|
- ggml_cgraph * build_internvl() {
|
|
|
- GGML_ASSERT(model.class_embedding != nullptr);
|
|
|
- GGML_ASSERT(model.position_embeddings != nullptr);
|
|
|
-
|
|
|
- const int n_pos = n_patches + 1;
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
-
|
|
|
- // add CLS token
|
|
|
- inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
|
-
|
|
|
- // The larger models use a different ViT, which uses RMS norm instead of layer norm
|
|
|
- // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188
|
|
|
- norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45)
|
|
|
- ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B)
|
|
|
- : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models)
|
|
|
-
|
|
|
- ggml_tensor * cur = build_vit(
|
|
|
- inp, n_pos,
|
|
|
- norm_t,
|
|
|
- hparams.ffn_op,
|
|
|
- model.position_embeddings,
|
|
|
- nullptr);
|
|
|
-
|
|
|
- // remove CLS token
|
|
|
- cur = ggml_view_2d(ctx0, cur,
|
|
|
- n_embd, n_patches,
|
|
|
- ggml_row_size(cur->type, n_embd), 0);
|
|
|
-
|
|
|
- // pixel shuffle
|
|
|
- {
|
|
|
- const int scale_factor = model.hparams.n_merge;
|
|
|
- const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
|
|
- const int height = n_patches_y;
|
|
|
- const int width = n_patches_x;
|
|
|
- GGML_ASSERT(scale_factor > 0);
|
|
|
- cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
|
|
|
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
- cur = ggml_cont_4d(ctx0, cur,
|
|
|
- n_embd * scale_factor * scale_factor,
|
|
|
- height / scale_factor,
|
|
|
- width / scale_factor,
|
|
|
- bsz);
|
|
|
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
- // flatten to 2D
|
|
|
- cur = ggml_cont_2d(ctx0, cur,
|
|
|
- n_embd * scale_factor * scale_factor,
|
|
|
- cur->ne[1] * cur->ne[2]);
|
|
|
- }
|
|
|
-
|
|
|
- // projector (always using GELU activation)
|
|
|
- {
|
|
|
- // projector LayerNorm uses pytorch's default eps = 1e-5
|
|
|
- // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
|
|
|
- cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_3_w, model.mm_3_b,
|
|
|
- FFN_GELU,
|
|
|
- -1);
|
|
|
- }
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
- }
|
|
|
-
|
|
|
- ggml_cgraph * build_llama4() {
|
|
|
- GGML_ASSERT(model.class_embedding != nullptr);
|
|
|
- GGML_ASSERT(model.position_embeddings != nullptr);
|
|
|
-
|
|
|
- const int n_pos = n_patches + 1; // +1 for [CLS]
|
|
|
-
|
|
|
- // 2D input positions
|
|
|
- ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
|
|
|
- ggml_set_name(pos_h, "pos_h");
|
|
|
- ggml_set_input(pos_h);
|
|
|
-
|
|
|
- ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
|
|
|
- ggml_set_name(pos_w, "pos_w");
|
|
|
- ggml_set_input(pos_w);
|
|
|
-
|
|
|
- ggml_tensor * inp = build_inp_raw();
|
|
|
-
|
|
|
- // Llama4UnfoldConvolution
|
|
|
- {
|
|
|
- ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
|
|
|
- patch_size, patch_size, 3, n_embd);
|
|
|
- inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
|
|
|
- inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
|
|
|
- inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
|
|
|
- cb(inp, "patch_conv", -1);
|
|
|
- }
|
|
|
-
|
|
|
- // add CLS token
|
|
|
- inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
|
-
|
|
|
- // build ViT with 2D position embeddings
|
|
|
- auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
|
|
- // first half is X axis and second half is Y axis
|
|
|
- // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312
|
|
|
- // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441
|
|
|
- return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
|
|
|
- };
|
|
|
- ggml_tensor * cur = build_vit(
|
|
|
- inp, n_pos,
|
|
|
- NORM_TYPE_NORMAL,
|
|
|
- hparams.ffn_op,
|
|
|
- model.position_embeddings,
|
|
|
- add_pos);
|
|
|
-
|
|
|
- // remove CLS token
|
|
|
- cur = ggml_view_2d(ctx0, cur,
|
|
|
- n_embd, n_patches,
|
|
|
- ggml_row_size(cur->type, n_embd), 0);
|
|
|
-
|
|
|
- // pixel shuffle
|
|
|
- // based on Llama4VisionPixelShuffleMLP
|
|
|
- // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
|
|
|
- {
|
|
|
- const int scale_factor = model.hparams.n_merge;
|
|
|
- const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
|
|
- GGML_ASSERT(scale_factor > 0);
|
|
|
- GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
|
|
|
- cur = ggml_reshape_4d(ctx0, cur,
|
|
|
- n_embd * scale_factor,
|
|
|
- n_patches_x / scale_factor,
|
|
|
- n_patches_y,
|
|
|
- bsz);
|
|
|
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
- cur = ggml_cont_4d(ctx0, cur,
|
|
|
- n_embd * scale_factor * scale_factor,
|
|
|
- n_patches_x / scale_factor,
|
|
|
- n_patches_y / scale_factor,
|
|
|
- bsz);
|
|
|
- //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
- // flatten to 2D
|
|
|
- cur = ggml_cont_2d(ctx0, cur,
|
|
|
- n_embd * scale_factor * scale_factor,
|
|
|
- n_patches / scale_factor / scale_factor);
|
|
|
- cb(cur, "pixel_shuffle", -1);
|
|
|
- }
|
|
|
-
|
|
|
- // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
|
|
|
- {
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
|
|
|
- cur = ggml_gelu(ctx0, cur);
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
|
|
|
- cur = ggml_gelu(ctx0, cur);
|
|
|
- cb(cur, "adapter_mlp", -1);
|
|
|
- }
|
|
|
-
|
|
|
- // Llama4MultiModalProjector
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
|
|
|
- cb(cur, "projected", -1);
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
+ cb(inp, "pos_embed", -1);
|
|
|
}
|
|
|
|
|
|
- ggml_cgraph * build_kimivl() {
|
|
|
- // 2D input positions
|
|
|
- ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
|
|
- ggml_set_name(pos_h, "pos_h");
|
|
|
- ggml_set_input(pos_h);
|
|
|
+ ggml_tensor * inpL = inp;
|
|
|
|
|
|
- ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
|
|
- ggml_set_name(pos_w, "pos_w");
|
|
|
- ggml_set_input(pos_w);
|
|
|
-
|
|
|
- ggml_tensor * learned_pos_embd = resize_position_embeddings();
|
|
|
-
|
|
|
- // build ViT with 2D position embeddings
|
|
|
- auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
|
|
- // first half is X axis and second half is Y axis
|
|
|
- return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
|
|
|
- };
|
|
|
-
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
- ggml_tensor * cur = build_vit(
|
|
|
- inp, n_patches,
|
|
|
- NORM_TYPE_NORMAL,
|
|
|
- hparams.ffn_op,
|
|
|
- learned_pos_embd,
|
|
|
- add_pos);
|
|
|
-
|
|
|
- cb(cur, "vit_out", -1);
|
|
|
-
|
|
|
- {
|
|
|
- // patch_merger
|
|
|
- const int scale_factor = model.hparams.n_merge;
|
|
|
- cur = build_patch_merge_permute(cur, scale_factor);
|
|
|
-
|
|
|
- // projection norm
|
|
|
- int proj_inp_dim = cur->ne[0];
|
|
|
- cur = ggml_view_2d(ctx0, cur,
|
|
|
- n_embd, cur->ne[1] * scale_factor * scale_factor,
|
|
|
- ggml_row_size(cur->type, n_embd), 0);
|
|
|
- cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
|
|
|
- cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
|
|
|
- cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
|
|
|
- cur = ggml_view_2d(ctx0, cur,
|
|
|
- proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
|
|
|
- ggml_row_size(cur->type, proj_inp_dim), 0);
|
|
|
- cb(cur, "proj_inp_normed", -1);
|
|
|
-
|
|
|
- // projection mlp
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_2_w, model.mm_2_b,
|
|
|
- FFN_GELU,
|
|
|
- -1);
|
|
|
- cb(cur, "proj_out", -1);
|
|
|
- }
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
+ // pre-layernorm
|
|
|
+ if (model.pre_ln_w) {
|
|
|
+ inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
|
|
|
+ cb(inpL, "pre_ln", -1);
|
|
|
}
|
|
|
|
|
|
- // this graph is used by llava, granite and glm
|
|
|
- // due to having embedding_stack (used by granite), we cannot reuse build_vit
|
|
|
- ggml_cgraph * build_llava() {
|
|
|
- const int batch_size = 1;
|
|
|
- const int n_pos = n_patches + (model.class_embedding ? 1 : 0);
|
|
|
+ // loop over layers
|
|
|
+ for (int il = 0; il < n_layer; il++) {
|
|
|
+ auto & layer = model.layers[il];
|
|
|
+ ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
|
|
|
|
|
- GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported");
|
|
|
+ // layernorm1
|
|
|
+ cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
|
|
|
+ cb(cur, "layer_inp_normed", il);
|
|
|
|
|
|
- // Calculate the deepest feature layer based on hparams and projector type
|
|
|
- int max_feature_layer = n_layer;
|
|
|
+ // self-attention
|
|
|
{
|
|
|
- // Get the index of the second to last layer; this is the default for models that have a llava projector
|
|
|
- int il_last = hparams.n_layer - 1;
|
|
|
- int deepest_feature_layer = -1;
|
|
|
-
|
|
|
- if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
|
|
|
- il_last += 1;
|
|
|
- }
|
|
|
-
|
|
|
- // If we set explicit vision feature layers, only go up to the deepest one
|
|
|
- // NOTE: only used by granite-vision models for now
|
|
|
- for (const auto & feature_layer : hparams.vision_feature_layer) {
|
|
|
- if (feature_layer > deepest_feature_layer) {
|
|
|
- deepest_feature_layer = feature_layer;
|
|
|
+ ggml_tensor * Qcur = nullptr;
|
|
|
+ ggml_tensor * Kcur = nullptr;
|
|
|
+ ggml_tensor * Vcur = nullptr;
|
|
|
+ if (layer.qkv_w != nullptr) {
|
|
|
+ // fused qkv
|
|
|
+ cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
|
|
+ if (layer.qkv_b != nullptr) {
|
|
|
+ cur = ggml_add(ctx0, cur, layer.qkv_b);
|
|
|
}
|
|
|
- }
|
|
|
- max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer;
|
|
|
- }
|
|
|
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
+ Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
+ /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
+ /* nb2 */ cur->nb[1],
|
|
|
+ /* offset */ 0);
|
|
|
|
|
|
- // concat class_embeddings and patch_embeddings
|
|
|
- if (model.class_embedding) {
|
|
|
- inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
|
- }
|
|
|
-
|
|
|
- ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
|
|
|
- ggml_set_name(positions, "positions");
|
|
|
- ggml_set_input(positions);
|
|
|
-
|
|
|
- inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
|
|
-
|
|
|
- ggml_tensor * inpL = inp;
|
|
|
-
|
|
|
- // pre-layernorm
|
|
|
- if (model.pre_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1);
|
|
|
- cb(inpL, "pre_ln", -1);
|
|
|
- }
|
|
|
-
|
|
|
- std::vector<ggml_tensor *> embedding_stack;
|
|
|
- const auto & vision_feature_layer = hparams.vision_feature_layer;
|
|
|
+ Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
+ /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
+ /* nb2 */ cur->nb[1],
|
|
|
+ /* offset */ ggml_row_size(cur->type, n_embd));
|
|
|
|
|
|
- // loop over layers
|
|
|
- for (int il = 0; il < max_feature_layer; il++) {
|
|
|
- auto & layer = model.layers[il];
|
|
|
- ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- // If this is an embedding feature layer, save the output.
|
|
|
- // NOTE: 0 index here refers to the input to the encoder.
|
|
|
- if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
|
|
|
- embedding_stack.push_back(cur);
|
|
|
- }
|
|
|
+ Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
+ /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
+ /* nb2 */ cur->nb[1],
|
|
|
+ /* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
|
|
|
|
|
- // layernorm1
|
|
|
- cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
|
|
- cb(cur, "layer_inp_normed", il);
|
|
|
+ // TODO: q/k norm requires row size == n_embd, while here it's d_head
|
|
|
+ // we can add support in the future if needed
|
|
|
+ GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr);
|
|
|
|
|
|
- // self-attention
|
|
|
- {
|
|
|
- ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
|
|
|
+ } else {
|
|
|
+ // separate q, k, v
|
|
|
+ Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
|
|
|
if (layer.q_b) {
|
|
|
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
|
|
|
}
|
|
|
|
|
|
- ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
|
|
|
+ Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
|
|
|
if (layer.k_b) {
|
|
|
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
|
|
|
}
|
|
|
|
|
|
- ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
|
|
|
+ Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
|
|
|
if (layer.v_b) {
|
|
|
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
|
|
|
}
|
|
|
|
|
|
- Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
|
|
- Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
|
|
|
- Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
|
|
|
-
|
|
|
- cb(Qcur, "Qcur", il);
|
|
|
- cb(Kcur, "Kcur", il);
|
|
|
- cb(Vcur, "Vcur", il);
|
|
|
-
|
|
|
- cur = build_attn(layer.o_w, layer.o_b,
|
|
|
- Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
|
|
- cb(cur, "attn_out", il);
|
|
|
- }
|
|
|
-
|
|
|
- // re-add the layer input, e.g., residual
|
|
|
- cur = ggml_add(ctx0, cur, inpL);
|
|
|
-
|
|
|
- inpL = cur; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- cb(cur, "ffn_inp", il);
|
|
|
-
|
|
|
- // layernorm2
|
|
|
- cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
|
|
- cb(cur, "ffn_inp_normed", il);
|
|
|
-
|
|
|
- // ffn
|
|
|
- cur = build_ffn(cur,
|
|
|
- layer.ff_up_w, layer.ff_up_b,
|
|
|
- layer.ff_gate_w, layer.ff_gate_b,
|
|
|
- layer.ff_down_w, layer.ff_down_b,
|
|
|
- hparams.ffn_op, il);
|
|
|
-
|
|
|
- cb(cur, "ffn_out", il);
|
|
|
-
|
|
|
- // residual 2
|
|
|
- cur = ggml_add(ctx0, inpL, cur);
|
|
|
- cb(cur, "layer_out", il);
|
|
|
-
|
|
|
- inpL = cur;
|
|
|
- }
|
|
|
-
|
|
|
- // post-layernorm
|
|
|
- if (model.post_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1);
|
|
|
- }
|
|
|
-
|
|
|
- ggml_tensor * embeddings = inpL;
|
|
|
-
|
|
|
- // process vision feature layers (used by granite)
|
|
|
- {
|
|
|
- // final layer is a vision feature layer
|
|
|
- if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) {
|
|
|
- embedding_stack.push_back(inpL);
|
|
|
- }
|
|
|
-
|
|
|
- // If feature layers are explicitly set, stack them (if we have multiple)
|
|
|
- if (!embedding_stack.empty()) {
|
|
|
- embeddings = embedding_stack[0];
|
|
|
- for (size_t i = 1; i < embedding_stack.size(); i++) {
|
|
|
- embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
|
|
|
+ if (layer.q_norm) {
|
|
|
+ Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
|
|
|
+ cb(Qcur, "Qcur_norm", il);
|
|
|
}
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // llava projector (also used by granite)
|
|
|
- if (ctx->model.hparams.has_llava_projector) {
|
|
|
- embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
|
|
-
|
|
|
- ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
|
|
- ggml_set_name(patches, "patches");
|
|
|
- ggml_set_input(patches);
|
|
|
-
|
|
|
- // shape [1, 576, 1024]
|
|
|
- // ne is whcn, ne = [1024, 576, 1, 1]
|
|
|
- embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
|
|
|
|
|
- // print_tensor_info(embeddings, "embeddings");
|
|
|
-
|
|
|
- // llava projector
|
|
|
- if (ctx->proj_type() == PROJECTOR_TYPE_MLP) {
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
|
|
- embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
|
|
-
|
|
|
- embeddings = ggml_gelu(ctx0, embeddings);
|
|
|
- if (model.mm_2_w) {
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
|
|
- embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
|
|
+ if (layer.k_norm) {
|
|
|
+ Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
|
|
|
+ cb(Kcur, "Kcur_norm", il);
|
|
|
}
|
|
|
- }
|
|
|
- else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) {
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
|
|
- embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
|
|
- // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
|
|
- // First LayerNorm
|
|
|
- embeddings = ggml_norm(ctx0, embeddings, eps);
|
|
|
- embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
|
|
|
- model.mm_1_b);
|
|
|
-
|
|
|
- // GELU activation
|
|
|
- embeddings = ggml_gelu(ctx0, embeddings);
|
|
|
-
|
|
|
- // Second linear layer
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
|
|
|
- embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
|
|
|
-
|
|
|
- // Second LayerNorm
|
|
|
- embeddings = ggml_norm(ctx0, embeddings, eps);
|
|
|
- embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
|
|
|
- model.mm_4_b);
|
|
|
- }
|
|
|
- else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) {
|
|
|
- // MobileVLM projector
|
|
|
- int n_patch = 24;
|
|
|
- ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
|
|
|
- mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
|
|
|
- mlp_1 = ggml_gelu(ctx0, mlp_1);
|
|
|
- ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
|
|
|
- mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
|
|
|
- // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
|
|
|
-
|
|
|
- // block 1
|
|
|
- ggml_tensor * block_1 = nullptr;
|
|
|
- {
|
|
|
- // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
|
|
|
- mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3);
|
|
|
- mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
|
|
|
- // stride = 1, padding = 1, bias is nullptr
|
|
|
- block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
|
|
|
-
|
|
|
- // layer norm
|
|
|
- // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
|
|
|
- // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
|
|
|
- block_1 = ggml_norm(ctx0, block_1, eps);
|
|
|
- block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
|
|
|
-
|
|
|
- // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
|
|
|
- // hardswish
|
|
|
- ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
|
|
|
-
|
|
|
- block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
|
|
|
- // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
|
|
|
- // pointwise conv
|
|
|
- block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
|
|
|
- block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
|
|
|
- block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
|
|
|
- block_1 = ggml_relu(ctx0, block_1);
|
|
|
- block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
|
|
|
- block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
|
|
|
- block_1 = ggml_hardsigmoid(ctx0, block_1);
|
|
|
- // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
|
|
|
- block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
|
|
|
- block_1 = ggml_mul(ctx0, block_1_hw, block_1);
|
|
|
-
|
|
|
- int w = block_1->ne[0], h = block_1->ne[1];
|
|
|
- block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
|
|
|
-
|
|
|
- // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
|
|
|
- block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
|
|
|
- block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
|
|
|
-
|
|
|
- // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
|
|
|
- block_1 = ggml_norm(ctx0, block_1, eps);
|
|
|
- block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
|
|
|
- // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
|
|
|
- // residual
|
|
|
- block_1 = ggml_add(ctx0, mlp_3, block_1);
|
|
|
- }
|
|
|
-
|
|
|
- // block_2
|
|
|
- {
|
|
|
- // stride = 2
|
|
|
- block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
|
|
|
-
|
|
|
- // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
|
|
|
- // layer norm
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
|
|
|
- // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
|
|
|
- block_1 = ggml_norm(ctx0, block_1, eps);
|
|
|
- block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
|
|
|
- // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
|
|
|
- // hardswish
|
|
|
- ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
|
|
|
-
|
|
|
- // not sure the parameters is right for globalAvgPooling
|
|
|
- block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
|
|
|
- // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
|
|
|
- // pointwise conv
|
|
|
- block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
|
|
|
- block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
|
|
|
- block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
|
|
|
- block_1 = ggml_relu(ctx0, block_1);
|
|
|
- block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
|
|
|
- block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
|
|
|
- block_1 = ggml_hardsigmoid(ctx0, block_1);
|
|
|
-
|
|
|
- // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
|
|
|
- block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
|
|
|
- block_1 = ggml_mul(ctx0, block_1_hw, block_1);
|
|
|
-
|
|
|
- int w = block_1->ne[0], h = block_1->ne[1];
|
|
|
- block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
|
|
|
- block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
|
|
|
- // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
|
|
|
- block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
|
|
|
- block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
|
|
|
-
|
|
|
-
|
|
|
- // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
|
|
|
- block_1 = ggml_norm(ctx0, block_1, eps);
|
|
|
- block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
|
|
|
- block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
|
|
|
- // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
|
|
|
- }
|
|
|
- embeddings = block_1;
|
|
|
- }
|
|
|
- else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2)
|
|
|
- {
|
|
|
- int n_patch = 24;
|
|
|
- ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
|
|
- mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
|
|
|
- mlp_0 = ggml_gelu(ctx0, mlp_0);
|
|
|
- ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
|
|
|
- mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
|
|
|
- // mlp_2 ne = [2048, 576, 1, 1]
|
|
|
- // // AVG Pool Layer 2*2, strides = 2
|
|
|
- mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3);
|
|
|
- // mlp_2 ne = [576, 2048, 1, 1]
|
|
|
- mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
|
|
|
- // mlp_2 ne [24, 24, 2048, 1]
|
|
|
- mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
|
|
|
- // weight ne = [3, 3, 2048, 1]
|
|
|
- ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
|
|
|
- peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
|
|
|
- peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
|
|
|
- mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
|
|
|
- peg_0 = ggml_add(ctx0, peg_0, mlp_2);
|
|
|
- peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
|
|
|
- embeddings = peg_0;
|
|
|
- }
|
|
|
- else {
|
|
|
- GGML_ABORT("fatal error");
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
- // glm projector
|
|
|
- else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
|
|
|
- size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
|
|
|
- embeddings = ggml_permute(ctx0,embeddings,1,0,2,3);
|
|
|
- embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
|
|
|
- embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
|
|
|
- embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
|
|
|
- embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
|
|
|
- embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
|
|
|
- // GLU
|
|
|
- {
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
|
|
- embeddings = ggml_norm(ctx0, embeddings, eps);
|
|
|
- embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
|
|
|
- embeddings = ggml_gelu_inplace(ctx0, embeddings);
|
|
|
- ggml_tensor * x = embeddings;
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
|
|
|
- x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
|
|
|
- embeddings = ggml_swiglu_split(ctx0, embeddings, x);
|
|
|
- embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
|
|
- }
|
|
|
- // arrangement of BOI/EOI token embeddings
|
|
|
- // note: these embeddings are not present in text model, hence we cannot process them as text tokens
|
|
|
- // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
|
|
|
- {
|
|
|
- embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI
|
|
|
- embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- else {
|
|
|
- GGML_ABORT("llava: unknown projector type");
|
|
|
- }
|
|
|
-
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, embeddings);
|
|
|
-
|
|
|
- return gf;
|
|
|
- }
|
|
|
- // whisper encoder with custom projector
|
|
|
- ggml_cgraph * build_whisper_enc() {
|
|
|
- const int n_frames = img.nx;
|
|
|
- const int n_pos = n_frames / 2;
|
|
|
- GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
|
|
|
-
|
|
|
- ggml_tensor * inp = build_inp_raw(1);
|
|
|
-
|
|
|
- // conv1d block
|
|
|
- {
|
|
|
- // convolution + gelu
|
|
|
- ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
|
|
|
- cur = ggml_add(ctx0, cur, model.conv1d_1_b);
|
|
|
-
|
|
|
- cur = ggml_gelu_erf(ctx0, cur);
|
|
|
-
|
|
|
- cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
|
|
|
- cur = ggml_add(ctx0, cur, model.conv1d_2_b);
|
|
|
-
|
|
|
- cur = ggml_gelu_erf(ctx0, cur);
|
|
|
- // transpose
|
|
|
- inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
|
|
- cb(inp, "after_conv1d", -1);
|
|
|
- }
|
|
|
-
|
|
|
- // sanity check (only check one layer, but it should be the same for all)
|
|
|
- GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
|
|
|
- GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
|
|
|
- GGML_ASSERT(model.layers[0].q_b);
|
|
|
- GGML_ASSERT(model.layers[0].v_b);
|
|
|
- GGML_ASSERT(!model.layers[0].k_b); // no bias for k
|
|
|
- GGML_ASSERT(model.post_ln_w && model.post_ln_b);
|
|
|
-
|
|
|
- ggml_tensor * pos_embd_selected = ggml_view_2d(
|
|
|
- ctx0, model.position_embeddings,
|
|
|
- model.position_embeddings->ne[0], n_pos,
|
|
|
- model.position_embeddings->nb[1], 0
|
|
|
- );
|
|
|
- ggml_tensor * cur = build_vit(
|
|
|
- inp, n_pos,
|
|
|
- NORM_TYPE_NORMAL,
|
|
|
- hparams.ffn_op,
|
|
|
- pos_embd_selected,
|
|
|
- nullptr);
|
|
|
-
|
|
|
- cb(cur, "after_transformer", -1);
|
|
|
-
|
|
|
- if (model.audio_has_stack_frames()) {
|
|
|
- // StackAudioFrames
|
|
|
- // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
|
|
- int64_t stride = n_embd * hparams.proj_stack_factor;
|
|
|
- int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
|
|
- int64_t pad = padded_len - ggml_nelements(cur);
|
|
|
- if (pad > 0) {
|
|
|
- cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
|
|
- cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
|
|
+ Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
|
|
+ Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
|
|
|
+ Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
|
|
|
}
|
|
|
- cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
|
|
- ggml_row_size(cur->type, stride), 0);
|
|
|
- cb(cur, "after_stacked", -1);
|
|
|
- }
|
|
|
-
|
|
|
- if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
|
|
|
- // UltravoxProjector
|
|
|
- // pre-norm
|
|
|
- cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
|
|
- cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
|
|
-
|
|
|
- // ffn in
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
|
|
-
|
|
|
- // swiglu
|
|
|
- // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
|
|
- cur = ggml_swiglu_swapped(ctx0, cur);
|
|
|
-
|
|
|
- // mid-norm
|
|
|
- cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
|
|
- cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
|
|
-
|
|
|
- // ffn out
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
|
|
-
|
|
|
- } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
|
|
|
- // projector
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
|
|
|
- cur = ggml_add(ctx0, cur, model.mm_fc_b);
|
|
|
-
|
|
|
- } else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) {
|
|
|
- // projector
|
|
|
- cur = build_ffn(cur,
|
|
|
- model.mm_1_w, model.mm_1_b,
|
|
|
- nullptr, nullptr,
|
|
|
- model.mm_2_w, model.mm_2_b,
|
|
|
- FFN_GELU_ERF,
|
|
|
- -1);
|
|
|
-
|
|
|
- } else {
|
|
|
- GGML_ABORT("%s: unknown projector type", __func__);
|
|
|
- }
|
|
|
-
|
|
|
- cb(cur, "projected", -1);
|
|
|
-
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
- }
|
|
|
-
|
|
|
- // cogvlm vision encoder
|
|
|
- ggml_cgraph * build_cogvlm() {
|
|
|
- GGML_ASSERT(model.class_embedding != nullptr);
|
|
|
- GGML_ASSERT(model.position_embeddings != nullptr);
|
|
|
-
|
|
|
- const int n_pos = n_patches + 1; // +1 for [CLS]
|
|
|
-
|
|
|
- // build input and concatenate class embedding
|
|
|
- ggml_tensor * inp = build_inp();
|
|
|
- inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
|
-
|
|
|
- inp = ggml_add(ctx0, inp, model.position_embeddings);
|
|
|
- cb(inp, "inp_pos", -1);
|
|
|
-
|
|
|
- ggml_tensor * inpL = inp;
|
|
|
-
|
|
|
- for (int il = 0; il < n_layer; il++) {
|
|
|
- auto & layer = model.layers[il];
|
|
|
- ggml_tensor * cur = inpL;
|
|
|
-
|
|
|
- cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
|
|
-
|
|
|
- cur = ggml_add(ctx0, cur, layer.qkv_b);
|
|
|
-
|
|
|
- ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
|
|
- cur->nb[1], 0);
|
|
|
- ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
|
|
- cur->nb[1], n_embd * sizeof(float));
|
|
|
- ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
|
|
- cur->nb[1], 2 * n_embd * sizeof(float));
|
|
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
|
cb(Kcur, "Kcur", il);
|
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
|
|
+ if (add_pos) {
|
|
|
+ Qcur = add_pos(Qcur, layer);
|
|
|
+ Kcur = add_pos(Kcur, layer);
|
|
|
+ cb(Qcur, "Qcur_pos", il);
|
|
|
+ cb(Kcur, "Kcur_pos", il);
|
|
|
+ }
|
|
|
+
|
|
|
cur = build_attn(layer.o_w, layer.o_b,
|
|
|
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
|
|
cb(cur, "attn_out", il);
|
|
|
-
|
|
|
- cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
|
|
- cb(cur, "attn_post_norm", il);
|
|
|
-
|
|
|
- cur = ggml_add(ctx0, cur, inpL);
|
|
|
- inpL = cur;
|
|
|
-
|
|
|
- cur = build_ffn(cur,
|
|
|
- layer.ff_up_w, layer.ff_up_b,
|
|
|
- layer.ff_gate_w, layer.ff_gate_b,
|
|
|
- layer.ff_down_w, layer.ff_down_b,
|
|
|
- hparams.ffn_op, il);
|
|
|
-
|
|
|
- cb(cur, "ffn_out", il);
|
|
|
-
|
|
|
- cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
|
|
- cb(cur, "ffn_post_norm", il);
|
|
|
-
|
|
|
- cur = ggml_add(ctx0, cur, inpL);
|
|
|
- cb(cur, "layer_out", il);
|
|
|
- inpL = cur;
|
|
|
-
|
|
|
}
|
|
|
|
|
|
- // remove CLS token (like build_llama4 does)
|
|
|
- ggml_tensor * cur = ggml_view_2d(ctx0, inpL,
|
|
|
- n_embd, n_patches,
|
|
|
- ggml_row_size(inpL->type, n_embd), 0);
|
|
|
+ if (layer.ls_1_w) {
|
|
|
+ cur = ggml_mul(ctx0, cur, layer.ls_1_w);
|
|
|
+ cb(cur, "attn_out_scaled", il);
|
|
|
+ }
|
|
|
|
|
|
- // Multiply with mm_model_proj
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
|
|
|
+ // re-add the layer input, e.g., residual
|
|
|
+ cur = ggml_add(ctx0, cur, inpL);
|
|
|
|
|
|
- // Apply layernorm, weight, bias
|
|
|
- cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
|
|
+ inpL = cur; // inpL = residual, cur = hidden_states
|
|
|
|
|
|
- // Apply GELU
|
|
|
- cur = ggml_gelu_inplace(ctx0, cur);
|
|
|
+ cb(cur, "ffn_inp", il);
|
|
|
|
|
|
- // Branch 1: multiply with mm_h_to_4h_w
|
|
|
- ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
|
|
|
+ // layernorm2
|
|
|
+ cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
|
|
+ cb(cur, "ffn_inp_normed", il);
|
|
|
|
|
|
- // Branch 2: multiply with mm_gate_w
|
|
|
- ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
|
|
|
+ // ffn
|
|
|
+ cur = build_ffn(cur,
|
|
|
+ layer.ff_up_w, layer.ff_up_b,
|
|
|
+ layer.ff_gate_w, layer.ff_gate_b,
|
|
|
+ layer.ff_down_w, layer.ff_down_b,
|
|
|
+ ffn_t, il);
|
|
|
|
|
|
- // Apply silu
|
|
|
- gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
|
|
|
+ cb(cur, "ffn_out", il);
|
|
|
|
|
|
- // Apply mm_4h_to_h_w
|
|
|
- cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
|
|
|
+ if (layer.ls_2_w) {
|
|
|
+ cur = ggml_mul(ctx0, cur, layer.ls_2_w);
|
|
|
+ cb(cur, "ffn_out_scaled", il);
|
|
|
+ }
|
|
|
|
|
|
- // Concatenate with boi and eoi
|
|
|
- cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
|
|
|
- cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
|
|
|
+ // residual 2
|
|
|
+ cur = ggml_add(ctx0, inpL, cur);
|
|
|
+ cb(cur, "layer_out", il);
|
|
|
|
|
|
- // build the graph
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
-
|
|
|
- return gf;
|
|
|
+ inpL = cur;
|
|
|
}
|
|
|
|
|
|
-private:
|
|
|
- //
|
|
|
- // utility functions
|
|
|
- //
|
|
|
-
|
|
|
- void cb(ggml_tensor * cur0, const char * name, int il) const {
|
|
|
- if (ctx->debug_graph) {
|
|
|
- ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
|
|
|
- std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
|
|
|
- ggml_set_name(cur, cur_name.c_str());
|
|
|
- ggml_set_output(cur);
|
|
|
- ggml_build_forward_expand(gf, cur);
|
|
|
- ctx->debug_print_tensors.push_back(cur);
|
|
|
- }
|
|
|
+ if (model.audio_has_avgpool()) {
|
|
|
+ ggml_tensor * cur = inpL;
|
|
|
+ cur = ggml_transpose(ctx0, cur);
|
|
|
+ cur = ggml_cont(ctx0, cur);
|
|
|
+ cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
|
|
|
+ cur = ggml_transpose(ctx0, cur);
|
|
|
+ cur = ggml_cont(ctx0, cur);
|
|
|
+ inpL = cur;
|
|
|
}
|
|
|
|
|
|
- // siglip2 naflex
|
|
|
- ggml_tensor * resize_position_embeddings() {
|
|
|
- ggml_tensor * pos_embd = model.position_embeddings;
|
|
|
- const int height = img.ny / patch_size;
|
|
|
- const int width = img.nx / patch_size;
|
|
|
- const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS;
|
|
|
- const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
|
|
|
-
|
|
|
- GGML_ASSERT(pos_embd);
|
|
|
-
|
|
|
- if (height == n_per_side && width == n_per_side) {
|
|
|
- return pos_embd;
|
|
|
- }
|
|
|
-
|
|
|
- pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side)
|
|
|
- pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd)
|
|
|
- pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
|
|
|
- pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height)
|
|
|
- pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height)
|
|
|
-
|
|
|
- return pos_embd;
|
|
|
+ // post-layernorm
|
|
|
+ if (model.post_ln_w) {
|
|
|
+ inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
|
|
}
|
|
|
+ return inpL;
|
|
|
+}
|
|
|
|
|
|
- // build vision transformer (ViT) cgraph
|
|
|
- // this function should cover most of the models
|
|
|
- // if your model has specific features, you should probably duplicate this function
|
|
|
- ggml_tensor * build_vit(
|
|
|
- ggml_tensor * inp,
|
|
|
- int64_t n_pos,
|
|
|
- norm_type norm_t,
|
|
|
- ffn_op_type ffn_t,
|
|
|
- ggml_tensor * learned_pos_embd,
|
|
|
- std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
|
|
|
- ) {
|
|
|
- if (learned_pos_embd) {
|
|
|
- inp = ggml_add(ctx0, inp, learned_pos_embd);
|
|
|
- cb(inp, "pos_embed", -1);
|
|
|
- }
|
|
|
-
|
|
|
- ggml_tensor * inpL = inp;
|
|
|
-
|
|
|
- // pre-layernorm
|
|
|
- if (model.pre_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
|
|
|
- cb(inpL, "pre_ln", -1);
|
|
|
- }
|
|
|
-
|
|
|
- // loop over layers
|
|
|
- for (int il = 0; il < n_layer; il++) {
|
|
|
- auto & layer = model.layers[il];
|
|
|
- ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- // layernorm1
|
|
|
- cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
|
|
|
- cb(cur, "layer_inp_normed", il);
|
|
|
-
|
|
|
- // self-attention
|
|
|
- {
|
|
|
- ggml_tensor * Qcur = nullptr;
|
|
|
- ggml_tensor * Kcur = nullptr;
|
|
|
- ggml_tensor * Vcur = nullptr;
|
|
|
- if (layer.qkv_w != nullptr) {
|
|
|
- // fused qkv
|
|
|
- cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
|
|
- if (layer.qkv_b != nullptr) {
|
|
|
- cur = ggml_add(ctx0, cur, layer.qkv_b);
|
|
|
- }
|
|
|
-
|
|
|
- Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
- /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
- /* nb2 */ cur->nb[1],
|
|
|
- /* offset */ 0);
|
|
|
-
|
|
|
- Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
- /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
- /* nb2 */ cur->nb[1],
|
|
|
- /* offset */ ggml_row_size(cur->type, n_embd));
|
|
|
-
|
|
|
- Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
|
|
- /* nb1 */ ggml_row_size(cur->type, d_head),
|
|
|
- /* nb2 */ cur->nb[1],
|
|
|
- /* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
|
|
-
|
|
|
- // TODO: q/k norm requires row size == n_embd, while here it's d_head
|
|
|
- // we can add support in the future if needed
|
|
|
- GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr);
|
|
|
-
|
|
|
- } else {
|
|
|
- // separate q, k, v
|
|
|
- Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
|
|
|
- if (layer.q_b) {
|
|
|
- Qcur = ggml_add(ctx0, Qcur, layer.q_b);
|
|
|
- }
|
|
|
-
|
|
|
- Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
|
|
|
- if (layer.k_b) {
|
|
|
- Kcur = ggml_add(ctx0, Kcur, layer.k_b);
|
|
|
- }
|
|
|
-
|
|
|
- Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
|
|
|
- if (layer.v_b) {
|
|
|
- Vcur = ggml_add(ctx0, Vcur, layer.v_b);
|
|
|
- }
|
|
|
-
|
|
|
- if (layer.q_norm) {
|
|
|
- Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
|
|
|
- cb(Qcur, "Qcur_norm", il);
|
|
|
- }
|
|
|
-
|
|
|
- if (layer.k_norm) {
|
|
|
- Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
|
|
|
- cb(Kcur, "Kcur_norm", il);
|
|
|
- }
|
|
|
-
|
|
|
- Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
|
|
- Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
|
|
|
- Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
|
|
|
- }
|
|
|
-
|
|
|
- cb(Qcur, "Qcur", il);
|
|
|
- cb(Kcur, "Kcur", il);
|
|
|
- cb(Vcur, "Vcur", il);
|
|
|
-
|
|
|
- if (add_pos) {
|
|
|
- Qcur = add_pos(Qcur, layer);
|
|
|
- Kcur = add_pos(Kcur, layer);
|
|
|
- cb(Qcur, "Qcur_pos", il);
|
|
|
- cb(Kcur, "Kcur_pos", il);
|
|
|
- }
|
|
|
-
|
|
|
- cur = build_attn(layer.o_w, layer.o_b,
|
|
|
- Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
|
|
- cb(cur, "attn_out", il);
|
|
|
- }
|
|
|
-
|
|
|
- if (layer.ls_1_w) {
|
|
|
- cur = ggml_mul(ctx0, cur, layer.ls_1_w);
|
|
|
- cb(cur, "attn_out_scaled", il);
|
|
|
- }
|
|
|
-
|
|
|
- // re-add the layer input, e.g., residual
|
|
|
- cur = ggml_add(ctx0, cur, inpL);
|
|
|
-
|
|
|
- inpL = cur; // inpL = residual, cur = hidden_states
|
|
|
-
|
|
|
- cb(cur, "ffn_inp", il);
|
|
|
-
|
|
|
- // layernorm2
|
|
|
- cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
|
|
- cb(cur, "ffn_inp_normed", il);
|
|
|
-
|
|
|
- // ffn
|
|
|
- cur = build_ffn(cur,
|
|
|
- layer.ff_up_w, layer.ff_up_b,
|
|
|
- layer.ff_gate_w, layer.ff_gate_b,
|
|
|
- layer.ff_down_w, layer.ff_down_b,
|
|
|
- ffn_t, il);
|
|
|
-
|
|
|
- cb(cur, "ffn_out", il);
|
|
|
+// build the input after conv2d (inp_raw --> patches)
|
|
|
+// returns tensor with shape [n_embd, n_patches]
|
|
|
+ggml_tensor * clip_graph::build_inp() {
|
|
|
+ ggml_tensor * inp_raw = build_inp_raw();
|
|
|
+ ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
+ inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
|
|
+ inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
|
|
+ if (model.patch_bias) {
|
|
|
+ inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
+ cb(inp, "patch_bias", -1);
|
|
|
+ }
|
|
|
+ return inp;
|
|
|
+}
|
|
|
|
|
|
- if (layer.ls_2_w) {
|
|
|
- cur = ggml_mul(ctx0, cur, layer.ls_2_w);
|
|
|
- cb(cur, "ffn_out_scaled", il);
|
|
|
- }
|
|
|
+ggml_tensor * clip_graph::build_inp_raw(int channels) {
|
|
|
+ ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
|
|
|
+ ggml_set_name(inp_raw, "inp_raw");
|
|
|
+ ggml_set_input(inp_raw);
|
|
|
+ return inp_raw;
|
|
|
+}
|
|
|
|
|
|
- // residual 2
|
|
|
- cur = ggml_add(ctx0, inpL, cur);
|
|
|
- cb(cur, "layer_out", il);
|
|
|
+ggml_tensor * clip_graph::build_norm(
|
|
|
+ ggml_tensor * cur,
|
|
|
+ ggml_tensor * mw,
|
|
|
+ ggml_tensor * mb,
|
|
|
+ norm_type type,
|
|
|
+ float norm_eps,
|
|
|
+ int il) const {
|
|
|
|
|
|
- inpL = cur;
|
|
|
- }
|
|
|
+ cur = type == NORM_TYPE_RMS
|
|
|
+ ? ggml_rms_norm(ctx0, cur, norm_eps)
|
|
|
+ : ggml_norm(ctx0, cur, norm_eps);
|
|
|
|
|
|
- if (ctx->model.audio_has_avgpool()) {
|
|
|
- ggml_tensor * cur = inpL;
|
|
|
- cur = ggml_transpose(ctx0, cur);
|
|
|
- cur = ggml_cont(ctx0, cur);
|
|
|
- cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
|
|
|
- cur = ggml_transpose(ctx0, cur);
|
|
|
- cur = ggml_cont(ctx0, cur);
|
|
|
- inpL = cur;
|
|
|
- }
|
|
|
-
|
|
|
- // post-layernorm
|
|
|
- if (model.post_ln_w) {
|
|
|
- inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
|
|
- }
|
|
|
- return inpL;
|
|
|
+ if (mw || mb) {
|
|
|
+ cb(cur, "norm", il);
|
|
|
}
|
|
|
|
|
|
- // build the input after conv2d (inp_raw --> patches)
|
|
|
- // returns tensor with shape [n_embd, n_patches]
|
|
|
- ggml_tensor * build_inp() {
|
|
|
- ggml_tensor * inp_raw = build_inp_raw();
|
|
|
- ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
- inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
|
|
- inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
|
|
- if (model.patch_bias) {
|
|
|
- inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
- cb(inp, "patch_bias", -1);
|
|
|
+ if (mw) {
|
|
|
+ cur = ggml_mul(ctx0, cur, mw);
|
|
|
+ if (mb) {
|
|
|
+ cb(cur, "norm_w", il);
|
|
|
}
|
|
|
- return inp;
|
|
|
}
|
|
|
|
|
|
- ggml_tensor * build_inp_raw(int channels = 3) {
|
|
|
- ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
|
|
|
- ggml_set_name(inp_raw, "inp_raw");
|
|
|
- ggml_set_input(inp_raw);
|
|
|
- return inp_raw;
|
|
|
+ if (mb) {
|
|
|
+ cur = ggml_add(ctx0, cur, mb);
|
|
|
}
|
|
|
|
|
|
- ggml_tensor * build_norm(
|
|
|
- ggml_tensor * cur,
|
|
|
- ggml_tensor * mw,
|
|
|
- ggml_tensor * mb,
|
|
|
- norm_type type,
|
|
|
- float norm_eps,
|
|
|
- int il) const {
|
|
|
-
|
|
|
- cur = type == NORM_TYPE_RMS
|
|
|
- ? ggml_rms_norm(ctx0, cur, norm_eps)
|
|
|
- : ggml_norm(ctx0, cur, norm_eps);
|
|
|
-
|
|
|
- if (mw || mb) {
|
|
|
- cb(cur, "norm", il);
|
|
|
- }
|
|
|
+ return cur;
|
|
|
+}
|
|
|
|
|
|
- if (mw) {
|
|
|
- cur = ggml_mul(ctx0, cur, mw);
|
|
|
- if (mb) {
|
|
|
- cb(cur, "norm_w", il);
|
|
|
- }
|
|
|
- }
|
|
|
+ggml_tensor * clip_graph::build_ffn(
|
|
|
+ ggml_tensor * cur,
|
|
|
+ ggml_tensor * up,
|
|
|
+ ggml_tensor * up_b,
|
|
|
+ ggml_tensor * gate,
|
|
|
+ ggml_tensor * gate_b,
|
|
|
+ ggml_tensor * down,
|
|
|
+ ggml_tensor * down_b,
|
|
|
+ ffn_op_type type_op,
|
|
|
+ int il) const {
|
|
|
|
|
|
- if (mb) {
|
|
|
- cur = ggml_add(ctx0, cur, mb);
|
|
|
- }
|
|
|
+ ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
|
|
|
+ cb(tmp, "ffn_up", il);
|
|
|
|
|
|
- return cur;
|
|
|
+ if (up_b) {
|
|
|
+ tmp = ggml_add(ctx0, tmp, up_b);
|
|
|
+ cb(tmp, "ffn_up_b", il);
|
|
|
}
|
|
|
|
|
|
- ggml_tensor * build_ffn(
|
|
|
- ggml_tensor * cur,
|
|
|
- ggml_tensor * up,
|
|
|
- ggml_tensor * up_b,
|
|
|
- ggml_tensor * gate,
|
|
|
- ggml_tensor * gate_b,
|
|
|
- ggml_tensor * down,
|
|
|
- ggml_tensor * down_b,
|
|
|
- ffn_op_type type_op,
|
|
|
- int il) const {
|
|
|
-
|
|
|
- ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
|
|
|
- cb(tmp, "ffn_up", il);
|
|
|
-
|
|
|
- if (up_b) {
|
|
|
- tmp = ggml_add(ctx0, tmp, up_b);
|
|
|
- cb(tmp, "ffn_up_b", il);
|
|
|
- }
|
|
|
-
|
|
|
- if (gate) {
|
|
|
- cur = ggml_mul_mat(ctx0, gate, cur);
|
|
|
- cb(cur, "ffn_gate", il);
|
|
|
-
|
|
|
- if (gate_b) {
|
|
|
- cur = ggml_add(ctx0, cur, gate_b);
|
|
|
- cb(cur, "ffn_gate_b", il);
|
|
|
- }
|
|
|
- } else {
|
|
|
- cur = tmp;
|
|
|
- }
|
|
|
-
|
|
|
- // we only support parallel ffn for now
|
|
|
- switch (type_op) {
|
|
|
- case FFN_SILU:
|
|
|
- if (gate) {
|
|
|
- cur = ggml_swiglu_split(ctx0, cur, tmp);
|
|
|
- cb(cur, "ffn_swiglu", il);
|
|
|
- } else {
|
|
|
- cur = ggml_silu(ctx0, cur);
|
|
|
- cb(cur, "ffn_silu", il);
|
|
|
- } break;
|
|
|
- case FFN_GELU:
|
|
|
- if (gate) {
|
|
|
- cur = ggml_geglu_split(ctx0, cur, tmp);
|
|
|
- cb(cur, "ffn_geglu", il);
|
|
|
- } else {
|
|
|
- cur = ggml_gelu(ctx0, cur);
|
|
|
- cb(cur, "ffn_gelu", il);
|
|
|
- } break;
|
|
|
- case FFN_GELU_ERF:
|
|
|
- if (gate) {
|
|
|
- cur = ggml_geglu_erf_split(ctx0, cur, tmp);
|
|
|
- cb(cur, "ffn_geglu_erf", il);
|
|
|
- } else {
|
|
|
- cur = ggml_gelu_erf(ctx0, cur);
|
|
|
- cb(cur, "ffn_gelu_erf", il);
|
|
|
- } break;
|
|
|
- case FFN_GELU_QUICK:
|
|
|
- if (gate) {
|
|
|
- cur = ggml_geglu_quick_split(ctx0, cur, tmp);
|
|
|
- cb(cur, "ffn_geglu_quick", il);
|
|
|
- } else {
|
|
|
- cur = ggml_gelu_quick(ctx0, cur);
|
|
|
- cb(cur, "ffn_gelu_quick", il);
|
|
|
- } break;
|
|
|
- }
|
|
|
+ if (gate) {
|
|
|
+ cur = ggml_mul_mat(ctx0, gate, cur);
|
|
|
+ cb(cur, "ffn_gate", il);
|
|
|
|
|
|
- if (down) {
|
|
|
- cur = ggml_mul_mat(ctx0, down, cur);
|
|
|
+ if (gate_b) {
|
|
|
+ cur = ggml_add(ctx0, cur, gate_b);
|
|
|
+ cb(cur, "ffn_gate_b", il);
|
|
|
}
|
|
|
+ } else {
|
|
|
+ cur = tmp;
|
|
|
+ }
|
|
|
|
|
|
- if (down_b) {
|
|
|
- cb(cur, "ffn_down", il);
|
|
|
- }
|
|
|
+ // we only support parallel ffn for now
|
|
|
+ switch (type_op) {
|
|
|
+ case FFN_SILU:
|
|
|
+ if (gate) {
|
|
|
+ cur = ggml_swiglu_split(ctx0, cur, tmp);
|
|
|
+ cb(cur, "ffn_swiglu", il);
|
|
|
+ } else {
|
|
|
+ cur = ggml_silu(ctx0, cur);
|
|
|
+ cb(cur, "ffn_silu", il);
|
|
|
+ } break;
|
|
|
+ case FFN_GELU:
|
|
|
+ if (gate) {
|
|
|
+ cur = ggml_geglu_split(ctx0, cur, tmp);
|
|
|
+ cb(cur, "ffn_geglu", il);
|
|
|
+ } else {
|
|
|
+ cur = ggml_gelu(ctx0, cur);
|
|
|
+ cb(cur, "ffn_gelu", il);
|
|
|
+ } break;
|
|
|
+ case FFN_GELU_ERF:
|
|
|
+ if (gate) {
|
|
|
+ cur = ggml_geglu_erf_split(ctx0, cur, tmp);
|
|
|
+ cb(cur, "ffn_geglu_erf", il);
|
|
|
+ } else {
|
|
|
+ cur = ggml_gelu_erf(ctx0, cur);
|
|
|
+ cb(cur, "ffn_gelu_erf", il);
|
|
|
+ } break;
|
|
|
+ case FFN_GELU_QUICK:
|
|
|
+ if (gate) {
|
|
|
+ cur = ggml_geglu_quick_split(ctx0, cur, tmp);
|
|
|
+ cb(cur, "ffn_geglu_quick", il);
|
|
|
+ } else {
|
|
|
+ cur = ggml_gelu_quick(ctx0, cur);
|
|
|
+ cb(cur, "ffn_gelu_quick", il);
|
|
|
+ } break;
|
|
|
+ }
|
|
|
|
|
|
- if (down_b) {
|
|
|
- cur = ggml_add(ctx0, cur, down_b);
|
|
|
- }
|
|
|
+ if (down) {
|
|
|
+ cur = ggml_mul_mat(ctx0, down, cur);
|
|
|
+ }
|
|
|
|
|
|
- return cur;
|
|
|
+ if (down_b) {
|
|
|
+ cb(cur, "ffn_down", il);
|
|
|
}
|
|
|
|
|
|
- ggml_tensor * build_attn(
|
|
|
- ggml_tensor * wo,
|
|
|
- ggml_tensor * wo_b,
|
|
|
- ggml_tensor * q_cur,
|
|
|
- ggml_tensor * k_cur,
|
|
|
- ggml_tensor * v_cur,
|
|
|
- ggml_tensor * kq_mask,
|
|
|
- float kq_scale,
|
|
|
- int il) const {
|
|
|
- // these nodes are added to the graph together so that they are not reordered
|
|
|
- // by doing so, the number of splits in the graph is reduced
|
|
|
- ggml_build_forward_expand(gf, q_cur);
|
|
|
- ggml_build_forward_expand(gf, k_cur);
|
|
|
- ggml_build_forward_expand(gf, v_cur);
|
|
|
+ if (down_b) {
|
|
|
+ cur = ggml_add(ctx0, cur, down_b);
|
|
|
+ }
|
|
|
|
|
|
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
|
|
- //cb(q, "q", il);
|
|
|
+ return cur;
|
|
|
+}
|
|
|
|
|
|
- ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
|
|
- //cb(k, "k", il);
|
|
|
+ggml_tensor * clip_graph::build_attn(
|
|
|
+ ggml_tensor * wo,
|
|
|
+ ggml_tensor * wo_b,
|
|
|
+ ggml_tensor * q_cur,
|
|
|
+ ggml_tensor * k_cur,
|
|
|
+ ggml_tensor * v_cur,
|
|
|
+ ggml_tensor * kq_mask,
|
|
|
+ float kq_scale,
|
|
|
+ int il) const {
|
|
|
+ // these nodes are added to the graph together so that they are not reordered
|
|
|
+ // by doing so, the number of splits in the graph is reduced
|
|
|
+ ggml_build_forward_expand(gf, q_cur);
|
|
|
+ ggml_build_forward_expand(gf, k_cur);
|
|
|
+ ggml_build_forward_expand(gf, v_cur);
|
|
|
|
|
|
- ggml_tensor * cur;
|
|
|
+ ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
|
|
+ //cb(q, "q", il);
|
|
|
|
|
|
- if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
|
|
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
|
|
+ ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
|
|
+ //cb(k, "k", il);
|
|
|
|
|
|
- k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
|
|
- v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
|
|
+ ggml_tensor * cur;
|
|
|
|
|
|
- cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
|
|
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
|
+ if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
|
|
+ ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
|
|
|
|
|
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
|
|
+ k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
|
|
+ v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
|
|
|
|
|
- } else {
|
|
|
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
|
|
- v = ggml_cont(ctx0, v);
|
|
|
+ cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
|
|
+ ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
|
|
|
|
- const auto n_tokens = q->ne[1];
|
|
|
- const auto n_head = q->ne[2];
|
|
|
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
|
|
|
|
|
- ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
|
- // F32 may not needed for vision encoders?
|
|
|
- // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
|
+ } else {
|
|
|
+ ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
|
|
+ v = ggml_cont(ctx0, v);
|
|
|
|
|
|
- kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
|
|
|
+ const auto n_tokens = q->ne[1];
|
|
|
+ const auto n_head = q->ne[2];
|
|
|
|
|
|
- ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
|
|
- cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
|
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
|
|
- }
|
|
|
+ ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
|
+ // F32 may not needed for vision encoders?
|
|
|
+ // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
|
|
|
|
- cb(cur, "kqv_out", il);
|
|
|
+ kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
|
|
|
|
|
|
- if (wo) {
|
|
|
- cur = ggml_mul_mat(ctx0, wo, cur);
|
|
|
- }
|
|
|
+ ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
|
|
+ cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
|
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
|
|
+ }
|
|
|
|
|
|
- if (wo_b) {
|
|
|
- cur = ggml_add(ctx0, cur, wo_b);
|
|
|
- }
|
|
|
+ cb(cur, "kqv_out", il);
|
|
|
|
|
|
- return cur;
|
|
|
+ if (wo) {
|
|
|
+ cur = ggml_mul_mat(ctx0, wo, cur);
|
|
|
}
|
|
|
|
|
|
- // implementation of the 2D RoPE without adding a new op in ggml
|
|
|
- // this is not efficient (use double the memory), but works on all backends
|
|
|
- // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
|
|
|
- static ggml_tensor * build_rope_2d(
|
|
|
- ggml_context * ctx0,
|
|
|
- ggml_tensor * cur,
|
|
|
- ggml_tensor * pos_a, // first half
|
|
|
- ggml_tensor * pos_b, // second half
|
|
|
- const float freq_base,
|
|
|
- const bool interleave_freq
|
|
|
- ) {
|
|
|
- const int64_t n_dim = cur->ne[0];
|
|
|
- const int64_t n_head = cur->ne[1];
|
|
|
- const int64_t n_pos = cur->ne[2];
|
|
|
-
|
|
|
- // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
|
|
|
- // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
|
|
|
- // first half of cur will use 1e-0, 1e-2 (even)
|
|
|
- // second half of cur will use 1e-1, 1e-3 (odd)
|
|
|
- // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
|
|
|
- // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
|
|
|
- // then for the second half, we use freq_scale to shift the inv_freq
|
|
|
- // ^ why? replace (2i) with (2i+1) in the above equation
|
|
|
- const float freq_scale_odd = interleave_freq
|
|
|
- ? std::pow(freq_base, (float)-2/n_dim)
|
|
|
- : 1.0;
|
|
|
-
|
|
|
- // first half
|
|
|
- ggml_tensor * first;
|
|
|
- {
|
|
|
- first = ggml_view_3d(ctx0, cur,
|
|
|
- n_dim/2, n_head, n_pos,
|
|
|
- ggml_row_size(cur->type, n_dim),
|
|
|
- ggml_row_size(cur->type, n_dim*n_head),
|
|
|
- 0);
|
|
|
- first = ggml_rope_ext(
|
|
|
- ctx0,
|
|
|
- first,
|
|
|
- pos_a, // positions
|
|
|
- nullptr, // freq factors
|
|
|
- n_dim/2, // n_dims
|
|
|
- 0, 0, freq_base,
|
|
|
- 1.0f, 0.0f, 1.0f, 0.0f, 0.0f
|
|
|
- );
|
|
|
- }
|
|
|
+ if (wo_b) {
|
|
|
+ cur = ggml_add(ctx0, cur, wo_b);
|
|
|
+ }
|
|
|
|
|
|
- // second half
|
|
|
- ggml_tensor * second;
|
|
|
- {
|
|
|
- second = ggml_view_3d(ctx0, cur,
|
|
|
- n_dim/2, n_head, n_pos,
|
|
|
- ggml_row_size(cur->type, n_dim),
|
|
|
- ggml_row_size(cur->type, n_dim*n_head),
|
|
|
- n_dim/2 * ggml_element_size(cur));
|
|
|
- second = ggml_rope_ext(
|
|
|
- ctx0,
|
|
|
- second,
|
|
|
- pos_b, // positions
|
|
|
- nullptr, // freq factors
|
|
|
- n_dim/2, // n_dims
|
|
|
- 0, 0, freq_base,
|
|
|
- freq_scale_odd,
|
|
|
- 0.0f, 1.0f, 0.0f, 0.0f
|
|
|
- );
|
|
|
- }
|
|
|
+ return cur;
|
|
|
+}
|
|
|
|
|
|
- cur = ggml_concat(ctx0, first, second, 0);
|
|
|
- return cur;
|
|
|
+// implementation of the 2D RoPE without adding a new op in ggml
|
|
|
+// this is not efficient (use double the memory), but works on all backends
|
|
|
+// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
|
|
|
+ggml_tensor * clip_graph::build_rope_2d(
|
|
|
+ ggml_context * ctx0,
|
|
|
+ ggml_tensor * cur,
|
|
|
+ ggml_tensor * pos_a, // first half
|
|
|
+ ggml_tensor * pos_b, // second half
|
|
|
+ const float freq_base,
|
|
|
+ const bool interleave_freq
|
|
|
+) {
|
|
|
+ const int64_t n_dim = cur->ne[0];
|
|
|
+ const int64_t n_head = cur->ne[1];
|
|
|
+ const int64_t n_pos = cur->ne[2];
|
|
|
+
|
|
|
+ // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
|
|
|
+ // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
|
|
|
+ // first half of cur will use 1e-0, 1e-2 (even)
|
|
|
+ // second half of cur will use 1e-1, 1e-3 (odd)
|
|
|
+ // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
|
|
|
+ // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
|
|
|
+ // then for the second half, we use freq_scale to shift the inv_freq
|
|
|
+ // ^ why? replace (2i) with (2i+1) in the above equation
|
|
|
+ const float freq_scale_odd = interleave_freq
|
|
|
+ ? std::pow(freq_base, (float)-2/n_dim)
|
|
|
+ : 1.0;
|
|
|
+
|
|
|
+ // first half
|
|
|
+ ggml_tensor * first;
|
|
|
+ {
|
|
|
+ first = ggml_view_3d(ctx0, cur,
|
|
|
+ n_dim/2, n_head, n_pos,
|
|
|
+ ggml_row_size(cur->type, n_dim),
|
|
|
+ ggml_row_size(cur->type, n_dim*n_head),
|
|
|
+ 0);
|
|
|
+ first = ggml_rope_ext(
|
|
|
+ ctx0,
|
|
|
+ first,
|
|
|
+ pos_a, // positions
|
|
|
+ nullptr, // freq factors
|
|
|
+ n_dim/2, // n_dims
|
|
|
+ 0, 0, freq_base,
|
|
|
+ 1.0f, 0.0f, 1.0f, 0.0f, 0.0f
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
- // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
|
|
|
- // support dynamic resolution
|
|
|
- ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
|
|
|
- GGML_ASSERT(scale_factor > 1);
|
|
|
-
|
|
|
- const int n_embd = cur->ne[0];
|
|
|
- int width = img.nx / patch_size;
|
|
|
- int height = img.ny / patch_size;
|
|
|
-
|
|
|
- // pad width and height to factor
|
|
|
- const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
|
|
|
- const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
|
|
|
- cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
|
|
|
- if (pad_width || pad_height) {
|
|
|
- cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
|
|
|
- width += pad_width;
|
|
|
- height += pad_height;
|
|
|
- }
|
|
|
+ // second half
|
|
|
+ ggml_tensor * second;
|
|
|
+ {
|
|
|
+ second = ggml_view_3d(ctx0, cur,
|
|
|
+ n_dim/2, n_head, n_pos,
|
|
|
+ ggml_row_size(cur->type, n_dim),
|
|
|
+ ggml_row_size(cur->type, n_dim*n_head),
|
|
|
+ n_dim/2 * ggml_element_size(cur));
|
|
|
+ second = ggml_rope_ext(
|
|
|
+ ctx0,
|
|
|
+ second,
|
|
|
+ pos_b, // positions
|
|
|
+ nullptr, // freq factors
|
|
|
+ n_dim/2, // n_dims
|
|
|
+ 0, 0, freq_base,
|
|
|
+ freq_scale_odd,
|
|
|
+ 0.0f, 1.0f, 0.0f, 0.0f
|
|
|
+ );
|
|
|
+ }
|
|
|
|
|
|
- // unshuffle h
|
|
|
- cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
|
|
|
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
+ cur = ggml_concat(ctx0, first, second, 0);
|
|
|
+ return cur;
|
|
|
+}
|
|
|
|
|
|
- // unshuffle w
|
|
|
- cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
|
|
|
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
+// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
|
|
|
+// support dynamic resolution
|
|
|
+ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
|
|
|
+ GGML_ASSERT(scale_factor > 1);
|
|
|
|
|
|
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
|
|
- cb(cur, "pixel_shuffle", -1);
|
|
|
+ const int n_embd = cur->ne[0];
|
|
|
+ int width = img.nx / patch_size;
|
|
|
+ int height = img.ny / patch_size;
|
|
|
|
|
|
- return cur;
|
|
|
+ // pad width and height to factor
|
|
|
+ const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
|
|
|
+ const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
|
|
|
+ cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
|
|
|
+ if (pad_width || pad_height) {
|
|
|
+ cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
|
|
|
+ width += pad_width;
|
|
|
+ height += pad_height;
|
|
|
}
|
|
|
|
|
|
-};
|
|
|
+ // unshuffle h
|
|
|
+ cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
|
|
|
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
+
|
|
|
+ // unshuffle w
|
|
|
+ cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
|
|
|
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
+
|
|
|
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
|
|
+ cb(cur, "pixel_shuffle", -1);
|
|
|
+
|
|
|
+ return cur;
|
|
|
+}
|
|
|
|
|
|
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
|
|
GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
|
|
|
- clip_graph graph(ctx, *imgs.entries[0]);
|
|
|
|
|
|
- ggml_cgraph * res;
|
|
|
+ const clip_image_f32 & img = *imgs.entries[0];
|
|
|
+ std::unique_ptr<clip_graph> builder;
|
|
|
|
|
|
switch (ctx->proj_type()) {
|
|
|
case PROJECTOR_TYPE_GEMMA3:
|
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
|
case PROJECTOR_TYPE_LFM2:
|
|
|
+ case PROJECTOR_TYPE_JANUS_PRO:
|
|
|
{
|
|
|
- res = graph.build_siglip();
|
|
|
+ builder = std::make_unique<clip_graph_siglip>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_PIXTRAL:
|
|
|
case PROJECTOR_TYPE_LIGHTONOCR:
|
|
|
{
|
|
|
- res = graph.build_pixtral();
|
|
|
+ builder = std::make_unique<clip_graph_pixtral>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_QWEN2VL:
|
|
|
case PROJECTOR_TYPE_QWEN25VL:
|
|
|
{
|
|
|
- res = graph.build_qwen2vl();
|
|
|
+ builder = std::make_unique<clip_graph_qwen2vl>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_QWEN3VL:
|
|
|
{
|
|
|
- res = graph.build_qwen3vl();
|
|
|
+ builder = std::make_unique<clip_graph_qwen3vl>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_MINICPMV:
|
|
|
{
|
|
|
- res = graph.build_minicpmv();
|
|
|
+ builder = std::make_unique<clip_graph_minicpmv>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_INTERNVL:
|
|
|
{
|
|
|
- res = graph.build_internvl();
|
|
|
+ builder = std::make_unique<clip_graph_internvl>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_LLAMA4:
|
|
|
{
|
|
|
- res = graph.build_llama4();
|
|
|
+ builder = std::make_unique<clip_graph_llama4>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_ULTRAVOX:
|
|
|
case PROJECTOR_TYPE_VOXTRAL:
|
|
|
case PROJECTOR_TYPE_QWEN2A:
|
|
|
{
|
|
|
- res = graph.build_whisper_enc();
|
|
|
+ builder = std::make_unique<clip_graph_whisper_enc>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_KIMIVL:
|
|
|
{
|
|
|
- res = graph.build_kimivl();
|
|
|
- } break;
|
|
|
- case PROJECTOR_TYPE_JANUS_PRO:
|
|
|
- {
|
|
|
- res = graph.build_siglip();
|
|
|
+ builder = std::make_unique<clip_graph_kimivl>(ctx, img);
|
|
|
} break;
|
|
|
case PROJECTOR_TYPE_COGVLM:
|
|
|
{
|
|
|
- res = graph.build_cogvlm();
|
|
|
+ builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
|
|
|
} break;
|
|
|
- default:
|
|
|
+ case PROJECTOR_TYPE_MLP:
|
|
|
+ case PROJECTOR_TYPE_MLP_NORM:
|
|
|
+ case PROJECTOR_TYPE_LDP:
|
|
|
+ case PROJECTOR_TYPE_LDPV2:
|
|
|
+ case PROJECTOR_TYPE_GLM_EDGE:
|
|
|
{
|
|
|
- res = graph.build_llava();
|
|
|
+ builder = std::make_unique<clip_graph_llava>(ctx, img);
|
|
|
} break;
|
|
|
+ default:
|
|
|
+ GGML_ABORT("missing cgraph builder");
|
|
|
}
|
|
|
- return res;
|
|
|
+
|
|
|
+ return builder->build();
|
|
|
}
|
|
|
|
|
|
+//
|
|
|
+// clip_model_loader
|
|
|
+//
|
|
|
+
|
|
|
struct clip_model_loader {
|
|
|
ggml_context_ptr ctx_meta;
|
|
|
gguf_context_ptr ctx_gguf;
|