Преглед изворни кода

llama : check all graph nodes when searching for result_embd_pooled (#8956)

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
fairydreaming пре 1 година
родитељ
комит
33309f661a
1 измењених фајлова са 8 додато и 5 уклоњено
  1. 8 5
      src/llama.cpp

+ 8 - 5
src/llama.cpp

@@ -14722,12 +14722,15 @@ static int llama_decode_internal(
             res  = nullptr;
             embd = nullptr;
         } else if (cparams.embeddings) {
-            res = nullptr; // do not extract logits for embedding case
-            embd = gf->nodes[gf->n_nodes - 1];
-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+            res  = nullptr; // do not extract logits for embedding case
+            embd = nullptr;
+            for (int i = gf->n_nodes - 1; i >= 0; --i) {
+                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
+                    embd = gf->nodes[i];
+                    break;
+                }
             }
-            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");