|
@@ -1175,10 +1175,11 @@ struct clip_graph {
|
|
|
cb(K, "resampler_K", -1);
|
|
cb(K, "resampler_K", -1);
|
|
|
cb(V, "resampler_V", -1);
|
|
cb(V, "resampler_V", -1);
|
|
|
|
|
|
|
|
|
|
+ float resampler_kq_scale = 1.0f/ sqrtf(float(d_head));
|
|
|
embeddings = build_attn(
|
|
embeddings = build_attn(
|
|
|
model.mm_model_attn_o_w,
|
|
model.mm_model_attn_o_w,
|
|
|
model.mm_model_attn_o_b,
|
|
model.mm_model_attn_o_b,
|
|
|
- Q, K, V, nullptr, kq_scale, -1);
|
|
|
|
|
|
|
+ Q, K, V, nullptr, resampler_kq_scale, -1);
|
|
|
cb(embeddings, "resampler_attn_out", -1);
|
|
cb(embeddings, "resampler_attn_out", -1);
|
|
|
}
|
|
}
|
|
|
// layernorm
|
|
// layernorm
|