|
@@ -215,6 +215,10 @@ struct clip_layer {
|
|
|
// layernorm 2
|
|
// layernorm 2
|
|
|
ggml_tensor * ln_2_w = nullptr;
|
|
ggml_tensor * ln_2_w = nullptr;
|
|
|
ggml_tensor * ln_2_b = nullptr;
|
|
ggml_tensor * ln_2_b = nullptr;
|
|
|
|
|
+
|
|
|
|
|
+ // layer scale (no bias)
|
|
|
|
|
+ ggml_tensor * ls_1_w = nullptr;
|
|
|
|
|
+ ggml_tensor * ls_2_w = nullptr;
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
struct clip_vision_model {
|
|
struct clip_vision_model {
|
|
@@ -589,6 +593,9 @@ struct clip_graph {
|
|
|
|
|
|
|
|
// Qwen2VL and Qwen2.5VL use M-RoPE
|
|
// Qwen2VL and Qwen2.5VL use M-RoPE
|
|
|
ggml_cgraph * build_qwen2vl() {
|
|
ggml_cgraph * build_qwen2vl() {
|
|
|
|
|
+ GGML_ASSERT(model.patch_bias == nullptr);
|
|
|
|
|
+ GGML_ASSERT(model.class_embedding == nullptr);
|
|
|
|
|
+
|
|
|
const int batch_size = 1;
|
|
const int batch_size = 1;
|
|
|
const bool use_window_attn = hparams.n_wa_pattern > 0;
|
|
const bool use_window_attn = hparams.n_wa_pattern > 0;
|
|
|
const int n_wa_pattern = hparams.n_wa_pattern;
|
|
const int n_wa_pattern = hparams.n_wa_pattern;
|
|
@@ -625,10 +632,6 @@ struct clip_graph {
|
|
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (model.patch_bias) {
|
|
|
|
|
- inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
ggml_tensor * inpL = inp;
|
|
ggml_tensor * inpL = inp;
|
|
|
ggml_tensor * window_mask = nullptr;
|
|
ggml_tensor * window_mask = nullptr;
|
|
|
ggml_tensor * window_idx = nullptr;
|
|
ggml_tensor * window_idx = nullptr;
|
|
@@ -859,6 +862,67 @@ struct clip_graph {
|
|
|
return gf;
|
|
return gf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ ggml_cgraph * build_internvl() {
|
|
|
|
|
+ GGML_ASSERT(model.class_embedding != nullptr);
|
|
|
|
|
+ GGML_ASSERT(model.position_embeddings != nullptr);
|
|
|
|
|
+
|
|
|
|
|
+ const int n_pos = n_patches + 1;
|
|
|
|
|
+ ggml_tensor * inp = build_inp();
|
|
|
|
|
+
|
|
|
|
|
+ // add CLS token
|
|
|
|
|
+ inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor * cur = build_vit(
|
|
|
|
|
+ inp, n_pos,
|
|
|
|
|
+ NORM_TYPE_NORMAL,
|
|
|
|
|
+ hparams.ffn_op,
|
|
|
|
|
+ model.position_embeddings,
|
|
|
|
|
+ nullptr);
|
|
|
|
|
+
|
|
|
|
|
+ // remove CLS token
|
|
|
|
|
+ cur = ggml_view_2d(ctx0, cur,
|
|
|
|
|
+ n_embd, n_patches,
|
|
|
|
|
+ ggml_row_size(cur->type, n_embd), 0);
|
|
|
|
|
+
|
|
|
|
|
+ // pixel shuffle
|
|
|
|
|
+ {
|
|
|
|
|
+ const int scale_factor = model.hparams.proj_scale_factor;
|
|
|
|
|
+ const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
|
|
|
|
+ const int height = n_patches_y;
|
|
|
|
|
+ const int width = n_patches_x;
|
|
|
|
|
+ GGML_ASSERT(scale_factor > 0);
|
|
|
|
|
+ cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
|
|
|
|
|
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
|
|
+ cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
|
|
|
|
|
+ n_embd * scale_factor * scale_factor,
|
|
|
|
|
+ height / scale_factor,
|
|
|
|
|
+ width / scale_factor,
|
|
|
|
|
+ bsz);
|
|
|
|
|
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
|
|
|
+ // flatten to 2D
|
|
|
|
|
+ cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
|
|
|
|
|
+ n_embd * scale_factor * scale_factor,
|
|
|
|
|
+ cur->ne[1] * cur->ne[2]);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // projector (always using GELU activation)
|
|
|
|
|
+ {
|
|
|
|
|
+ // projector LayerNorm uses pytorch's default eps = 1e-5
|
|
|
|
|
+ // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
|
|
|
|
|
+ cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
|
|
|
|
+ cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
|
|
|
|
+ cur = ggml_add(ctx0, cur, model.mm_1_b);
|
|
|
|
|
+ cur = ggml_gelu(ctx0, cur);
|
|
|
|
|
+ cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
|
|
|
|
|
+ cur = ggml_add(ctx0, cur, model.mm_3_b);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // build the graph
|
|
|
|
|
+ ggml_build_forward_expand(gf, cur);
|
|
|
|
|
+
|
|
|
|
|
+ return gf;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// this graph is used by llava, granite and glm
|
|
// this graph is used by llava, granite and glm
|
|
|
// due to having embedding_stack (used by granite), we cannot reuse build_vit
|
|
// due to having embedding_stack (used by granite), we cannot reuse build_vit
|
|
|
ggml_cgraph * build_llava() {
|
|
ggml_cgraph * build_llava() {
|
|
@@ -890,10 +954,6 @@ struct clip_graph {
|
|
|
|
|
|
|
|
ggml_tensor * inp = build_inp();
|
|
ggml_tensor * inp = build_inp();
|
|
|
|
|
|
|
|
- if (model.patch_bias) {
|
|
|
|
|
- inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
// concat class_embeddings and patch_embeddings
|
|
// concat class_embeddings and patch_embeddings
|
|
|
if (model.class_embedding) {
|
|
if (model.class_embedding) {
|
|
|
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
|
@@ -1260,11 +1320,6 @@ private:
|
|
|
ggml_tensor * learned_pos_embd,
|
|
ggml_tensor * learned_pos_embd,
|
|
|
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
|
|
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
|
|
|
) {
|
|
) {
|
|
|
- if (model.patch_bias) {
|
|
|
|
|
- inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
|
|
- cb(inp, "patch_bias", -1);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
if (learned_pos_embd) {
|
|
if (learned_pos_embd) {
|
|
|
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
|
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
|
|
cb(inp, "pos_embed", -1);
|
|
cb(inp, "pos_embed", -1);
|
|
@@ -1324,6 +1379,11 @@ private:
|
|
|
cb(cur, "attn_out", il);
|
|
cb(cur, "attn_out", il);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (layer.ls_1_w) {
|
|
|
|
|
+ cur = ggml_mul(ctx0, cur, layer.ls_1_w);
|
|
|
|
|
+ cb(cur, "attn_out_scaled", il);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// re-add the layer input, e.g., residual
|
|
// re-add the layer input, e.g., residual
|
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
cur = ggml_add(ctx0, cur, inpL);
|
|
|
|
|
|
|
@@ -1344,6 +1404,11 @@ private:
|
|
|
|
|
|
|
|
cb(cur, "ffn_out", il);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
|
|
|
|
|
|
+ if (layer.ls_2_w) {
|
|
|
|
|
+ cur = ggml_mul(ctx0, cur, layer.ls_2_w);
|
|
|
|
|
+ cb(cur, "ffn_out_scaled", il);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// residual 2
|
|
// residual 2
|
|
|
cur = ggml_add(ctx0, inpL, cur);
|
|
cur = ggml_add(ctx0, inpL, cur);
|
|
|
cb(cur, "layer_out", il);
|
|
cb(cur, "layer_out", il);
|
|
@@ -1365,6 +1430,10 @@ private:
|
|
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
|
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
|
|
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
|
|
|
|
+ if (model.patch_bias) {
|
|
|
|
|
+ inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
|
|
|
+ cb(inp, "patch_bias", -1);
|
|
|
|
|
+ }
|
|
|
return inp;
|
|
return inp;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1627,6 +1696,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|
|
{
|
|
{
|
|
|
res = graph.build_minicpmv();
|
|
res = graph.build_minicpmv();
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case PROJECTOR_TYPE_INTERNVL:
|
|
|
|
|
+ {
|
|
|
|
|
+ res = graph.build_internvl();
|
|
|
|
|
+ } break;
|
|
|
default:
|
|
default:
|
|
|
{
|
|
{
|
|
|
res = graph.build_llava();
|
|
res = graph.build_llava();
|
|
@@ -1790,6 +1863,7 @@ struct clip_model_loader {
|
|
|
}
|
|
}
|
|
|
} break;
|
|
} break;
|
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
|
|
|
+ case PROJECTOR_TYPE_INTERNVL:
|
|
|
{
|
|
{
|
|
|
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
|
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
|
|
} break;
|
|
} break;
|
|
@@ -1897,6 +1971,9 @@ struct clip_model_loader {
|
|
|
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
|
|
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
|
|
|
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
|
|
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
|
|
|
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
|
|
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
|
|
|
|
|
+ layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
|
|
|
|
|
+ layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
|
|
|
|
|
+
|
|
|
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
|
|
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
|
|
|
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
|
|
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
|
|
|
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
|
|
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
|
|
@@ -1904,7 +1981,7 @@ struct clip_model_loader {
|
|
|
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
|
|
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
|
|
|
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
|
|
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
|
|
|
|
|
|
|
|
- // new naming
|
|
|
|
|
|
|
+ // ffn
|
|
|
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
|
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
|
|
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
|
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
|
|
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
|
|
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
|
|
@@ -2052,6 +2129,15 @@ struct clip_model_loader {
|
|
|
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
|
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
|
|
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
|
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case PROJECTOR_TYPE_INTERNVL:
|
|
|
|
|
+ {
|
|
|
|
|
+ vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
|
|
|
|
+ vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
|
|
|
|
|
+ vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
|
|
|
|
|
+ vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
|
|
|
|
|
+ vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
|
|
|
|
+ vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
|
|
|
|
|
+ } break;
|
|
|
default:
|
|
default:
|
|
|
GGML_ASSERT(false && "unknown projector type");
|
|
GGML_ASSERT(false && "unknown projector type");
|
|
|
}
|
|
}
|
|
@@ -2838,7 +2924,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|
|
}
|
|
}
|
|
|
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
|
|
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
|
|
|
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|
|
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|
|
|
- || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
|
|
|
|
|
|
+ || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
|
|
|
|
|
+ || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
|
|
|
|
|
+ ) {
|
|
|
clip_image_u8 resized_image;
|
|
clip_image_u8 resized_image;
|
|
|
int sz = params.image_size;
|
|
int sz = params.image_size;
|
|
|
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
|
|
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
|
|
@@ -2988,9 +3076,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|
|
|
|
|
|
|
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
|
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
|
|
|
|
|
|
|
- if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
|
|
|
|
|
|
|
+ if (ctx->proj_type == PROJECTOR_TYPE_LDP
|
|
|
|
|
+ || ctx->proj_type == PROJECTOR_TYPE_LDPV2
|
|
|
|
|
+ || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
|
|
|
n_patches /= 4;
|
|
n_patches /= 4;
|
|
|
- n_patches += 2; // for BOI and EOI token embeddings
|
|
|
|
|
|
|
+ if (ctx->vision_model.mm_glm_tok_boi) {
|
|
|
|
|
+ n_patches += 2; // for BOI and EOI token embeddings
|
|
|
|
|
+ }
|
|
|
} else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
|
|
} else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
|
|
|
if (ctx->minicpmv_version == 2) {
|
|
if (ctx->minicpmv_version == 2) {
|
|
|
n_patches = 96;
|
|
n_patches = 96;
|
|
@@ -3013,7 +3105,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|
|
int n_per_side = params.image_size / params.patch_size;
|
|
int n_per_side = params.image_size / params.patch_size;
|
|
|
int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
|
|
int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
|
|
|
n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
|
|
n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
|
|
|
- } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
|
|
|
|
|
|
+ } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
|
|
|
|
|
+ // both W and H are divided by proj_scale_factor
|
|
|
n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
|
|
n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
|
|
|
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
|
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
|
|
int n_merge = params.spatial_merge_size;
|
|
int n_merge = params.spatial_merge_size;
|
|
@@ -3408,6 +3501,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|
|
} break;
|
|
} break;
|
|
|
case PROJECTOR_TYPE_GEMMA3:
|
|
case PROJECTOR_TYPE_GEMMA3:
|
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
|
|
|
+ case PROJECTOR_TYPE_INTERNVL:
|
|
|
{
|
|
{
|
|
|
// do nothing
|
|
// do nothing
|
|
|
} break;
|
|
} break;
|
|
@@ -3434,6 +3528,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|
|
// the last node is the embedding tensor
|
|
// the last node is the embedding tensor
|
|
|
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
|
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
|
|
|
|
|
|
|
|
|
+ // sanity check (only support batch size of 1 for now)
|
|
|
|
|
+ const int n_tokens_out = embeddings->ne[1];
|
|
|
|
|
+ const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
|
|
|
|
|
+ if (n_tokens_out != expected_n_tokens_out) {
|
|
|
|
|
+ LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
|
|
|
|
+ GGML_ABORT("Invalid number of output tokens");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// copy the embeddings to the location passed by the user
|
|
// copy the embeddings to the location passed by the user
|
|
|
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
|
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
|
|
|
|
|
|
@@ -3604,6 +3706,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|
|
return ctx->vision_model.mm_input_proj_w->ne[0];
|
|
return ctx->vision_model.mm_input_proj_w->ne[0];
|
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
case PROJECTOR_TYPE_IDEFICS3:
|
|
|
return ctx->vision_model.projection->ne[1];
|
|
return ctx->vision_model.projection->ne[1];
|
|
|
|
|
+ case PROJECTOR_TYPE_INTERNVL:
|
|
|
|
|
+ return ctx->vision_model.mm_3_w->ne[1];
|
|
|
default:
|
|
default:
|
|
|
GGML_ABORT("Unknown projector type");
|
|
GGML_ABORT("Unknown projector type");
|
|
|
}
|
|
}
|