llava.cpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #include "clip.h"
  2. #include "llava-utils.h"
  3. #include "common.h"
  4. #include "llama.h"
  5. #include <cstdio>
  6. #include <cstdlib>
  7. #include <vector>
  8. static void show_additional_info(int /*argc*/, char ** argv) {
  9. printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
  10. printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
  11. }
  12. int main(int argc, char ** argv) {
  13. ggml_time_init();
  14. gpt_params params;
  15. if (!gpt_params_parse(argc, argv, params)) {
  16. show_additional_info(argc, argv);
  17. return 1;
  18. }
  19. if (params.mmproj.empty() || params.image.empty()) {
  20. gpt_print_usage(argc, argv, params);
  21. show_additional_info(argc, argv);
  22. return 1;
  23. }
  24. const char * clip_path = params.mmproj.c_str();
  25. const char * img_path = params.image.c_str();
  26. if (params.prompt.empty()) {
  27. params.prompt = "describe the image in detail.";
  28. }
  29. auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
  30. // load and preprocess the image
  31. clip_image_u8 img;
  32. clip_image_f32 img_res;
  33. if (!clip_image_load_from_file(img_path, &img)) {
  34. fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
  35. clip_free(ctx_clip);
  36. return 1;
  37. }
  38. if (!clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true)) {
  39. fprintf(stderr, "%s: unable to preprocess %s\n", __func__, img_path);
  40. clip_free(ctx_clip);
  41. return 1;
  42. }
  43. int n_img_pos = clip_n_patches(ctx_clip);
  44. int n_img_embd = clip_n_mmproj_embd(ctx_clip);
  45. float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
  46. if (!image_embd) {
  47. fprintf(stderr, "Unable to allocate memory for image embeddings\n");
  48. return 1;
  49. }
  50. const int64_t t_img_enc_start_us = ggml_time_us();
  51. if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
  52. fprintf(stderr, "Unable to encode image\n");
  53. return 1;
  54. }
  55. const int64_t t_img_enc_end_us = ggml_time_us();
  56. // we get the embeddings, free up the memory required for CLIP
  57. clip_free(ctx_clip);
  58. llama_backend_init(params.numa);
  59. llama_model_params model_params = llama_model_default_params();
  60. llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
  61. if (model == NULL) {
  62. fprintf(stderr , "%s: error: unable to load model\n" , __func__);
  63. return 1;
  64. }
  65. llama_context_params ctx_params = llama_context_default_params();
  66. ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
  67. ctx_params.n_threads = params.n_threads;
  68. ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
  69. llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
  70. if (ctx_llama == NULL) {
  71. fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
  72. return 1;
  73. }
  74. // make sure that the correct mmproj was used, i.e., compare apples to apples
  75. int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
  76. if (n_img_embd != n_llama_embd) {
  77. printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
  78. llama_free(ctx_llama);
  79. llama_free_model(model);
  80. llama_backend_free();
  81. free(image_embd);
  82. return 1;
  83. }
  84. // process the prompt
  85. // llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
  86. int n_past = 0;
  87. const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
  88. // GG: are we sure that the should be a trailing whitespace at the end of this string?
  89. eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
  90. eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
  91. eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
  92. eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
  93. // generate the response
  94. printf("\n");
  95. for (int i = 0; i < max_tgt_len; i++) {
  96. const char * tmp = sample(ctx_llama, params, &n_past);
  97. if (strcmp(tmp, "</s>") == 0) break;
  98. printf("%s", tmp);
  99. fflush(stdout);
  100. }
  101. printf("\n");
  102. {
  103. const float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
  104. printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
  105. }
  106. llama_print_timings(ctx_llama);
  107. llama_free(ctx_llama);
  108. llama_free_model(model);
  109. llama_backend_free();
  110. free(image_embd);
  111. return 0;
  112. }