|
|
@@ -14722,12 +14722,15 @@ static int llama_decode_internal(
|
|
|
res = nullptr;
|
|
|
embd = nullptr;
|
|
|
} else if (cparams.embeddings) {
|
|
|
- res = nullptr; // do not extract logits for embedding case
|
|
|
- embd = gf->nodes[gf->n_nodes - 1];
|
|
|
- if (strcmp(embd->name, "result_embd_pooled") != 0) {
|
|
|
- embd = gf->nodes[gf->n_nodes - 2];
|
|
|
+ res = nullptr; // do not extract logits for embedding case
|
|
|
+ embd = nullptr;
|
|
|
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
|
|
+ if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
|
|
|
+ embd = gf->nodes[i];
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
- GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
|
|
+ GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
|
|
|
} else {
|
|
|
embd = nullptr; // do not extract embeddings when not needed
|
|
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|