瀏覽代碼

llama : check all graph nodes when searching for result_embd_pooled (#8956)

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
fairydreaming 1 年之前
父節點
當前提交
33309f661a
共有 1 個文件被更改,包括 8 次插入5 次删除
  1. 8 5
      src/llama.cpp

+ 8 - 5
src/llama.cpp

@@ -14722,12 +14722,15 @@ static int llama_decode_internal(
             res  = nullptr;
             res  = nullptr;
             embd = nullptr;
             embd = nullptr;
         } else if (cparams.embeddings) {
         } else if (cparams.embeddings) {
-            res = nullptr; // do not extract logits for embedding case
-            embd = gf->nodes[gf->n_nodes - 1];
-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+            res  = nullptr; // do not extract logits for embedding case
+            embd = nullptr;
+            for (int i = gf->n_nodes - 1; i >= 0; --i) {
+                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
+                    embd = gf->nodes[i];
+                    break;
+                }
             }
             }
-            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
         } else {
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");