|
|
@@ -1562,20 +1562,25 @@ void llm_graph_context::build_pooling(
|
|
|
ggml_tensor * inp_cls = build_inp_cls();
|
|
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
|
|
|
|
|
- // classification head
|
|
|
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
|
- GGML_ASSERT(cls != nullptr);
|
|
|
- GGML_ASSERT(cls_b != nullptr);
|
|
|
-
|
|
|
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
|
|
- cur = ggml_tanh(ctx0, cur);
|
|
|
-
|
|
|
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
|
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
|
|
- if (cls_out) {
|
|
|
+ if (cls != nullptr && cls_b != nullptr) {
|
|
|
+ // classification head
|
|
|
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
|
|
+ cur = ggml_tanh(ctx0, cur);
|
|
|
+
|
|
|
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
|
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
|
|
+ if (cls_out) {
|
|
|
+ GGML_ASSERT(cls_out_b != nullptr);
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
|
|
+ }
|
|
|
+ } else if (cls_out) {
|
|
|
+ // Single layer classification head (direct projection)
|
|
|
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
|
|
GGML_ASSERT(cls_out_b != nullptr);
|
|
|
-
|
|
|
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
|
|
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
|
|
|
+ } else {
|
|
|
+ GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
|
|
}
|
|
|
} break;
|
|
|
default:
|