llava.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #include "clip.h"
  2. #include "common.h"
  3. #include "llama.h"
  4. #include "llava.h"
  5. #include <cstdio>
  6. #include <cstdlib>
  7. #include <vector>
  8. #include "base64.hpp"
  9. static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
  10. clip_image_f32 * img_res = make_clip_image_f32();
  11. if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
  12. fprintf(stderr, "%s: unable to preprocess image\n", __func__);
  13. clip_image_f32_free(img_res);
  14. return false;
  15. }
  16. *n_img_pos = clip_n_patches(ctx_clip);
  17. const int64_t t_img_enc_start_us = ggml_time_us();
  18. bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
  19. clip_image_f32_free(img_res);
  20. if (!encoded) {
  21. fprintf(stderr, "Unable to encode image\n");
  22. return false;
  23. }
  24. const int64_t t_img_enc_end_us = ggml_time_us();
  25. float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
  26. printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos);
  27. return true;
  28. }
  29. bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
  30. // make sure that the correct mmproj was used, i.e., compare apples to apples
  31. int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
  32. auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
  33. if (n_image_embd != n_llama_embd) {
  34. printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
  35. return false;
  36. }
  37. return true;
  38. }
  39. static bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
  40. float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
  41. if (!image_embd) {
  42. fprintf(stderr, "Unable to allocate memory for image embeddings\n");
  43. free(image_embd);
  44. return false;
  45. }
  46. int n_img_pos;
  47. if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
  48. fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
  49. free(image_embd);
  50. return false;
  51. }
  52. *image_embd_out = image_embd;
  53. *n_img_pos_out = n_img_pos;
  54. return true;
  55. }
  56. bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
  57. int n_embd = llama_n_embd(llama_get_model(ctx_llama));
  58. for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
  59. int n_eval = image_embed->n_image_pos - i;
  60. if (n_eval > n_batch) {
  61. n_eval = n_batch;
  62. }
  63. llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
  64. if (llama_decode(ctx_llama, batch)) {
  65. fprintf(stderr, "%s : failed to eval\n", __func__);
  66. return false;
  67. }
  68. *n_past += n_eval;
  69. }
  70. return true;
  71. }
  72. LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) {
  73. clip_image_u8 * img = make_clip_image_u8();
  74. if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
  75. clip_image_u8_free(img);
  76. fprintf(stderr, "%s: can't load image from bytes, is it a valid image?", __func__);
  77. return NULL;
  78. }
  79. float* image_embed = NULL;
  80. int n_image_pos = 0;
  81. bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
  82. if (!image_embed_result) {
  83. clip_image_u8_free(img);
  84. fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
  85. return NULL;
  86. }
  87. clip_image_u8_free(img);
  88. auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
  89. result->embed = image_embed;
  90. result->n_image_pos = n_image_pos;
  91. return result;
  92. }
  93. static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) {
  94. auto file = fopen(path, "rb");
  95. if (file == NULL) {
  96. fprintf(stderr, "%s: can't read file %s\n", __func__, path);
  97. return false;
  98. }
  99. fseek(file, 0, SEEK_END);
  100. auto fileSize = ftell(file);
  101. fseek(file, 0, SEEK_SET);
  102. auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
  103. if (buffer == NULL) {
  104. fprintf(stderr, "%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
  105. perror("Memory allocation error");
  106. fclose(file);
  107. return false;
  108. }
  109. errno = 0;
  110. size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
  111. if (ferror(file)) {
  112. die_fmt("read error: %s", strerror(errno));
  113. }
  114. if (ret != (size_t) fileSize) {
  115. die("unexpectedly reached end of file");
  116. }
  117. fclose(file); // Close the file
  118. *bytesOut = buffer;
  119. *sizeOut = fileSize;
  120. return true;
  121. }
  122. LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) {
  123. unsigned char* image_bytes;
  124. long image_bytes_length;
  125. auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
  126. if (!loaded) {
  127. fprintf(stderr, "%s: failed to load %s\n", __func__, image_path);
  128. return NULL;
  129. }
  130. auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
  131. free(image_bytes);
  132. return embed;
  133. }
  134. LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed) {
  135. free(embed->embed);
  136. free(embed);
  137. }