Ver código fonte

examples : add llama_init_from_gpt_params() common function (#1290)

Signed-off-by: deadprogram <ron@hybridgroup.com>
Ron Evans 2 anos atrás
pai
commit
67c77799e0

+ 31 - 0
examples/common.cpp

@@ -405,6 +405,37 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
     return res;
 }
 
+struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
+    auto lparams = llama_context_default_params();
+
+    lparams.n_ctx      = params.n_ctx;
+    lparams.n_parts    = params.n_parts;
+    lparams.seed       = params.seed;
+    lparams.f16_kv     = params.memory_f16;
+    lparams.use_mmap   = params.use_mmap;
+    lparams.use_mlock  = params.use_mlock;
+
+    llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
+
+    if (lctx == NULL) {
+        fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+        return NULL;
+    }
+
+    if (!params.lora_adapter.empty()) {
+        int err = llama_apply_lora_from_file(lctx,
+                                             params.lora_adapter.c_str(),
+                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+                                             params.n_threads);
+        if (err != 0) {
+            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+            return NULL;
+        }
+    }
+
+    return lctx;
+}
+
 /* Keep track of current color of output, and emit ANSI code if it changes. */
 void set_console_color(console_state & con_st, console_color_t color) {
     if (con_st.use_color && con_st.color != color) {

+ 6 - 0
examples/common.h

@@ -77,6 +77,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
 
 std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
 
+//
+// Model utils
+//
+
+struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
+
 //
 // Console utils
 //

+ 4 - 18
examples/embedding/embedding.cpp

@@ -35,24 +35,10 @@ int main(int argc, char ** argv) {
     llama_context * ctx;
 
     // load the model
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.n_ctx      = params.n_ctx;
-        lparams.n_parts    = params.n_parts;
-        lparams.seed       = params.seed;
-        lparams.f16_kv     = params.memory_f16;
-        lparams.logits_all = params.perplexity;
-        lparams.use_mmap   = params.use_mmap;
-        lparams.use_mlock  = params.use_mlock;
-        lparams.embedding  = params.embedding;
-
-        ctx = llama_init_from_file(params.model.c_str(), lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
-            return 1;
-        }
+    ctx = llama_init_from_gpt_params(params);
+    if (ctx == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return 1;
     }
 
     // print system information

+ 5 - 28
examples/main/main.cpp

@@ -101,34 +101,11 @@ int main(int argc, char ** argv) {
     llama_context * ctx;
     g_ctx = &ctx;
 
-    // load the model
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.n_ctx      = params.n_ctx;
-        lparams.n_parts    = params.n_parts;
-        lparams.seed       = params.seed;
-        lparams.f16_kv     = params.memory_f16;
-        lparams.use_mmap   = params.use_mmap;
-        lparams.use_mlock  = params.use_mlock;
-
-        ctx = llama_init_from_file(params.model.c_str(), lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
-            return 1;
-        }
-    }
-
-    if (!params.lora_adapter.empty()) {
-        int err = llama_apply_lora_from_file(ctx,
-                                             params.lora_adapter.c_str(),
-                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
-                                             params.n_threads);
-        if (err != 0) {
-            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
-            return 1;
-        }
+    // load the model and apply lora adapter, if any
+    ctx = llama_init_from_gpt_params(params);
+    if (ctx == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return 1;
     }
 
     // print system information

+ 5 - 30
examples/perplexity/perplexity.cpp

@@ -122,36 +122,11 @@ int main(int argc, char ** argv) {
 
     llama_context * ctx;
 
-    // load the model
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.n_ctx      = params.n_ctx;
-        lparams.n_parts    = params.n_parts;
-        lparams.seed       = params.seed;
-        lparams.f16_kv     = params.memory_f16;
-        lparams.logits_all = params.perplexity;
-        lparams.use_mmap   = params.use_mmap;
-        lparams.use_mlock  = params.use_mlock;
-        lparams.embedding  = params.embedding;
-
-        ctx = llama_init_from_file(params.model.c_str(), lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
-            return 1;
-        }
-    }
-
-    if (!params.lora_adapter.empty()) {
-        int err = llama_apply_lora_from_file(ctx,
-                                             params.lora_adapter.c_str(),
-                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
-                                             params.n_threads);
-        if (err != 0) {
-            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
-            return 1;
-        }
+    // load the model and apply lora adapter, if any
+    ctx = llama_init_from_gpt_params(params);
+    if (ctx == NULL) {
+        fprintf(stderr, "%s: error: unable to load model\n", __func__);
+        return 1;
     }
 
     // print system information