Sfoglia il codice sorgente

examples : support encoder-decoder models in the simple example (#16002)

Signed-off-by: Jie Fu <jiefu@tencent.com>
Jie Fu (傅杰) 4 mesi fa
parent
commit
1cbd80f8cf
1 ha cambiato i file con 14 aggiunte e 0 eliminazioni
  1. 14 0
      examples/simple/simple.cpp

+ 14 - 0
examples/simple/simple.cpp

@@ -145,6 +145,20 @@ int main(int argc, char ** argv) {
 
     llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
 
+    if (llama_model_has_encoder(model)) {
+        if (llama_encode(ctx, batch)) {
+            fprintf(stderr, "%s : failed to eval\n", __func__);
+            return 1;
+        }
+
+        llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+        if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
+            decoder_start_token_id = llama_vocab_bos(vocab);
+        }
+
+        batch = llama_batch_get_one(&decoder_start_token_id, 1);
+    }
+
     // main loop
 
     const auto t_main_start = ggml_time_us();