mtmd.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. #include "clip.h"
  2. #include "clip-impl.h"
  3. #include "mtmd.h"
  4. #include "llama.h"
  5. #include <algorithm>
  6. #include <cerrno>
  7. #include <cstdio>
  8. #include <cstdlib>
  9. #include <cstring>
  10. #include <limits>
  11. #include <vector>
  12. struct mtmd_context {
  13. struct clip_ctx * ctx_clip;
  14. const struct llama_model * text_model;
  15. std::vector<float> image_embd_v; // image embedding vector
  16. bool print_timings;
  17. int n_threads;
  18. std::string image_marker;
  19. // TODO @ngxson : add timings
  20. mtmd_context(const char * mmproj_fname,
  21. const llama_model * text_model,
  22. const mtmd_context_params & ctx_params) :
  23. print_timings(ctx_params.print_timings),
  24. n_threads (ctx_params.n_threads),
  25. image_marker (ctx_params.image_marker)
  26. {
  27. clip_context_params ctx_clip_params;
  28. ctx_clip_params.use_gpu = ctx_params.use_gpu;
  29. ctx_clip_params.verbosity = ctx_params.verbosity;
  30. ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
  31. if (!ctx_clip) {
  32. throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
  33. }
  34. this->text_model = text_model;
  35. }
  36. ~mtmd_context() {
  37. clip_free(ctx_clip);
  38. }
  39. };
  40. struct mtmd_image_tokens_data {
  41. clip_image_f32_batch batch_f32; // preprocessed image patches
  42. };
  43. struct mtmd_image_tokens {
  44. uint32_t nx; // number of tokens in x direction
  45. uint32_t ny; // number of tokens in y direction
  46. uint32_t n_tokens() const { return nx * ny; }
  47. clip_image_f32_batch batch_f32; // preprocessed image patches
  48. std::string id; // optional user-defined ID, useful for KV cache tracking
  49. };
  50. mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
  51. const struct llama_model * text_model,
  52. const struct mtmd_context_params ctx_params) {
  53. try {
  54. return new mtmd_context(mmproj_fname, text_model, ctx_params);
  55. } catch (const std::exception & e) {
  56. LOG_ERR("%s: error: %s\n", __func__, e.what());
  57. return nullptr;
  58. }
  59. }
  60. void mtmd_free(mtmd_context * ctx) {
  61. if (ctx) {
  62. delete ctx;
  63. }
  64. }
  65. // copied from common_tokenize
  66. static std::vector<llama_token> mtmd_tokenize_text_internal(
  67. const struct llama_vocab * vocab,
  68. const std::string & text,
  69. bool add_special,
  70. bool parse_special) {
  71. // upper limit for the number of tokens
  72. int n_tokens = text.length() + 2 * add_special;
  73. std::vector<llama_token> result(n_tokens);
  74. n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
  75. if (n_tokens < 0) {
  76. result.resize(-n_tokens);
  77. int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
  78. GGML_ASSERT(check == -n_tokens);
  79. } else {
  80. result.resize(n_tokens);
  81. }
  82. return result;
  83. }
  84. int32_t mtmd_tokenize(mtmd_context * ctx,
  85. std::vector<mtmd_input_chunk> & output,
  86. const mtmd_input_text & text,
  87. const std::vector<mtmd_bitmap> & bitmaps) {
  88. auto vocab = llama_model_get_vocab(ctx->text_model);
  89. std::string prompt_modified(text.text);
  90. std::string marker_modified(ctx->image_marker);
  91. projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
  92. // a bit hacky here, but works for now
  93. // for some models, we need to add prefix and suffix to the image embeddings
  94. if (proj_type == PROJECTOR_TYPE_GEMMA3) {
  95. // <start_of_image> ... (image embeddings) ... <end_of_image>
  96. marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
  97. string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
  98. }
  99. std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
  100. output.clear();
  101. output.reserve(parts.size());
  102. size_t i_img = 0;
  103. for (const auto & part : parts) {
  104. //printf("tokenizing part: %s\n", part.c_str());
  105. bool add_bos = &parts.front() == &part;
  106. auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special);
  107. if (tokens.empty()) {
  108. continue;
  109. }
  110. mtmd_input_chunk chunk{
  111. MTMD_INPUT_CHUNK_TYPE_TEXT,
  112. std::move(tokens),
  113. {},
  114. };
  115. output.emplace_back(std::move(chunk));
  116. if (&parts.back() != &part) {
  117. // add image token to middle of 2 parts
  118. if (i_img >= bitmaps.size()) {
  119. LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
  120. return 1;
  121. }
  122. // shim layer
  123. clip_image_u8_ptr img_u8(clip_image_u8_init());
  124. img_u8->nx = bitmaps[i_img].nx;
  125. img_u8->ny = bitmaps[i_img].ny;
  126. img_u8->buf.resize(bitmaps[i_img].data.size());
  127. std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
  128. // preprocess image
  129. clip_image_f32_batch batch_f32;
  130. bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
  131. if (!ok) {
  132. LOG_ERR("Unable to preprocess image\n");
  133. return 2;
  134. }
  135. mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
  136. image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
  137. image_tokens->ny = 1; // TODO
  138. image_tokens->batch_f32 = std::move(batch_f32);
  139. image_tokens->id = bitmaps[i_img].id; // optional
  140. mtmd_input_chunk chunk{
  141. MTMD_INPUT_CHUNK_TYPE_IMAGE,
  142. {},
  143. std::move(image_tokens),
  144. };
  145. output.emplace_back(std::move(chunk));
  146. i_img++;
  147. }
  148. }
  149. return 0;
  150. }
  151. void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
  152. if (image_tokens) {
  153. delete image_tokens;
  154. }
  155. }
  156. size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
  157. return image_tokens->n_tokens();
  158. }
  159. size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
  160. return image_tokens->nx;
  161. }
  162. size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
  163. return image_tokens->ny;
  164. }
  165. std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
  166. return image_tokens->id;
  167. }
  168. int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
  169. int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
  170. ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
  171. bool ok = clip_image_batch_encode(
  172. ctx->ctx_clip,
  173. ctx->n_threads,
  174. &image_tokens->batch_f32,
  175. ctx->image_embd_v.data());
  176. return ok ? 0 : 1;
  177. }
  178. float * mtmd_get_output_embd(mtmd_context * ctx) {
  179. return ctx->image_embd_v.data();
  180. }
  181. size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
  182. size_t n_tokens = 0;
  183. for (auto & chunk : chunks) {
  184. if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
  185. n_tokens += chunk.tokens_text.size();
  186. } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
  187. n_tokens += chunk.tokens_image->n_tokens();
  188. } else {
  189. GGML_ASSERT(false && "chunk type not supported");
  190. }
  191. }
  192. return n_tokens;
  193. }
  194. // helper struct to make working with embd batch easier
  195. // note: this will be removed after llama_batch_ext refactoring
  196. struct decode_embd_batch {
  197. std::vector<llama_pos> pos;
  198. std::vector<int32_t> n_seq_id;
  199. std::vector<llama_seq_id> seq_id_0;
  200. std::vector<llama_seq_id *> seq_ids;
  201. std::vector<int8_t> logits;
  202. llama_batch batch;
  203. decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
  204. pos .resize(n_tokens);
  205. n_seq_id.resize(n_tokens);
  206. seq_ids .resize(n_tokens + 1);
  207. logits .resize(n_tokens);
  208. seq_id_0.resize(1);
  209. seq_id_0[0] = seq_id;
  210. seq_ids [n_tokens] = nullptr;
  211. batch = {
  212. /*n_tokens =*/ n_tokens,
  213. /*tokens =*/ nullptr,
  214. /*embd =*/ embd,
  215. /*pos =*/ pos.data(),
  216. /*n_seq_id =*/ n_seq_id.data(),
  217. /*seq_id =*/ seq_ids.data(),
  218. /*logits =*/ logits.data(),
  219. };
  220. for (int i = 0; i < n_tokens; i++) {
  221. batch.pos [i] = pos_0 + i;
  222. batch.n_seq_id[i] = 1;
  223. batch.seq_id [i] = seq_id_0.data();
  224. batch.logits [i] = false;
  225. }
  226. }
  227. };
  228. int32_t mtmd_helper_eval(mtmd_context * ctx,
  229. llama_context * lctx,
  230. mtmd_input_chunks & chunks,
  231. llama_pos pos0,
  232. llama_seq_id seq_id,
  233. int32_t n_batch) {
  234. int32_t ret;
  235. llama_pos n_past = pos0;
  236. llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
  237. for (auto & chunk : chunks) {
  238. bool is_last = &chunk == &chunks.back();
  239. if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
  240. // TODO @ngxson : may need to split into smaller batches
  241. text_batch.n_tokens = chunk.tokens_text.size();
  242. for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
  243. text_batch.token [i] = chunk.tokens_text[i];
  244. text_batch.pos [i] = n_past++;
  245. text_batch.n_seq_id[i] = 1;
  246. text_batch.seq_id [i][0] = seq_id;
  247. text_batch.logits [i] = false;
  248. }
  249. if (is_last) {
  250. // always get logits for last input chunk
  251. text_batch.logits[text_batch.n_tokens - 1] = true;
  252. }
  253. ret = llama_decode(lctx, text_batch);
  254. if (ret != 0) {
  255. LOG_ERR("failed to decode text\n");
  256. llama_batch_free(text_batch);
  257. return ret;
  258. }
  259. } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
  260. GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
  261. GGML_ASSERT(chunk.tokens_image != nullptr);
  262. int64_t t0 = ggml_time_ms();
  263. if (ctx->print_timings) {
  264. LOG_INF("encoding image...\n");
  265. }
  266. ret = mtmd_encode(ctx, chunk.tokens_image.get());
  267. if (ret != 0) {
  268. LOG_ERR("failed to encode image\n");
  269. llama_batch_free(text_batch);
  270. return ret;
  271. }
  272. if (ctx->print_timings) {
  273. LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
  274. }
  275. int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
  276. float * embd = mtmd_get_output_embd(ctx);
  277. decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
  278. int64_t t1 = ggml_time_ms();
  279. ret = llama_decode(lctx, batch_img.batch);
  280. if (ret != 0) {
  281. LOG_ERR("failed to decode image\n");
  282. llama_batch_free(text_batch);
  283. return ret;
  284. }
  285. if (ctx->print_timings) {
  286. LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
  287. }
  288. n_past += n_tokens;
  289. } else {
  290. GGML_ASSERT(false && "chunk type not supported");
  291. }
  292. }
  293. llama_batch_free(text_batch);
  294. return 0;
  295. }
  296. int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) {
  297. clip_image_u8_ptr img_u8(clip_image_u8_init());
  298. bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
  299. if (!ok) {
  300. LOG_ERR("Unable to load image from buffer\n");
  301. return 1;
  302. }
  303. unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
  304. output.data.resize(output.nx * output.ny * 3);
  305. std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
  306. return 0;
  307. }
  308. int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) {
  309. clip_image_u8_ptr img_u8(clip_image_u8_init());
  310. bool ok = clip_image_load_from_file(fname, img_u8.get());
  311. if (!ok) {
  312. LOG_ERR("Unable to load image %s\n", fname);
  313. return 1;
  314. }
  315. unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
  316. output.data.resize(output.nx * output.ny * 3);
  317. std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
  318. return 0;
  319. }
  320. bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
  321. projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
  322. if (proj_type == PROJECTOR_TYPE_GEMMA3) {
  323. return true;
  324. }
  325. return false;
  326. }
  327. void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
  328. mtmd_image_tokens_free(val);
  329. }