clip.cpp 137 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197
  1. // NOTE: This is modified from clip.cpp only for LLaVA,
  2. // so there might be still unnecessary artifacts hanging around
  3. // I'll gradually clean and extend it
  4. // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
  5. #include "clip.h"
  6. #include "clip-impl.h"
  7. #include "ggml.h"
  8. #include "ggml-cpp.h"
  9. #include "ggml-cpu.h"
  10. #include "ggml-alloc.h"
  11. #include "ggml-backend.h"
  12. #include "gguf.h"
  13. #define STB_IMAGE_IMPLEMENTATION
  14. #include "stb_image.h"
  15. #include <cassert>
  16. #include <cmath>
  17. #include <cstdlib>
  18. #include <cstring>
  19. #include <fstream>
  20. #include <map>
  21. #include <regex>
  22. #include <stdexcept>
  23. #include <unordered_set>
  24. #include <vector>
  25. #include <sstream>
  26. #include <cinttypes>
  27. #include <limits>
  28. #include <array>
  29. struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
  30. //#define CLIP_DEBUG_FUNCTIONS
  31. #ifdef CLIP_DEBUG_FUNCTIONS
  32. static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
  33. std::ofstream file(filename, std::ios::binary);
  34. if (!file.is_open()) {
  35. LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
  36. return;
  37. }
  38. // PPM header: P6 format, width, height, and max color value
  39. file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
  40. // Write pixel data
  41. for (size_t i = 0; i < img.buf.size(); i += 3) {
  42. // PPM expects binary data in RGB format, which matches our image buffer
  43. file.write(reinterpret_cast<const char*>(&img.buf[i]), 3);
  44. }
  45. file.close();
  46. }
  47. static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) {
  48. std::ofstream file(filename, std::ios::binary);
  49. if (!file.is_open()) {
  50. LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
  51. return;
  52. }
  53. int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
  54. int bytesPerPixel = 3;
  55. int widthInBytes = img.nx * bytesPerPixel;
  56. int paddingAmount = (4 - (widthInBytes % 4)) % 4;
  57. int stride = widthInBytes + paddingAmount;
  58. // Bitmap file header
  59. unsigned char fileHeader[14] = {
  60. 'B','M', // Signature
  61. 0,0,0,0, // Image file size in bytes
  62. 0,0,0,0, // Reserved
  63. 54,0,0,0 // Start of pixel array
  64. };
  65. // Total file size
  66. fileSize = 54 + (stride * img.ny);
  67. fileHeader[2] = (unsigned char)(fileSize);
  68. fileHeader[3] = (unsigned char)(fileSize >> 8);
  69. fileHeader[4] = (unsigned char)(fileSize >> 16);
  70. fileHeader[5] = (unsigned char)(fileSize >> 24);
  71. // Bitmap information header (BITMAPINFOHEADER)
  72. unsigned char infoHeader[40] = {
  73. 40,0,0,0, // Size of this header (40 bytes)
  74. 0,0,0,0, // Image width
  75. 0,0,0,0, // Image height
  76. 1,0, // Number of color planes
  77. 24,0, // Bits per pixel
  78. 0,0,0,0, // No compression
  79. 0,0,0,0, // Image size (can be 0 for no compression)
  80. 0,0,0,0, // X pixels per meter (not specified)
  81. 0,0,0,0, // Y pixels per meter (not specified)
  82. 0,0,0,0, // Total colors (color table not used)
  83. 0,0,0,0 // Important colors (all are important)
  84. };
  85. // Width and height in the information header
  86. infoHeader[4] = (unsigned char)(img.nx);
  87. infoHeader[5] = (unsigned char)(img.nx >> 8);
  88. infoHeader[6] = (unsigned char)(img.nx >> 16);
  89. infoHeader[7] = (unsigned char)(img.nx >> 24);
  90. infoHeader[8] = (unsigned char)(img.ny);
  91. infoHeader[9] = (unsigned char)(img.ny >> 8);
  92. infoHeader[10] = (unsigned char)(img.ny >> 16);
  93. infoHeader[11] = (unsigned char)(img.ny >> 24);
  94. // Write file headers
  95. file.write(reinterpret_cast<char*>(fileHeader), sizeof(fileHeader));
  96. file.write(reinterpret_cast<char*>(infoHeader), sizeof(infoHeader));
  97. // Pixel data
  98. std::vector<unsigned char> padding(3, 0); // Max padding size to be added to each row
  99. for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
  100. for (int x = 0; x < img.nx; ++x) {
  101. // Each pixel
  102. size_t pixelIndex = (y * img.nx + x) * 3;
  103. unsigned char pixel[3] = {
  104. img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
  105. img.buf[pixelIndex + 1],
  106. img.buf[pixelIndex]
  107. };
  108. file.write(reinterpret_cast<char*>(pixel), 3);
  109. }
  110. // Write padding for the row
  111. file.write(reinterpret_cast<char*>(padding.data()), paddingAmount);
  112. }
  113. file.close();
  114. }
  115. // debug function to convert f32 to u8
  116. static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
  117. dst.nx = src.nx;
  118. dst.ny = src.ny;
  119. dst.buf.resize(3 * src.nx * src.ny);
  120. for (size_t i = 0; i < src.buf.size(); ++i) {
  121. dst.buf[i] = static_cast<uint8_t>(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
  122. }
  123. }
  124. #endif
  125. //
  126. // clip layers
  127. //
  128. enum patch_merge_type {
  129. PATCH_MERGE_FLAT,
  130. PATCH_MERGE_SPATIAL_UNPAD,
  131. };
  132. struct clip_hparams {
  133. int32_t image_size;
  134. int32_t patch_size;
  135. int32_t hidden_size;
  136. int32_t n_intermediate;
  137. int32_t projection_dim;
  138. int32_t n_head;
  139. int32_t n_layer;
  140. int32_t proj_scale_factor = 0; // idefics3
  141. patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
  142. float eps = 1e-6;
  143. float rope_theta = 0.0;
  144. std::vector<int32_t> image_grid_pinpoints;
  145. int32_t image_crop_resolution;
  146. std::unordered_set<int32_t> vision_feature_layer;
  147. };
  148. struct clip_layer {
  149. // attention
  150. struct ggml_tensor * k_w = nullptr;
  151. struct ggml_tensor * k_b = nullptr;
  152. struct ggml_tensor * q_w = nullptr;
  153. struct ggml_tensor * q_b = nullptr;
  154. struct ggml_tensor * v_w = nullptr;
  155. struct ggml_tensor * v_b = nullptr;
  156. struct ggml_tensor * o_w = nullptr;
  157. struct ggml_tensor * o_b = nullptr;
  158. // layernorm 1
  159. struct ggml_tensor * ln_1_w = nullptr;
  160. struct ggml_tensor * ln_1_b = nullptr;
  161. // ff
  162. struct ggml_tensor * ff_i_w = nullptr; // legacy naming
  163. struct ggml_tensor * ff_i_b = nullptr; // legacy naming
  164. struct ggml_tensor * ff_o_w = nullptr; // legacy naming
  165. struct ggml_tensor * ff_o_b = nullptr; // legacy naming
  166. struct ggml_tensor * ff_up_w = nullptr;
  167. struct ggml_tensor * ff_up_b = nullptr;
  168. struct ggml_tensor * ff_gate_w = nullptr;
  169. struct ggml_tensor * ff_gate_b = nullptr;
  170. struct ggml_tensor * ff_down_w = nullptr;
  171. struct ggml_tensor * ff_down_b = nullptr;
  172. // layernorm 2
  173. struct ggml_tensor * ln_2_w = nullptr;
  174. struct ggml_tensor * ln_2_b = nullptr;
  175. };
  176. struct clip_vision_model {
  177. struct clip_hparams hparams;
  178. // embeddings
  179. struct ggml_tensor * class_embedding = nullptr;
  180. struct ggml_tensor * patch_embeddings_0 = nullptr;
  181. struct ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
  182. struct ggml_tensor * patch_bias = nullptr;
  183. struct ggml_tensor * position_embeddings = nullptr;
  184. struct ggml_tensor * pre_ln_w = nullptr;
  185. struct ggml_tensor * pre_ln_b = nullptr;
  186. std::vector<clip_layer> layers;
  187. struct ggml_tensor * post_ln_w;
  188. struct ggml_tensor * post_ln_b;
  189. struct ggml_tensor * projection;
  190. // LLaVA projection
  191. struct ggml_tensor * mm_0_w = nullptr;
  192. struct ggml_tensor * mm_0_b = nullptr;
  193. struct ggml_tensor * mm_2_w = nullptr;
  194. struct ggml_tensor * mm_2_b = nullptr;
  195. struct ggml_tensor * image_newline = nullptr;
  196. // Yi type models with mlp+normalization projection
  197. struct ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
  198. struct ggml_tensor * mm_1_b = nullptr;
  199. struct ggml_tensor * mm_3_w = nullptr;
  200. struct ggml_tensor * mm_3_b = nullptr;
  201. struct ggml_tensor * mm_4_w = nullptr;
  202. struct ggml_tensor * mm_4_b = nullptr;
  203. //GLMV-Edge projection
  204. struct ggml_tensor * mm_model_adapter_conv_w = nullptr;
  205. struct ggml_tensor * mm_model_adapter_conv_b = nullptr;
  206. // MobileVLM projection
  207. struct ggml_tensor * mm_model_mlp_1_w = nullptr;
  208. struct ggml_tensor * mm_model_mlp_1_b = nullptr;
  209. struct ggml_tensor * mm_model_mlp_3_w = nullptr;
  210. struct ggml_tensor * mm_model_mlp_3_b = nullptr;
  211. struct ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
  212. struct ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
  213. struct ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
  214. struct ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
  215. struct ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
  216. struct ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
  217. struct ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
  218. struct ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
  219. struct ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
  220. struct ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
  221. struct ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
  222. struct ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
  223. struct ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
  224. struct ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
  225. struct ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
  226. struct ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
  227. struct ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
  228. struct ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
  229. struct ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
  230. struct ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
  231. // MobileVLM_V2 projection
  232. struct ggml_tensor * mm_model_mlp_0_w = nullptr;
  233. struct ggml_tensor * mm_model_mlp_0_b = nullptr;
  234. struct ggml_tensor * mm_model_mlp_2_w = nullptr;
  235. struct ggml_tensor * mm_model_mlp_2_b = nullptr;
  236. struct ggml_tensor * mm_model_peg_0_w = nullptr;
  237. struct ggml_tensor * mm_model_peg_0_b = nullptr;
  238. // MINICPMV projection
  239. struct ggml_tensor * mm_model_pos_embed_k = nullptr;
  240. struct ggml_tensor * mm_model_query = nullptr;
  241. struct ggml_tensor * mm_model_proj = nullptr;
  242. struct ggml_tensor * mm_model_kv_proj = nullptr;
  243. struct ggml_tensor * mm_model_attn_q_w = nullptr;
  244. struct ggml_tensor * mm_model_attn_q_b = nullptr;
  245. struct ggml_tensor * mm_model_attn_k_w = nullptr;
  246. struct ggml_tensor * mm_model_attn_k_b = nullptr;
  247. struct ggml_tensor * mm_model_attn_v_w = nullptr;
  248. struct ggml_tensor * mm_model_attn_v_b = nullptr;
  249. struct ggml_tensor * mm_model_attn_o_w = nullptr;
  250. struct ggml_tensor * mm_model_attn_o_b = nullptr;
  251. struct ggml_tensor * mm_model_ln_q_w = nullptr;
  252. struct ggml_tensor * mm_model_ln_q_b = nullptr;
  253. struct ggml_tensor * mm_model_ln_kv_w = nullptr;
  254. struct ggml_tensor * mm_model_ln_kv_b = nullptr;
  255. struct ggml_tensor * mm_model_ln_post_w = nullptr;
  256. struct ggml_tensor * mm_model_ln_post_b = nullptr;
  257. // gemma3
  258. struct ggml_tensor * mm_input_proj_w = nullptr;
  259. struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
  260. // pixtral
  261. struct ggml_tensor * token_embd_img_break = nullptr;
  262. };
  263. struct clip_ctx {
  264. bool has_llava_projector = false;
  265. int minicpmv_version = 0;
  266. struct clip_vision_model vision_model;
  267. projector_type proj_type = PROJECTOR_TYPE_MLP;
  268. int32_t max_feature_layer; // unused in newer models like gemma3
  269. float image_mean[3];
  270. float image_std[3];
  271. bool use_gelu = false;
  272. bool use_silu = false;
  273. gguf_context_ptr ctx_gguf;
  274. ggml_context_ptr ctx_data;
  275. std::vector<uint8_t> buf_compute_meta;
  276. std::vector<ggml_backend_t> backend_ptrs;
  277. std::vector<ggml_backend_buffer_type_t> backend_buft;
  278. ggml_backend_t backend;
  279. ggml_backend_t backend_cpu;
  280. ggml_backend_buffer_ptr buf;
  281. int max_nodes = 8192;
  282. ggml_backend_sched_ptr sched;
  283. clip_image_size load_image_size;
  284. clip_ctx(clip_context_params & ctx_params) {
  285. backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
  286. backend = ctx_params.use_gpu
  287. ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
  288. : nullptr;
  289. if (backend) {
  290. LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
  291. backend_ptrs.push_back(backend);
  292. backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
  293. } else {
  294. backend = backend_cpu;
  295. LOG_INF("%s: CLIP using CPU backend\n", __func__);
  296. }
  297. backend_ptrs.push_back(backend_cpu);
  298. backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
  299. sched.reset(
  300. ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false)
  301. );
  302. }
  303. ~clip_ctx() {
  304. ggml_backend_free(backend);
  305. if (backend != backend_cpu) {
  306. ggml_backend_free(backend_cpu);
  307. }
  308. }
  309. };
  310. static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) {
  311. const auto & model = ctx->vision_model;
  312. const auto & hparams = model.hparams;
  313. int image_size_width = img.nx;
  314. int image_size_height = img.ny;
  315. const int patch_size = hparams.patch_size;
  316. const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
  317. const int hidden_size = hparams.hidden_size;
  318. const int n_head = hparams.n_head;
  319. const int d_head = hidden_size / n_head;
  320. const int n_layer = hparams.n_layer;
  321. const float eps = hparams.eps;
  322. struct ggml_init_params params = {
  323. /*.mem_size =*/ ctx->buf_compute_meta.size(),
  324. /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
  325. /*.no_alloc =*/ true,
  326. };
  327. ggml_context_ptr ctx0_ptr(ggml_init(params));
  328. auto ctx0 = ctx0_ptr.get();
  329. struct ggml_cgraph * gf = ggml_new_graph(ctx0);
  330. // input raw
  331. struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
  332. ggml_set_name(inp_raw, "inp_raw");
  333. ggml_set_input(inp_raw);
  334. struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
  335. inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
  336. inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
  337. inp = ggml_add(ctx0, inp, model.patch_bias);
  338. // position embeddings
  339. struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
  340. // loop over layers
  341. for (int il = 0; il < n_layer; il++) {
  342. struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
  343. // layernorm1
  344. {
  345. cur = ggml_norm(ctx0, cur, eps);
  346. cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
  347. }
  348. // self-attention
  349. {
  350. struct ggml_tensor * Q =
  351. ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
  352. Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
  353. Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
  354. struct ggml_tensor * K =
  355. ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
  356. K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
  357. K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
  358. struct ggml_tensor * V =
  359. ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
  360. V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
  361. V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
  362. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  363. KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
  364. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
  365. KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
  366. KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  367. cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
  368. }
  369. // attention output
  370. cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
  371. // re-add the layer input, e.g., residual
  372. cur = ggml_add(ctx0, cur, embeddings);
  373. embeddings = cur; // embeddings = residual, cur = hidden_states
  374. // layernorm2
  375. {
  376. cur = ggml_norm(ctx0, cur, eps);
  377. cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
  378. }
  379. cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
  380. cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
  381. // siglip uses gelu
  382. cur = ggml_gelu(ctx0, cur);
  383. cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
  384. cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
  385. // residual 2
  386. cur = ggml_add(ctx0, embeddings, cur);
  387. embeddings = cur;
  388. }
  389. // post-layernorm
  390. if (model.post_ln_w) {
  391. embeddings = ggml_norm(ctx0, embeddings, eps);
  392. ggml_set_name(embeddings, "post_ln");
  393. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
  394. }
  395. if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
  396. const int batch_size = 1;
  397. const int mm_tokens_per_image = 256; // default value for gemma3
  398. const int tokens_per_side = sqrt(mm_tokens_per_image);
  399. const int patches_per_image = sqrt(num_patches);
  400. const int kernel_size = patches_per_image / tokens_per_side;
  401. embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
  402. embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
  403. // doing a pool2d to reduce the number of output tokens to 256
  404. embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
  405. embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
  406. embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
  407. // apply norm before projection
  408. embeddings = ggml_rms_norm(ctx0, embeddings, eps);
  409. embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
  410. // apply projection
  411. embeddings = ggml_mul_mat(ctx0,
  412. ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
  413. embeddings);
  414. } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
  415. // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
  416. ggml_tensor * cur = embeddings;
  417. const int scale_factor = model.hparams.proj_scale_factor;
  418. const int n_embd = cur->ne[0];
  419. const int seq = cur->ne[1];
  420. const int bsz = 1; // batch size, always 1 for now since we don't support batching
  421. const int height = std::sqrt(seq);
  422. const int width = std::sqrt(seq);
  423. GGML_ASSERT(scale_factor != 0);
  424. cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
  425. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  426. cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
  427. n_embd * scale_factor * scale_factor,
  428. height / scale_factor,
  429. width / scale_factor,
  430. bsz);
  431. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
  432. cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
  433. n_embd * scale_factor * scale_factor,
  434. seq / (scale_factor * scale_factor),
  435. bsz);
  436. cur = ggml_mul_mat(ctx0, model.projection, cur);
  437. embeddings = cur;
  438. } else {
  439. GGML_ABORT("SigLIP: Unsupported projector type");
  440. }
  441. // build the graph
  442. ggml_build_forward_expand(gf, embeddings);
  443. return gf;
  444. }
  445. // implementation of the 2D RoPE without adding a new op in ggml
  446. // this is not efficient (use double the memory), but works on all backends
  447. // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
  448. static ggml_tensor * build_rope_2d(
  449. ggml_context * ctx0,
  450. ggml_tensor * cur,
  451. ggml_tensor * pos_h,
  452. ggml_tensor * pos_w,
  453. const float freq_base
  454. ) {
  455. const int64_t n_dim = cur->ne[0];
  456. const int64_t n_head = cur->ne[1];
  457. const int64_t n_pos = cur->ne[2];
  458. // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
  459. // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
  460. // first half of cur will use 1e-0, 1e-2 (even)
  461. // second half of cur will use 1e-1, 1e-3 (odd)
  462. // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
  463. // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
  464. // then for the second half, we use freq_scale to shift the inv_freq
  465. // ^ why? replace (2i) with (2i+1) in the above equation
  466. const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
  467. // first half
  468. ggml_tensor * first;
  469. {
  470. first = ggml_view_3d(ctx0, cur,
  471. n_dim/2, n_head, n_pos,
  472. ggml_row_size(cur->type, n_dim),
  473. ggml_row_size(cur->type, n_dim*n_head),
  474. 0);
  475. first = ggml_rope_ext(
  476. ctx0,
  477. first,
  478. pos_h, // positions
  479. nullptr, // freq factors
  480. n_dim/2, // n_dims
  481. 0, 0, freq_base,
  482. 1.0f, 0.0f, 1.0f, 0.0f, 0.0f
  483. );
  484. }
  485. // second half
  486. ggml_tensor * second;
  487. {
  488. second = ggml_view_3d(ctx0, cur,
  489. n_dim/2, n_head, n_pos,
  490. ggml_row_size(cur->type, n_dim),
  491. ggml_row_size(cur->type, n_dim*n_head),
  492. n_dim/2 * ggml_element_size(cur));
  493. second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
  494. second = ggml_rope_ext(
  495. ctx0,
  496. second,
  497. pos_w, // positions
  498. nullptr, // freq factors
  499. n_dim/2, // n_dims
  500. 0, 0, freq_base,
  501. freq_scale_odd,
  502. 0.0f, 1.0f, 0.0f, 0.0f
  503. );
  504. }
  505. cur = ggml_concat(ctx0, first, second, 0);
  506. return cur;
  507. }
  508. static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) {
  509. const auto & model = ctx->vision_model;
  510. const auto & hparams = model.hparams;
  511. GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
  512. int image_size_width = img.nx;
  513. int image_size_height = img.ny;
  514. const int patch_size = hparams.patch_size;
  515. const int n_patches_x = image_size_width / patch_size;
  516. const int n_patches_y = image_size_height / patch_size;
  517. const int num_patches = n_patches_x * n_patches_y;
  518. const int hidden_size = hparams.hidden_size;
  519. const int n_head = hparams.n_head;
  520. const int d_head = hidden_size / n_head;
  521. const int n_layer = hparams.n_layer;
  522. const float eps = hparams.eps;
  523. struct ggml_init_params params = {
  524. /*.mem_size =*/ ctx->buf_compute_meta.size(),
  525. /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
  526. /*.no_alloc =*/ true,
  527. };
  528. ggml_context_ptr ctx0_ptr(ggml_init(params));
  529. auto ctx0 = ctx0_ptr.get();
  530. struct ggml_cgraph * gf = ggml_new_graph(ctx0);
  531. // input raw
  532. struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
  533. ggml_set_name(inp_raw, "inp_raw");
  534. ggml_set_input(inp_raw);
  535. // 2D input positions
  536. struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
  537. ggml_set_name(pos_h, "pos_h");
  538. ggml_set_input(pos_h);
  539. struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
  540. ggml_set_name(pos_w, "pos_w");
  541. ggml_set_input(pos_w);
  542. struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
  543. inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
  544. inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
  545. struct ggml_tensor * embeddings = inp;
  546. // pre-layer norm
  547. embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
  548. // loop over layers
  549. for (int il = 0; il < n_layer; il++) {
  550. struct ggml_tensor * cur = embeddings;
  551. // pre-attention norm
  552. cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
  553. // self-attention
  554. {
  555. struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
  556. Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
  557. Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
  558. Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
  559. struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
  560. K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
  561. K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
  562. K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
  563. struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
  564. V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
  565. V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
  566. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  567. KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
  568. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
  569. KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
  570. KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  571. cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
  572. cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
  573. }
  574. // re-add the layer input, e.g., residual
  575. cur = ggml_add(ctx0, cur, embeddings);
  576. embeddings = cur; // embeddings = residual, cur = hidden_states
  577. // pre-ffn norm
  578. cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
  579. // feed-forward
  580. {
  581. ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
  582. ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
  583. gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
  584. cur = ggml_mul(ctx0, up_proj, gate_proj);
  585. cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
  586. }
  587. // residual 2
  588. cur = ggml_add(ctx0, embeddings, cur);
  589. embeddings = cur;
  590. }
  591. // LlavaMultiModalProjector (with GELU activation)
  592. {
  593. embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
  594. embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
  595. embeddings = ggml_gelu(ctx0, embeddings);
  596. embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
  597. embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
  598. }
  599. // arrangement of the [IMG_BREAK] token
  600. {
  601. // not efficient, but works
  602. // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
  603. // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
  604. // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
  605. const int n_embd_text = embeddings->ne[0];
  606. const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
  607. ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
  608. ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
  609. tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
  610. tok = ggml_add(ctx0, tok, model.token_embd_img_break);
  611. cur = ggml_concat(ctx0, cur, tok, 1);
  612. embeddings = ggml_view_2d(ctx0, cur,
  613. n_embd_text, n_tokens_output,
  614. ggml_row_size(cur->type, n_embd_text), 0);
  615. }
  616. // build the graph
  617. ggml_build_forward_expand(gf, embeddings);
  618. return gf;
  619. }
  620. static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
  621. const auto & model = ctx->vision_model;
  622. const auto & hparams = model.hparams;
  623. const int image_size = hparams.image_size;
  624. int image_size_width = image_size;
  625. int image_size_height = image_size;
  626. if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
  627. LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
  628. image_size_width = load_image_size.width;
  629. image_size_height = load_image_size.height;
  630. if (is_inf) {
  631. image_size_width = imgs.entries[0]->nx;
  632. image_size_height = imgs.entries[0]->ny;
  633. }
  634. }
  635. else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  636. // use the image's native resolution when image is avaible
  637. if (is_inf) {
  638. // if (imgs->data->nx && imgs->data->ny) {
  639. image_size_width = imgs.entries[0]->nx;
  640. image_size_height = imgs.entries[0]->ny;
  641. }
  642. }
  643. const int patch_size = hparams.patch_size;
  644. const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
  645. const int patches_w = image_size_width / patch_size;
  646. const int patches_h = image_size_height / patch_size;
  647. const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
  648. const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
  649. const int hidden_size = hparams.hidden_size;
  650. const int n_head = hparams.n_head;
  651. const int d_head = hidden_size / n_head;
  652. const float eps = hparams.eps;
  653. int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
  654. const int batch_size = imgs.entries.size();
  655. if (ctx->has_llava_projector
  656. || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
  657. || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
  658. GGML_ASSERT(batch_size == 1);
  659. }
  660. struct ggml_init_params params = {
  661. /*.mem_size =*/ ctx->buf_compute_meta.size(),
  662. /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
  663. /*.no_alloc =*/ true,
  664. };
  665. ggml_context_ptr ctx0_ptr(ggml_init(params));
  666. auto ctx0 = ctx0_ptr.get();
  667. struct ggml_cgraph * gf = ggml_new_graph(ctx0);
  668. struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
  669. ggml_set_name(inp_raw, "inp_raw");
  670. ggml_set_input(inp_raw);
  671. struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
  672. if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  673. GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
  674. GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
  675. auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
  676. inp = ggml_add(ctx0, inp, inp_1);
  677. inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
  678. inp = ggml_reshape_4d(
  679. ctx0, inp,
  680. hidden_size * 2, patches_w / 2, patches_h, batch_size);
  681. inp = ggml_reshape_4d(
  682. ctx0, inp,
  683. hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
  684. inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
  685. inp = ggml_reshape_3d(
  686. ctx0, inp,
  687. hidden_size, patches_w * patches_h, batch_size);
  688. }
  689. else {
  690. inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
  691. inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
  692. }
  693. if (model.patch_bias) {
  694. // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
  695. inp = ggml_add(ctx0, inp, model.patch_bias);
  696. }
  697. struct ggml_tensor * embeddings = inp;
  698. struct ggml_tensor * pos_embed = nullptr;
  699. // concat class_embeddings and patch_embeddings
  700. if (model.class_embedding) {
  701. embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
  702. embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
  703. embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
  704. embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
  705. embeddings = ggml_acc(ctx0, embeddings, inp,
  706. embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
  707. }
  708. struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
  709. ggml_set_name(positions, "positions");
  710. ggml_set_input(positions);
  711. if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings
  712. embeddings =
  713. ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
  714. }
  715. if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
  716. int pos_w = image_size_width/patch_size;
  717. int pos_h = image_size_height/patch_size;
  718. if (ctx->minicpmv_version == 2) {
  719. pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
  720. }
  721. else if (ctx->minicpmv_version == 3) {
  722. pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
  723. }
  724. else if (ctx->minicpmv_version == 4) {
  725. pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
  726. }
  727. ggml_set_name(pos_embed, "pos_embed");
  728. ggml_set_input(pos_embed);
  729. }
  730. // pre-layernorm
  731. if (model.pre_ln_w) {
  732. embeddings = ggml_norm(ctx0, embeddings, eps);
  733. ggml_set_name(embeddings, "pre_ln");
  734. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
  735. }
  736. std::vector<struct ggml_tensor *> embedding_stack;
  737. const auto & vision_feature_layer = hparams.vision_feature_layer;
  738. // loop over layers
  739. for (int il = 0; il < ctx->max_feature_layer; il++) {
  740. struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
  741. // If this is an embedding feature layer, save the output.
  742. // NOTE: 0 index here refers to the input to the encoder.
  743. if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
  744. embedding_stack.push_back(embeddings);
  745. }
  746. //const size_t nb_q_w = model.layers[il].q_w->nb[0];
  747. // layernorm1
  748. {
  749. cur = ggml_norm(ctx0, cur, eps);
  750. cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
  751. model.layers[il].ln_1_b);
  752. }
  753. // self-attention
  754. {
  755. struct ggml_tensor * Q =
  756. ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
  757. Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
  758. if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  759. Q = ggml_rope_multi(
  760. ctx0, Q, positions, nullptr,
  761. d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
  762. }
  763. Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
  764. Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
  765. struct ggml_tensor * K =
  766. ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
  767. K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
  768. if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  769. K = ggml_rope_multi(
  770. ctx0, K, positions, nullptr,
  771. d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
  772. }
  773. K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
  774. K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
  775. struct ggml_tensor * V =
  776. ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
  777. V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
  778. V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
  779. V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
  780. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  781. KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
  782. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
  783. KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
  784. KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  785. cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
  786. }
  787. // attention output
  788. cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
  789. // re-add the layer input, e.g., residual
  790. cur = ggml_add(ctx0, cur, embeddings);
  791. embeddings = cur; // embeddings = residual, cur = hidden_states
  792. // layernorm2
  793. {
  794. cur = ggml_norm(ctx0, cur, eps);
  795. cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
  796. }
  797. cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
  798. cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
  799. if (ctx->use_gelu) {
  800. cur = ggml_gelu_inplace(ctx0, cur);
  801. } else if (ctx->use_silu) {
  802. cur = ggml_silu_inplace(ctx0, cur);
  803. } else {
  804. cur = ggml_gelu_quick_inplace(ctx0, cur);
  805. }
  806. cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
  807. cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
  808. // residual 2
  809. cur = ggml_add(ctx0, embeddings, cur);
  810. embeddings = cur;
  811. }
  812. // post-layernorm
  813. if (model.post_ln_w) {
  814. embeddings = ggml_norm(ctx0, embeddings, eps);
  815. ggml_set_name(embeddings, "post_ln");
  816. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
  817. }
  818. // final layer is a vision feature layer
  819. if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
  820. embedding_stack.push_back(embeddings);
  821. }
  822. // If feature layers are explicitly set, stack them (if we have multiple)
  823. if (!embedding_stack.empty()) {
  824. embeddings = embedding_stack[0];
  825. for (size_t i = 1; i < embedding_stack.size(); i++) {
  826. embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
  827. }
  828. }
  829. // llava projector
  830. if (ctx->has_llava_projector) {
  831. embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
  832. struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
  833. ggml_set_name(patches, "patches");
  834. ggml_set_input(patches);
  835. // shape [1, 576, 1024]
  836. // ne is whcn, ne = [1024, 576, 1, 1]
  837. embeddings = ggml_get_rows(ctx0, embeddings, patches);
  838. // print_tensor_info(embeddings, "embeddings");
  839. // llava projector
  840. if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
  841. embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
  842. embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
  843. embeddings = ggml_gelu(ctx0, embeddings);
  844. if (model.mm_2_w) {
  845. embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
  846. embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
  847. }
  848. }
  849. else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
  850. embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
  851. embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
  852. // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
  853. // First LayerNorm
  854. embeddings = ggml_norm(ctx0, embeddings, eps);
  855. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
  856. model.mm_1_b);
  857. // GELU activation
  858. embeddings = ggml_gelu(ctx0, embeddings);
  859. // Second linear layer
  860. embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
  861. embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
  862. // Second LayerNorm
  863. embeddings = ggml_norm(ctx0, embeddings, eps);
  864. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
  865. model.mm_4_b);
  866. }
  867. else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
  868. // MobileVLM projector
  869. int n_patch = 24;
  870. struct ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
  871. mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
  872. mlp_1 = ggml_gelu(ctx0, mlp_1);
  873. struct ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
  874. mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
  875. // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
  876. // block 1
  877. struct ggml_tensor * block_1 = nullptr;
  878. {
  879. // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
  880. mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
  881. mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
  882. // stride = 1, padding = 1, bias is nullptr
  883. block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
  884. // layer norm
  885. // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
  886. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
  887. // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
  888. block_1 = ggml_norm(ctx0, block_1, eps);
  889. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
  890. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
  891. // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
  892. // hardswish
  893. struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
  894. block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
  895. // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
  896. // pointwise conv
  897. block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
  898. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
  899. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
  900. block_1 = ggml_relu(ctx0, block_1);
  901. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
  902. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
  903. block_1 = ggml_hardsigmoid(ctx0, block_1);
  904. // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
  905. block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
  906. block_1 = ggml_mul(ctx0, block_1_hw, block_1);
  907. int w = block_1->ne[0], h = block_1->ne[1];
  908. block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
  909. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
  910. // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
  911. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
  912. block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
  913. // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
  914. block_1 = ggml_norm(ctx0, block_1, eps);
  915. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
  916. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
  917. // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
  918. // residual
  919. block_1 = ggml_add(ctx0, mlp_3, block_1);
  920. }
  921. // block_2
  922. {
  923. // stride = 2
  924. block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
  925. // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
  926. // layer norm
  927. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
  928. // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
  929. block_1 = ggml_norm(ctx0, block_1, eps);
  930. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
  931. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
  932. // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
  933. // hardswish
  934. struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
  935. // not sure the parameters is right for globalAvgPooling
  936. block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
  937. // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
  938. // pointwise conv
  939. block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
  940. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
  941. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
  942. block_1 = ggml_relu(ctx0, block_1);
  943. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
  944. block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
  945. block_1 = ggml_hardsigmoid(ctx0, block_1);
  946. // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
  947. block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
  948. block_1 = ggml_mul(ctx0, block_1_hw, block_1);
  949. int w = block_1->ne[0], h = block_1->ne[1];
  950. block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
  951. block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
  952. // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
  953. block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
  954. block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
  955. // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
  956. block_1 = ggml_norm(ctx0, block_1, eps);
  957. block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
  958. block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
  959. // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
  960. }
  961. embeddings = block_1;
  962. }
  963. else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
  964. {
  965. int n_patch = 24;
  966. struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
  967. mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
  968. mlp_0 = ggml_gelu(ctx0, mlp_0);
  969. struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
  970. mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
  971. // mlp_2 ne = [2048, 576, 1, 1]
  972. // // AVG Pool Layer 2*2, strides = 2
  973. mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
  974. // mlp_2 ne = [576, 2048, 1, 1]
  975. mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
  976. // mlp_2 ne [24, 24, 2048, 1]
  977. mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
  978. // weight ne = [3, 3, 2048, 1]
  979. struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
  980. peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
  981. peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
  982. mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
  983. peg_0 = ggml_add(ctx0, peg_0, mlp_2);
  984. peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
  985. embeddings = peg_0;
  986. }
  987. else {
  988. GGML_ABORT("fatal error");
  989. }
  990. }
  991. // minicpmv projector
  992. else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
  993. struct ggml_tensor * q = model.mm_model_query;
  994. { // layernorm
  995. q = ggml_norm(ctx0, q, eps);
  996. q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
  997. }
  998. struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
  999. { // layernorm
  1000. v = ggml_norm(ctx0, v, eps);
  1001. v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
  1002. }
  1003. struct ggml_tensor * k;
  1004. { // position
  1005. // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
  1006. k = ggml_add(ctx0, v, pos_embed);
  1007. }
  1008. { // attention
  1009. int hidden_size = 4096;
  1010. const int d_head = 128;
  1011. int n_head = hidden_size/d_head;
  1012. int num_query = 96;
  1013. if (ctx->minicpmv_version == 2) {
  1014. hidden_size = 4096;
  1015. n_head = hidden_size/d_head;
  1016. num_query = 96;
  1017. }
  1018. else if (ctx->minicpmv_version == 3) {
  1019. hidden_size = 3584;
  1020. n_head = hidden_size/d_head;
  1021. num_query = 64;
  1022. }
  1023. else if (ctx->minicpmv_version == 4) {
  1024. hidden_size = 3584;
  1025. n_head = hidden_size/d_head;
  1026. num_query = 64;
  1027. }
  1028. struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
  1029. struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
  1030. struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
  1031. // permute
  1032. Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
  1033. Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
  1034. Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
  1035. K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
  1036. K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
  1037. K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
  1038. V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
  1039. V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
  1040. V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
  1041. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  1042. KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
  1043. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
  1044. KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
  1045. KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  1046. KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
  1047. embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
  1048. }
  1049. { // layernorm
  1050. embeddings = ggml_norm(ctx0, embeddings, eps);
  1051. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
  1052. }
  1053. embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
  1054. }
  1055. // glm projector
  1056. else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
  1057. size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
  1058. embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
  1059. embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
  1060. embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
  1061. embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
  1062. embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
  1063. embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
  1064. // GLU
  1065. {
  1066. embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
  1067. embeddings = ggml_norm(ctx0, embeddings, eps);
  1068. embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
  1069. embeddings = ggml_gelu_inplace(ctx0, embeddings);
  1070. struct ggml_tensor * x = embeddings;
  1071. embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
  1072. x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
  1073. embeddings = ggml_silu_inplace(ctx0, embeddings);
  1074. embeddings = ggml_mul(ctx0, embeddings,x);
  1075. embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
  1076. }
  1077. }
  1078. else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  1079. embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
  1080. embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
  1081. embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
  1082. // GELU activation
  1083. embeddings = ggml_gelu(ctx0, embeddings);
  1084. // Second linear layer
  1085. embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
  1086. embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
  1087. }
  1088. // build the graph
  1089. ggml_build_forward_expand(gf, embeddings);
  1090. return gf;
  1091. }
  1092. static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
  1093. ggml_cgraph * res;
  1094. switch (ctx->proj_type) {
  1095. case PROJECTOR_TYPE_GEMMA3:
  1096. case PROJECTOR_TYPE_IDEFICS3:
  1097. {
  1098. GGML_ASSERT(imgs.entries.size() == 1);
  1099. res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]);
  1100. } break;
  1101. case PROJECTOR_TYPE_PIXTRAL:
  1102. {
  1103. GGML_ASSERT(imgs.entries.size() == 1);
  1104. res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
  1105. } break;
  1106. default:
  1107. {
  1108. // TODO: we should have one build_* function per model
  1109. res = clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
  1110. } break;
  1111. }
  1112. return res;
  1113. }
  1114. struct clip_model_loader {
  1115. ggml_context_ptr ctx_meta;
  1116. gguf_context_ptr ctx_gguf;
  1117. clip_ctx & ctx_clip;
  1118. std::string fname;
  1119. size_t model_size; // in bytes
  1120. // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model
  1121. clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) {
  1122. struct ggml_context * meta = nullptr;
  1123. struct gguf_init_params params = {
  1124. /*.no_alloc = */ true,
  1125. /*.ctx = */ &meta,
  1126. };
  1127. ctx_gguf = gguf_context_ptr(gguf_init_from_file(fname, params));
  1128. if (!ctx_gguf.get()) {
  1129. throw std::runtime_error(string_format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
  1130. }
  1131. ctx_meta.reset(meta);
  1132. const int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
  1133. // print gguf info
  1134. {
  1135. std::string name;
  1136. get_string(KEY_NAME, name, false);
  1137. std::string description;
  1138. get_string(KEY_DESCRIPTION, description, false);
  1139. LOG_INF("%s: model name: %s\n", __func__, name.c_str());
  1140. LOG_INF("%s: description: %s\n", __func__, description.c_str());
  1141. LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx_gguf.get()));
  1142. LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx_gguf.get()));
  1143. LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);
  1144. LOG_INF("%s: n_kv: %d\n", __func__, (int)gguf_get_n_kv(ctx_gguf.get()));
  1145. LOG_INF("\n");
  1146. }
  1147. // tensors
  1148. {
  1149. for (int i = 0; i < n_tensors; ++i) {
  1150. const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
  1151. const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i);
  1152. enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i);
  1153. struct ggml_tensor * cur = ggml_get_tensor(meta, name);
  1154. size_t tensor_size = ggml_nbytes(cur);
  1155. model_size += tensor_size;
  1156. LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
  1157. __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
  1158. }
  1159. }
  1160. }
  1161. void load_hparams() {
  1162. auto & hparams = ctx_clip.vision_model.hparams;
  1163. // projector type
  1164. std::string proj_type;
  1165. {
  1166. get_string(KEY_PROJ_TYPE, proj_type, false);
  1167. if (!proj_type.empty()) {
  1168. ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
  1169. }
  1170. if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) {
  1171. throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
  1172. }
  1173. }
  1174. // other hparams
  1175. {
  1176. get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false);
  1177. get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
  1178. get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
  1179. get_u32(KEY_N_EMBD, hparams.hidden_size);
  1180. get_u32(KEY_N_HEAD, hparams.n_head);
  1181. get_u32(KEY_N_FF, hparams.n_intermediate);
  1182. get_u32(KEY_N_BLOCK, hparams.n_layer);
  1183. get_u32(KEY_PROJ_DIM, hparams.projection_dim);
  1184. get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
  1185. get_u32(KEY_IMAGE_SIZE, hparams.image_size);
  1186. get_u32(KEY_PATCH_SIZE, hparams.patch_size);
  1187. get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
  1188. get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
  1189. ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
  1190. || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
  1191. || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
  1192. || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
  1193. {
  1194. std::string mm_patch_merge_type;
  1195. get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
  1196. if (mm_patch_merge_type == "spatial_unpad") {
  1197. hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD;
  1198. }
  1199. }
  1200. {
  1201. int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
  1202. int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
  1203. GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
  1204. GGML_ASSERT(idx_std >= 0 && "image_std not found");
  1205. const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
  1206. const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
  1207. for (int i = 0; i < 3; ++i) {
  1208. ctx_clip.image_mean[i] = mean_data[i];
  1209. ctx_clip.image_std[i] = std_data[i];
  1210. }
  1211. }
  1212. // Load the vision feature layer indices if they are explicitly provided;
  1213. // if multiple vision feature layers are present, the values will be concatenated
  1214. // to form the final visual features.
  1215. // NOTE: gguf conversions should standardize the values of the vision feature layer to
  1216. // be non-negative, since we use -1 to mark values as unset here.
  1217. std::vector<int> vision_feature_layer;
  1218. get_arr_int(KEY_FEATURE_LAYER, vision_feature_layer, false);
  1219. // convert std::vector to std::unordered_set
  1220. for (auto & layer : vision_feature_layer) {
  1221. hparams.vision_feature_layer.insert(layer);
  1222. }
  1223. // Calculate the deepest feature layer based on hparams and projector type
  1224. // NOTE: This is only used by build_graph_legacy()
  1225. {
  1226. // Get the index of the second to last layer; this is the default for models that have a llava projector
  1227. int n_layer = hparams.n_layer - 1;
  1228. int deepest_feature_layer = -1;
  1229. if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV
  1230. || ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE
  1231. || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL) {
  1232. n_layer += 1;
  1233. }
  1234. // If we set explicit vision feature layers, only go up to the deepest one
  1235. // NOTE: only used by granite-vision models for now
  1236. for (const auto & feature_layer : hparams.vision_feature_layer) {
  1237. if (feature_layer > deepest_feature_layer) {
  1238. deepest_feature_layer = feature_layer;
  1239. }
  1240. }
  1241. ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
  1242. }
  1243. // model-specific params
  1244. switch (ctx_clip.proj_type) {
  1245. case PROJECTOR_TYPE_MINICPMV:
  1246. {
  1247. if (ctx_clip.minicpmv_version == 0) {
  1248. ctx_clip.minicpmv_version = 2; // default to 2 if not set
  1249. }
  1250. } break;
  1251. case PROJECTOR_TYPE_IDEFICS3:
  1252. {
  1253. get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
  1254. } break;
  1255. case PROJECTOR_TYPE_PIXTRAL:
  1256. {
  1257. hparams.rope_theta = 10000.0f;
  1258. } break;
  1259. default:
  1260. break;
  1261. }
  1262. LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
  1263. LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
  1264. LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
  1265. LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
  1266. LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
  1267. }
  1268. }
  1269. void load_tensors() {
  1270. std::map<std::string, size_t> tensor_offset;
  1271. std::vector<ggml_tensor *> tensors_to_load;
  1272. // get offsets
  1273. for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
  1274. const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
  1275. tensor_offset[name] = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), i);
  1276. }
  1277. // create data context
  1278. struct ggml_init_params params = {
  1279. /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(),
  1280. /*.mem_buffer =*/ NULL,
  1281. /*.no_alloc =*/ true,
  1282. };
  1283. ctx_clip.ctx_data.reset(ggml_init(params));
  1284. if (!ctx_clip.ctx_data) {
  1285. throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__));
  1286. }
  1287. // helper function
  1288. auto get_tensor = [&](const std::string & name, bool required = true) {
  1289. struct ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
  1290. if (!cur && required) {
  1291. throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
  1292. }
  1293. if (cur) {
  1294. tensors_to_load.push_back(cur);
  1295. // add tensors to context
  1296. struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
  1297. ggml_set_name(data_tensor, cur->name);
  1298. cur = data_tensor;
  1299. }
  1300. return cur;
  1301. };
  1302. auto & vision_model = ctx_clip.vision_model;
  1303. vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
  1304. vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
  1305. vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
  1306. vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
  1307. vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
  1308. vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
  1309. vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
  1310. vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
  1311. vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
  1312. // layers
  1313. vision_model.layers.resize(vision_model.hparams.n_layer);
  1314. for (int il = 0; il < vision_model.hparams.n_layer; ++il) {
  1315. auto & layer = vision_model.layers[il];
  1316. layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
  1317. layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
  1318. layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
  1319. layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
  1320. layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
  1321. layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
  1322. layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
  1323. layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
  1324. layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
  1325. layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
  1326. layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
  1327. layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
  1328. // new naming
  1329. layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
  1330. layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
  1331. layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
  1332. layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
  1333. layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
  1334. layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
  1335. // legacy naming (the in and out is reversed! don't ask me why)
  1336. layer.ff_i_w = layer.ff_down_w;
  1337. layer.ff_o_w = layer.ff_up_w;
  1338. layer.ff_i_b = layer.ff_down_b;
  1339. layer.ff_o_b = layer.ff_up_b;
  1340. }
  1341. switch (ctx_clip.proj_type) {
  1342. case PROJECTOR_TYPE_MLP:
  1343. case PROJECTOR_TYPE_MLP_NORM:
  1344. {
  1345. // LLaVA projection
  1346. vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
  1347. vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
  1348. // Yi-type llava
  1349. vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
  1350. vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
  1351. // missing in Yi-type llava
  1352. vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
  1353. vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
  1354. // Yi-type llava
  1355. vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
  1356. vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
  1357. vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
  1358. vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
  1359. if (vision_model.mm_3_w) {
  1360. // TODO: this is a hack to support Yi-type llava
  1361. ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM;
  1362. }
  1363. vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
  1364. } break;
  1365. case PROJECTOR_TYPE_LDP:
  1366. {
  1367. // MobileVLM projection
  1368. vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
  1369. vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
  1370. vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
  1371. vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
  1372. vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
  1373. vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
  1374. vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
  1375. vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
  1376. vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
  1377. vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
  1378. vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
  1379. vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
  1380. vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
  1381. vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
  1382. vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
  1383. vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
  1384. vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
  1385. vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
  1386. vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
  1387. vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
  1388. vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
  1389. vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
  1390. vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
  1391. vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
  1392. } break;
  1393. case PROJECTOR_TYPE_LDPV2:
  1394. {
  1395. // MobilVLM_V2 projection
  1396. vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
  1397. vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
  1398. vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
  1399. vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
  1400. vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
  1401. vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
  1402. } break;
  1403. case PROJECTOR_TYPE_MINICPMV:
  1404. {
  1405. // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
  1406. vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
  1407. vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
  1408. vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
  1409. vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
  1410. vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
  1411. vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
  1412. vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
  1413. vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
  1414. vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
  1415. vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
  1416. vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
  1417. vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
  1418. vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
  1419. vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
  1420. vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
  1421. vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
  1422. vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
  1423. vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
  1424. } break;
  1425. case PROJECTOR_TYPE_GLM_EDGE:
  1426. {
  1427. vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
  1428. vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
  1429. vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR,"weight"));
  1430. vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"weight"));
  1431. vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"bias"));
  1432. vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H,"weight"));
  1433. vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight"));
  1434. vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
  1435. } break;
  1436. case PROJECTOR_TYPE_QWEN2VL:
  1437. {
  1438. vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
  1439. vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
  1440. vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
  1441. vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
  1442. } break;
  1443. case PROJECTOR_TYPE_GEMMA3:
  1444. {
  1445. vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
  1446. vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
  1447. } break;
  1448. case PROJECTOR_TYPE_IDEFICS3:
  1449. {
  1450. vision_model.projection = get_tensor(TN_MM_PROJECTOR);
  1451. } break;
  1452. case PROJECTOR_TYPE_PIXTRAL:
  1453. {
  1454. vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
  1455. vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
  1456. vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
  1457. vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
  1458. // [IMG_BREAK] token embedding
  1459. vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
  1460. } break;
  1461. default:
  1462. GGML_ASSERT(false && "unknown projector type");
  1463. }
  1464. // load data
  1465. {
  1466. std::vector<uint8_t> read_buf;
  1467. auto fin = std::ifstream(fname, std::ios::binary);
  1468. if (!fin) {
  1469. throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
  1470. }
  1471. // alloc memory and offload data
  1472. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
  1473. ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
  1474. ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
  1475. for (auto & t : tensors_to_load) {
  1476. struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
  1477. const size_t offset = tensor_offset[t->name];
  1478. fin.seekg(offset, std::ios::beg);
  1479. if (!fin) {
  1480. throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
  1481. }
  1482. size_t num_bytes = ggml_nbytes(cur);
  1483. if (ggml_backend_buft_is_host(buft)) {
  1484. // for the CPU and Metal backend, we can read directly into the tensor
  1485. fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
  1486. } else {
  1487. // read into a temporary buffer first, then copy to device memory
  1488. read_buf.resize(num_bytes);
  1489. fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
  1490. ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
  1491. }
  1492. }
  1493. fin.close();
  1494. LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
  1495. }
  1496. }
  1497. void alloc_compute_meta() {
  1498. ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
  1499. // create a fake batch
  1500. clip_image_f32_batch batch;
  1501. clip_image_f32_ptr img(clip_image_f32_init());
  1502. clip_image_size image_size;
  1503. image_size.width = ctx_clip.vision_model.hparams.image_size;
  1504. image_size.height = ctx_clip.vision_model.hparams.image_size;
  1505. img->nx = image_size.width;
  1506. img->ny = image_size.height;
  1507. img->buf.resize(image_size.width * image_size.height * 3);
  1508. batch.entries.push_back(std::move(img));
  1509. ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
  1510. ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
  1511. for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
  1512. ggml_backend_t backend = ctx_clip.backend_ptrs[i];
  1513. ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i];
  1514. size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend);
  1515. if (size > 1) {
  1516. LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
  1517. ggml_backend_buft_name(buft),
  1518. size / 1024.0 / 1024.0);
  1519. }
  1520. }
  1521. }
  1522. void get_bool(const std::string & key, bool & output, bool required = true) {
  1523. const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
  1524. if (i < 0) {
  1525. if (required) throw std::runtime_error("Key not found: " + key);
  1526. return;
  1527. }
  1528. output = gguf_get_val_bool(ctx_gguf.get(), i);
  1529. }
  1530. void get_i32(const std::string & key, int & output, bool required = true) {
  1531. const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
  1532. if (i < 0) {
  1533. if (required) throw std::runtime_error("Key not found: " + key);
  1534. return;
  1535. }
  1536. output = gguf_get_val_i32(ctx_gguf.get(), i);
  1537. }
  1538. void get_u32(const std::string & key, int & output, bool required = true) {
  1539. const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
  1540. if (i < 0) {
  1541. if (required) throw std::runtime_error("Key not found: " + key);
  1542. return;
  1543. }
  1544. output = gguf_get_val_u32(ctx_gguf.get(), i);
  1545. }
  1546. void get_f32(const std::string & key, float & output, bool required = true) {
  1547. const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
  1548. if (i < 0) {
  1549. if (required) throw std::runtime_error("Key not found: " + key);
  1550. return;
  1551. }
  1552. output = gguf_get_val_f32(ctx_gguf.get(), i);
  1553. }
  1554. void get_string(const std::string & key, std::string & output, bool required = true) {
  1555. const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
  1556. if (i < 0) {
  1557. if (required) throw std::runtime_error("Key not found: " + key);
  1558. return;
  1559. }
  1560. output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
  1561. }
  1562. void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
  1563. const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
  1564. if (i < 0) {
  1565. if (required) throw std::runtime_error("Key not found: " + key);
  1566. return;
  1567. }
  1568. int n = gguf_get_arr_n(ctx_gguf.get(), i);
  1569. output.resize(n);
  1570. const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i);
  1571. for (int i = 0; i < n; ++i) {
  1572. output[i] = values[i];
  1573. }
  1574. }
  1575. };
  1576. // read and create ggml_context containing the tensors and their data
  1577. struct clip_ctx * clip_model_load(const char * fname, const int verbosity) {
  1578. return clip_init(fname, clip_context_params{
  1579. /* use_gpu */ true,
  1580. /* verbosity */ static_cast<ggml_log_level>(verbosity),
  1581. });
  1582. }
  1583. struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) {
  1584. g_logger_state.verbosity_thold = ctx_params.verbosity;
  1585. clip_ctx * ctx_clip = new clip_ctx(ctx_params);
  1586. try {
  1587. clip_model_loader loader(fname, *ctx_clip);
  1588. loader.load_hparams();
  1589. loader.load_tensors();
  1590. loader.alloc_compute_meta();
  1591. } catch (const std::exception & e) {
  1592. LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
  1593. delete ctx_clip;
  1594. return nullptr;
  1595. }
  1596. return ctx_clip;
  1597. }
  1598. void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
  1599. ctx_clip->load_image_size = *load_image_size; // copy
  1600. }
  1601. struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
  1602. return &ctx_clip->load_image_size;
  1603. }
  1604. struct clip_image_size * clip_image_size_init() {
  1605. struct clip_image_size * load_image_size = new struct clip_image_size();
  1606. load_image_size->width = 448;
  1607. load_image_size->height = 448;
  1608. return load_image_size;
  1609. }
  1610. struct clip_image_u8 * clip_image_u8_init() {
  1611. return new clip_image_u8();
  1612. }
  1613. struct clip_image_f32 * clip_image_f32_init() {
  1614. return new clip_image_f32();
  1615. }
  1616. struct clip_image_f32_batch * clip_image_f32_batch_init() {
  1617. return new clip_image_f32_batch();
  1618. }
  1619. unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
  1620. if (nx) *nx = img->nx;
  1621. if (ny) *ny = img->ny;
  1622. return img->buf.data();
  1623. }
  1624. void clip_image_size_free(struct clip_image_size * load_image_size) {
  1625. if (load_image_size == nullptr) {
  1626. return;
  1627. }
  1628. delete load_image_size;
  1629. }
  1630. void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
  1631. void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
  1632. void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
  1633. void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
  1634. size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
  1635. return batch->entries.size();
  1636. }
  1637. size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
  1638. if (idx < 0 || idx >= (int)batch->entries.size()) {
  1639. LOG_ERR("%s: invalid index %d\n", __func__, idx);
  1640. return 0;
  1641. }
  1642. return batch->entries[idx]->nx;
  1643. }
  1644. size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
  1645. if (idx < 0 || idx >= (int)batch->entries.size()) {
  1646. LOG_ERR("%s: invalid index %d\n", __func__, idx);
  1647. return 0;
  1648. }
  1649. return batch->entries[idx]->ny;
  1650. }
  1651. clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
  1652. if (idx < 0 || idx >= (int)batch->entries.size()) {
  1653. LOG_ERR("%s: invalid index %d\n", __func__, idx);
  1654. return nullptr;
  1655. }
  1656. return batch->entries[idx].get();
  1657. }
  1658. void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
  1659. img->nx = nx;
  1660. img->ny = ny;
  1661. img->buf.resize(3 * nx * ny);
  1662. memcpy(img->buf.data(), rgb_pixels, img->buf.size());
  1663. }
  1664. bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
  1665. int nx, ny, nc;
  1666. auto * data = stbi_load(fname, &nx, &ny, &nc, 3);
  1667. if (!data) {
  1668. LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
  1669. return false;
  1670. }
  1671. clip_build_img_from_pixels(data, nx, ny, img);
  1672. stbi_image_free(data);
  1673. return true;
  1674. }
  1675. bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) {
  1676. int nx, ny, nc;
  1677. auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
  1678. if (!data) {
  1679. LOG_ERR("%s: failed to decode image bytes\n", __func__);
  1680. return false;
  1681. }
  1682. clip_build_img_from_pixels(data, nx, ny, img);
  1683. stbi_image_free(data);
  1684. return true;
  1685. }
  1686. // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
  1687. static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
  1688. dst.nx = src.nx;
  1689. dst.ny = src.ny;
  1690. dst.buf.resize(src.buf.size());
  1691. // TODO @ngxson : seems like this could be done more efficiently on cgraph
  1692. for (size_t i = 0; i < src.buf.size(); ++i) {
  1693. int c = i % 3; // rgb
  1694. dst.buf[i] = (static_cast<float>(src.buf[i]) / 255.0f - mean[c]) / std[c];
  1695. }
  1696. }
  1697. // set of tools to manupulate images
  1698. // in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
  1699. struct image_manipulation {
  1700. // Bilinear resize function
  1701. static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
  1702. dst.nx = target_width;
  1703. dst.ny = target_height;
  1704. dst.buf.resize(3 * target_width * target_height);
  1705. float x_ratio = static_cast<float>(src.nx - 1) / target_width;
  1706. float y_ratio = static_cast<float>(src.ny - 1) / target_height;
  1707. for (int y = 0; y < target_height; y++) {
  1708. for (int x = 0; x < target_width; x++) {
  1709. float px = x_ratio * x;
  1710. float py = y_ratio * y;
  1711. int x_floor = static_cast<int>(px);
  1712. int y_floor = static_cast<int>(py);
  1713. float x_lerp = px - x_floor;
  1714. float y_lerp = py - y_floor;
  1715. for (int c = 0; c < 3; c++) {
  1716. float top = lerp(
  1717. static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
  1718. static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
  1719. x_lerp
  1720. );
  1721. float bottom = lerp(
  1722. static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
  1723. static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
  1724. x_lerp
  1725. );
  1726. dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, y_lerp));
  1727. }
  1728. }
  1729. }
  1730. }
  1731. // Bicubic resize function
  1732. // part of image will be cropped if the aspect ratio is different
  1733. static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
  1734. const int nx = img.nx;
  1735. const int ny = img.ny;
  1736. dst.nx = target_width;
  1737. dst.ny = target_height;
  1738. dst.buf.resize(3 * target_width * target_height);
  1739. float Cc;
  1740. float C[5];
  1741. float d0, d2, d3, a0, a1, a2, a3;
  1742. int i, j, k, jj;
  1743. int x, y;
  1744. float dx, dy;
  1745. float tx, ty;
  1746. tx = (float)nx / (float)target_width;
  1747. ty = (float)ny / (float)target_height;
  1748. // Bicubic interpolation; adapted from ViT.cpp, inspired from :
  1749. // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
  1750. // -> https://en.wikipedia.org/wiki/Bicubic_interpolation
  1751. for (i = 0; i < target_height; i++) {
  1752. for (j = 0; j < target_width; j++) {
  1753. x = (int)(tx * j);
  1754. y = (int)(ty * i);
  1755. dx = tx * j - x;
  1756. dy = ty * i - y;
  1757. for (k = 0; k < 3; k++) {
  1758. for (jj = 0; jj <= 3; jj++) {
  1759. d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
  1760. d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
  1761. d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
  1762. a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
  1763. a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
  1764. a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
  1765. a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
  1766. C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
  1767. d0 = C[0] - C[1];
  1768. d2 = C[2] - C[1];
  1769. d3 = C[3] - C[1];
  1770. a0 = C[1];
  1771. a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
  1772. a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
  1773. a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
  1774. Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
  1775. const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
  1776. dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
  1777. }
  1778. }
  1779. }
  1780. }
  1781. return true;
  1782. }
  1783. // llava-1.6 type of resize_and_pad
  1784. // if the ratio is not 1:1, padding with pad_color will be applied
  1785. // pad_color is single channel, default is 0 (black)
  1786. static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array<uint8_t, 3> pad_color = {0, 0, 0}) {
  1787. int target_width = target_resolution.width;
  1788. int target_height = target_resolution.height;
  1789. float scale_w = static_cast<float>(target_width) / image.nx;
  1790. float scale_h = static_cast<float>(target_height) / image.ny;
  1791. int new_width, new_height;
  1792. if (scale_w < scale_h) {
  1793. new_width = target_width;
  1794. new_height = std::min(static_cast<int>(std::ceil(image.ny * scale_w)), target_height);
  1795. } else {
  1796. new_height = target_height;
  1797. new_width = std::min(static_cast<int>(std::ceil(image.nx * scale_h)), target_width);
  1798. }
  1799. clip_image_u8 resized_image;
  1800. bicubic_resize(image, resized_image, new_width, new_height);
  1801. clip_image_u8 padded_image;
  1802. padded_image.nx = target_width;
  1803. padded_image.ny = target_height;
  1804. padded_image.buf.resize(3 * target_width * target_height);
  1805. // Fill the padded image with the fill color
  1806. for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
  1807. padded_image.buf[i] = pad_color[0];
  1808. padded_image.buf[i + 1] = pad_color[1];
  1809. padded_image.buf[i + 2] = pad_color[2];
  1810. }
  1811. // Calculate padding offsets
  1812. int pad_x = (target_width - new_width) / 2;
  1813. int pad_y = (target_height - new_height) / 2;
  1814. // Copy the resized image into the center of the padded buffer
  1815. for (int y = 0; y < new_height; ++y) {
  1816. for (int x = 0; x < new_width; ++x) {
  1817. for (int c = 0; c < 3; ++c) {
  1818. padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
  1819. }
  1820. }
  1821. }
  1822. dst = std::move(padded_image);
  1823. }
  1824. static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
  1825. dst.nx = w;
  1826. dst.ny = h;
  1827. dst.buf.resize(3 * w * h);
  1828. for (int i = 0; i < h; ++i) {
  1829. for (int j = 0; j < w; ++j) {
  1830. int src_idx = 3 * ((y + i)*image.nx + (x + j));
  1831. int dst_idx = 3 * (i*w + j);
  1832. dst.buf[dst_idx] = image.buf[src_idx];
  1833. dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
  1834. dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
  1835. }
  1836. }
  1837. }
  1838. // calculate the size of the **resized** image, while preserving the aspect ratio
  1839. // the calculated size will be aligned to the nearest multiple of align_size
  1840. // if H or W size is larger than max_dimension, it will be resized to max_dimension
  1841. static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
  1842. if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
  1843. return {0, 0};
  1844. }
  1845. float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
  1846. static_cast<float>(max_dimension) / inp_size.height));
  1847. float target_width_f = static_cast<float>(inp_size.width) * scale;
  1848. float target_height_f = static_cast<float>(inp_size.height) * scale;
  1849. int aligned_width = GGML_PAD((int)target_width_f, align_size);
  1850. int aligned_height = GGML_PAD((int)target_height_f, align_size);
  1851. return {aligned_width, aligned_height};
  1852. }
  1853. private:
  1854. static inline int clip(int x, int lower, int upper) {
  1855. return std::max(lower, std::min(x, upper));
  1856. }
  1857. // Linear interpolation between two points
  1858. static inline float lerp(float s, float e, float t) {
  1859. return s + (e - s) * t;
  1860. }
  1861. };
  1862. /**
  1863. * implementation of LLaVA-UHD:
  1864. * - https://arxiv.org/pdf/2403.11703
  1865. * - https://github.com/thunlp/LLaVA-UHD
  1866. * - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
  1867. *
  1868. * overview:
  1869. * - an image always have a single overview (downscaled image)
  1870. * - an image can have 0 or multiple slices, depending on the image size
  1871. * - each slice can then be considered as a separate image
  1872. *
  1873. * for example:
  1874. *
  1875. * [overview] --> [slice 1] --> [slice 2]
  1876. * | |
  1877. * +--> [slice 3] --> [slice 4]
  1878. */
  1879. struct llava_uhd {
  1880. struct slice_coordinates {
  1881. int x;
  1882. int y;
  1883. clip_image_size size;
  1884. };
  1885. struct slice_instructions {
  1886. clip_image_size overview_size; // size of downscaled image
  1887. clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size)
  1888. clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices
  1889. std::vector<slice_coordinates> slices;
  1890. bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6)
  1891. };
  1892. static int get_max_slices(struct clip_ctx * ctx) {
  1893. if (clip_is_minicpmv(ctx)) {
  1894. return 9;
  1895. }
  1896. return 0;
  1897. }
  1898. static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
  1899. slice_instructions res;
  1900. const int patch_size = clip_get_patch_size(ctx);
  1901. const int slice_size = clip_get_image_size(ctx);
  1902. const int max_slice_nums = get_max_slices(ctx);
  1903. const int original_width = original_size.width;
  1904. const int original_height = original_size.height;
  1905. const float log_ratio = log((float)original_width / original_height);
  1906. const float ratio = (float)original_width * original_height / (slice_size * slice_size);
  1907. const int multiple = fmin(ceil(ratio), max_slice_nums);
  1908. const bool has_slices = (multiple > 1);
  1909. const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty();
  1910. if (has_pinpoints) {
  1911. // has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
  1912. auto refine_size = llava_uhd::select_best_resolution(
  1913. ctx->vision_model.hparams.image_grid_pinpoints,
  1914. original_size);
  1915. res.overview_size = clip_image_size{slice_size, slice_size};
  1916. res.refined_size = refine_size;
  1917. res.grid_size = clip_image_size{0, 0};
  1918. res.padding_refined = true;
  1919. for (int y = 0; y < refine_size.height; y += slice_size) {
  1920. for (int x = 0; x < refine_size.width; x += slice_size) {
  1921. slice_coordinates slice;
  1922. slice.x = x;
  1923. slice.y = y;
  1924. slice.size.width = std::min(slice_size, refine_size.width - x);
  1925. slice.size.height = std::min(slice_size, refine_size.height - y);
  1926. res.slices.push_back(slice);
  1927. if (x == 0) {
  1928. res.grid_size.width++;
  1929. }
  1930. }
  1931. res.grid_size.height++;
  1932. }
  1933. return res;
  1934. }
  1935. // no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
  1936. auto best_size = get_best_resize(original_size, slice_size, patch_size, has_slices);
  1937. res.overview_size = best_size;
  1938. if (!has_slices) {
  1939. // skip slicing logic
  1940. res.refined_size = clip_image_size{0, 0};
  1941. res.grid_size = clip_image_size{0, 0};
  1942. } else {
  1943. auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio);
  1944. auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
  1945. res.grid_size = best_grid;
  1946. res.refined_size = refine_size;
  1947. int width = refine_size.width;
  1948. int height = refine_size.height;
  1949. int grid_x = int(width / best_grid.width);
  1950. int grid_y = int(height / best_grid.height);
  1951. for (int patches_y = 0, ic = 0;
  1952. patches_y < refine_size.height && ic < best_grid.height;
  1953. patches_y += grid_y, ic += 1) {
  1954. for (int patches_x = 0, jc = 0;
  1955. patches_x < refine_size.width && jc < best_grid.width;
  1956. patches_x += grid_x, jc += 1) {
  1957. slice_coordinates slice;
  1958. slice.x = patches_x;
  1959. slice.y = patches_y;
  1960. slice.size.width = grid_x;
  1961. slice.size.height = grid_y;
  1962. res.slices.push_back(slice);
  1963. // LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
  1964. }
  1965. }
  1966. }
  1967. return res;
  1968. }
  1969. static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
  1970. std::vector<clip_image_u8_ptr> output;
  1971. // resize to overview size
  1972. clip_image_u8_ptr resized_img(clip_image_u8_init());
  1973. image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height);
  1974. output.push_back(std::move(resized_img));
  1975. if (inst.slices.empty()) {
  1976. // no slices, just return the resized image
  1977. return output;
  1978. }
  1979. // resize to refined size
  1980. clip_image_u8_ptr refined_img(clip_image_u8_init());
  1981. if (inst.padding_refined) {
  1982. image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
  1983. } else {
  1984. image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
  1985. }
  1986. // create slices
  1987. for (const auto & slice : inst.slices) {
  1988. int x = slice.x;
  1989. int y = slice.y;
  1990. int w = slice.size.width;
  1991. int h = slice.size.height;
  1992. clip_image_u8_ptr img_slice(clip_image_u8_init());
  1993. image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
  1994. output.push_back(std::move(img_slice));
  1995. }
  1996. return output;
  1997. }
  1998. private:
  1999. static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
  2000. int width = original_size.width;
  2001. int height = original_size.height;
  2002. if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
  2003. float r = static_cast<float>(width) / height;
  2004. height = static_cast<int>(scale_resolution / std::sqrt(r));
  2005. width = static_cast<int>(height * r);
  2006. }
  2007. clip_image_size res;
  2008. res.width = ensure_divide(width, patch_size);
  2009. res.height = ensure_divide(height, patch_size);
  2010. return res;
  2011. }
  2012. /**
  2013. * Selects the best resolution from a list of possible resolutions based on the original size.
  2014. *
  2015. * @param original_size The original size of the image
  2016. * @param possible_resolutions A list of possible resolutions
  2017. * @return The best fit resolution
  2018. */
  2019. static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector<clip_image_size> & possible_resolutions) {
  2020. int original_width = original_size.width;
  2021. int original_height = original_size.height;
  2022. clip_image_size best_fit;
  2023. int max_effective_resolution = 0;
  2024. int min_wasted_resolution = std::numeric_limits<int>::max();
  2025. for (const auto & resolution : possible_resolutions) {
  2026. int width = resolution.width;
  2027. int height = resolution.height;
  2028. float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
  2029. int downscaled_width = static_cast<int>(original_width * scale);
  2030. int downscaled_height = static_cast<int>(original_height * scale);
  2031. int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
  2032. int wasted_resolution = (width * height) - effective_resolution;
  2033. // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
  2034. if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
  2035. max_effective_resolution = effective_resolution;
  2036. min_wasted_resolution = wasted_resolution;
  2037. best_fit = resolution;
  2038. }
  2039. }
  2040. return best_fit;
  2041. }
  2042. // used by llava 1.6 with custom list of pinpoints
  2043. static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
  2044. std::vector<clip_image_size> possible_resolutions;
  2045. for (size_t i = 0; i < pinpoints.size(); i += 2) {
  2046. possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
  2047. }
  2048. return select_best_resolution(original_size, possible_resolutions);
  2049. }
  2050. static int ensure_divide(int length, int patch_size) {
  2051. return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
  2052. }
  2053. static clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
  2054. int width = original_size.width;
  2055. int height = original_size.height;
  2056. int grid_x = grid.width;
  2057. int grid_y = grid.height;
  2058. int refine_width = ensure_divide(width, grid_x);
  2059. int refine_height = ensure_divide(height, grid_y);
  2060. clip_image_size grid_size;
  2061. grid_size.width = refine_width / grid_x;
  2062. grid_size.height = refine_height / grid_y;
  2063. auto best_grid_size = get_best_resize(grid_size, scale_resolution, patch_size, allow_upscale);
  2064. int best_grid_width = best_grid_size.width;
  2065. int best_grid_height = best_grid_size.height;
  2066. clip_image_size refine_size;
  2067. refine_size.width = best_grid_width * grid_x;
  2068. refine_size.height = best_grid_height * grid_y;
  2069. return refine_size;
  2070. }
  2071. static clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
  2072. std::vector<int> candidate_split_grids_nums;
  2073. for (int i : {multiple - 1, multiple, multiple + 1}) {
  2074. if (i == 1 || i > max_slice_nums) {
  2075. continue;
  2076. }
  2077. candidate_split_grids_nums.push_back(i);
  2078. }
  2079. std::vector<clip_image_size> candidate_grids;
  2080. for (int split_grids_nums : candidate_split_grids_nums) {
  2081. int m = 1;
  2082. while (m <= split_grids_nums) {
  2083. if (split_grids_nums % m == 0) {
  2084. candidate_grids.push_back(clip_image_size{m, split_grids_nums / m});
  2085. }
  2086. ++m;
  2087. }
  2088. }
  2089. clip_image_size best_grid{1, 1};
  2090. float min_error = std::numeric_limits<float>::infinity();
  2091. for (const auto& grid : candidate_grids) {
  2092. float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height));
  2093. if (error < min_error) {
  2094. best_grid = grid;
  2095. min_error = error;
  2096. }
  2097. }
  2098. return best_grid;
  2099. }
  2100. };
  2101. // TODO @ngxson : decprecate the load_image_size singleton pattern
  2102. int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
  2103. const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
  2104. return inst.grid_size.width;
  2105. }
  2106. // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
  2107. // res_imgs memory is being allocated here, previous allocations will be freed if found
  2108. bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
  2109. clip_image_size original_size{img->nx, img->ny};
  2110. bool pad_to_square = true;
  2111. auto & params = ctx->vision_model.hparams;
  2112. // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
  2113. if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
  2114. pad_to_square = false;
  2115. }
  2116. if (clip_is_minicpmv(ctx)) {
  2117. auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
  2118. std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
  2119. for (size_t i = 0; i < imgs.size(); ++i) {
  2120. // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
  2121. clip_image_f32_ptr res(clip_image_f32_init());
  2122. normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
  2123. res_imgs->entries.push_back(std::move(res));
  2124. }
  2125. return true;
  2126. }
  2127. else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  2128. clip_image_u8 resized;
  2129. auto patch_size = clip_get_patch_size(ctx) * 2;
  2130. int nx = ceil((float)img->nx / patch_size) * patch_size;
  2131. int ny = ceil((float)img->ny / patch_size) * patch_size;
  2132. image_manipulation::bicubic_resize(*img, resized, nx, ny);
  2133. clip_image_f32_ptr img_f32(clip_image_f32_init());
  2134. // clip_image_f32_ptr res(clip_image_f32_init());
  2135. normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std);
  2136. // res_imgs->data[0] = *res;
  2137. res_imgs->entries.push_back(std::move(img_f32));
  2138. return true;
  2139. }
  2140. else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
  2141. || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
  2142. || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
  2143. clip_image_u8 resized_image;
  2144. int sz = params.image_size;
  2145. image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
  2146. clip_image_f32_ptr img_f32(clip_image_f32_init());
  2147. //clip_image_save_to_bmp(resized_image, "resized.bmp");
  2148. normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
  2149. res_imgs->entries.push_back(std::move(img_f32));
  2150. return true;
  2151. }
  2152. else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
  2153. clip_image_u8 resized_image;
  2154. auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
  2155. image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
  2156. clip_image_f32_ptr img_f32(clip_image_f32_init());
  2157. normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
  2158. res_imgs->entries.push_back(std::move(img_f32));
  2159. return true;
  2160. }
  2161. // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
  2162. // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
  2163. clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
  2164. if (pad_to_square) {
  2165. // for llava-1.5, we resize image to a square, and pad the shorter side with a background color
  2166. // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
  2167. const int longer_side = std::max(img->nx, img->ny);
  2168. temp->nx = longer_side;
  2169. temp->ny = longer_side;
  2170. temp->buf.resize(3 * longer_side * longer_side);
  2171. // background color in RGB from LLaVA (this is the mean rgb color * 255)
  2172. const std::array<uint8_t, 3> pad_color = {122, 116, 104};
  2173. // resize the image to the target_size
  2174. image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
  2175. clip_image_f32_ptr res(clip_image_f32_init());
  2176. normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
  2177. res_imgs->entries.push_back(std::move(res));
  2178. return true;
  2179. } else if (!params.image_grid_pinpoints.empty()) {
  2180. // "spatial_unpad" with "anyres" processing for llava-1.6
  2181. auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
  2182. std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
  2183. for (size_t i = 0; i < imgs.size(); ++i) {
  2184. // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
  2185. clip_image_f32_ptr res(clip_image_f32_init());
  2186. normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
  2187. res_imgs->entries.push_back(std::move(res));
  2188. }
  2189. return true;
  2190. }
  2191. GGML_ASSERT(false && "Unknown image preprocessing type");
  2192. }
  2193. ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
  2194. return ctx->vision_model.image_newline;
  2195. }
  2196. void clip_free(clip_ctx * ctx) {
  2197. if (ctx == nullptr) {
  2198. return;
  2199. }
  2200. delete ctx;
  2201. }
  2202. size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
  2203. return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
  2204. }
  2205. size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
  2206. clip_image_f32 img;
  2207. img.nx = img_w;
  2208. img.ny = img_h;
  2209. return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
  2210. }
  2211. int32_t clip_get_image_size(const struct clip_ctx * ctx) {
  2212. return ctx->vision_model.hparams.image_size;
  2213. }
  2214. int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
  2215. return ctx->vision_model.hparams.patch_size;
  2216. }
  2217. int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
  2218. return ctx->vision_model.hparams.hidden_size;
  2219. }
  2220. const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
  2221. return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
  2222. }
  2223. const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
  2224. if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
  2225. return &ctx->vision_model.hparams.image_grid_pinpoints.front();
  2226. }
  2227. return nullptr;
  2228. }
  2229. size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
  2230. return ctx->vision_model.hparams.image_grid_pinpoints.size();
  2231. }
  2232. int clip_n_patches(const struct clip_ctx * ctx) {
  2233. clip_image_f32 img;
  2234. img.nx = ctx->vision_model.hparams.image_size;
  2235. img.ny = ctx->vision_model.hparams.image_size;
  2236. return clip_n_patches_by_img(ctx, &img);
  2237. }
  2238. int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
  2239. const auto & params = ctx->vision_model.hparams;
  2240. int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
  2241. if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
  2242. n_patches /= 4;
  2243. } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
  2244. if (ctx->minicpmv_version == 2) {
  2245. n_patches = 96;
  2246. }
  2247. else if (ctx->minicpmv_version == 3) {
  2248. n_patches = 64;
  2249. }
  2250. else if (ctx->minicpmv_version == 4) {
  2251. n_patches = 64;
  2252. }
  2253. else {
  2254. GGML_ABORT("Unknown minicpmv version");
  2255. }
  2256. } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  2257. int patch_size = params.patch_size * 2;
  2258. int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
  2259. int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
  2260. n_patches = x_patch * y_patch;
  2261. } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
  2262. n_patches = 256;
  2263. } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
  2264. n_patches /= ctx->vision_model.hparams.proj_scale_factor;
  2265. } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
  2266. int n_patches_x = img->nx / params.patch_size;
  2267. int n_patches_y = img->ny / params.patch_size;
  2268. n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
  2269. }
  2270. return n_patches;
  2271. }
  2272. static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {
  2273. assert(embed_dim % 2 == 0);
  2274. int H = pos.size();
  2275. int W = pos[0].size();
  2276. std::vector<float> omega(embed_dim / 2);
  2277. for (int i = 0; i < embed_dim / 2; ++i) {
  2278. omega[i] = 1.0 / pow(10000.0, static_cast<float>(i) / (embed_dim / 2));
  2279. }
  2280. std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim)));
  2281. for (int h = 0; h < H; ++h) {
  2282. for (int w = 0; w < W; ++w) {
  2283. for (int d = 0; d < embed_dim / 2; ++d) {
  2284. float out_value = pos[h][w] * omega[d];
  2285. emb[h][w][d] = sin(out_value);
  2286. emb[h][w][d + embed_dim / 2] = cos(out_value);
  2287. }
  2288. }
  2289. }
  2290. return emb;
  2291. }
  2292. static std::vector<std::vector<std::vector<float>>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector<std::vector<std::vector<float>>> & grid) {
  2293. assert(embed_dim % 2 == 0);
  2294. std::vector<std::vector<std::vector<float>>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2)
  2295. std::vector<std::vector<std::vector<float>>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2)
  2296. int H = emb_h.size();
  2297. int W = emb_h[0].size();
  2298. std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim)));
  2299. for (int h = 0; h < H; ++h) {
  2300. for (int w = 0; w < W; ++w) {
  2301. for (int d = 0; d < embed_dim / 2; ++d) {
  2302. emb[h][w][d] = emb_h[h][w][d];
  2303. emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
  2304. }
  2305. }
  2306. }
  2307. return emb;
  2308. }
  2309. static std::vector<std::vector<float>> get_2d_sincos_pos_embed(int embed_dim, const std::pair<int, int> image_size) {
  2310. int grid_h_size = image_size.first;
  2311. int grid_w_size = image_size.second;
  2312. std::vector<float> grid_h(grid_h_size);
  2313. std::vector<float> grid_w(grid_w_size);
  2314. for (int i = 0; i < grid_h_size; ++i) {
  2315. grid_h[i] = static_cast<float>(i);
  2316. }
  2317. for (int i = 0; i < grid_w_size; ++i) {
  2318. grid_w[i] = static_cast<float>(i);
  2319. }
  2320. std::vector<std::vector<float>> grid(grid_h_size, std::vector<float>(grid_w_size));
  2321. for (int h = 0; h < grid_h_size; ++h) {
  2322. for (int w = 0; w < grid_w_size; ++w) {
  2323. grid[h][w] = grid_w[w];
  2324. }
  2325. }
  2326. std::vector<std::vector<std::vector<float>>> grid_2d = {grid, grid};
  2327. for (int h = 0; h < grid_h_size; ++h) {
  2328. for (int w = 0; w < grid_w_size; ++w) {
  2329. grid_2d[0][h][w] = grid_h[h];
  2330. grid_2d[1][h][w] = grid_w[w];
  2331. }
  2332. }
  2333. std::vector<std::vector<std::vector<float>>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
  2334. int H = image_size.first;
  2335. int W = image_size.second;
  2336. std::vector<std::vector<float>> pos_embed_2d(H * W, std::vector<float>(embed_dim));
  2337. for (int h = 0; h < H; ++h) {
  2338. for (int w = 0; w < W; ++w) {
  2339. pos_embed_2d[w * H + h] = pos_embed_3d[h][w];
  2340. }
  2341. }
  2342. return pos_embed_2d;
  2343. }
  2344. bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
  2345. clip_image_f32_batch imgs;
  2346. clip_image_f32_ptr img_copy(clip_image_f32_init());
  2347. *img_copy = *img;
  2348. imgs.entries.push_back(std::move(img_copy));
  2349. return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
  2350. }
  2351. bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
  2352. const clip_image_f32_batch & imgs = *imgs_c_ptr;
  2353. int batch_size = imgs.entries.size();
  2354. if (ctx->has_llava_projector
  2355. || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
  2356. || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
  2357. GGML_ASSERT(batch_size == 1);
  2358. }
  2359. // build the inference graph
  2360. ggml_backend_sched_reset(ctx->sched.get());
  2361. ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
  2362. ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
  2363. // set inputs
  2364. const auto & model = ctx->vision_model;
  2365. const auto & hparams = model.hparams;
  2366. const int image_size_width = imgs.entries[0]->nx;
  2367. const int image_size_height = imgs.entries[0]->ny;
  2368. const int patch_size = hparams.patch_size;
  2369. const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
  2370. const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
  2371. const int pos_w = ctx->load_image_size.width / patch_size;
  2372. const int pos_h = ctx->load_image_size.height / patch_size;
  2373. {
  2374. struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
  2375. std::vector<float> inp_data(ggml_nelements(inp_raw));
  2376. float * data = inp_data.data();
  2377. // layout of data (note: the channel dim is unrolled to better visualize the layout):
  2378. //
  2379. // ┌──W──┐
  2380. // │ H │ channel = R
  2381. // ├─────┤ │
  2382. // │ H │ channel = G
  2383. // ├─────┤ │
  2384. // │ H │ channel = B
  2385. // └─────┘ │
  2386. // ──────┘ x B
  2387. for (size_t i = 0; i < imgs.entries.size(); i++) {
  2388. const int nx = imgs.entries[i]->nx;
  2389. const int ny = imgs.entries[i]->ny;
  2390. const int n = nx * ny;
  2391. for (int b = 0; b < batch_size; b++) {
  2392. float * batch_entry = data + b * (3*n);
  2393. for (int y = 0; y < ny; y++) {
  2394. for (int x = 0; x < nx; x++) {
  2395. size_t base_src = 3*(y * nx + x); // idx of the first channel
  2396. size_t base_dst = y * nx + x; // idx of the first channel
  2397. batch_entry[ base_dst] = imgs.entries[b]->buf[base_src ];
  2398. batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
  2399. batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
  2400. }
  2401. }
  2402. }
  2403. }
  2404. ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
  2405. }
  2406. if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
  2407. {
  2408. // inspired from siglip:
  2409. // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
  2410. // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
  2411. struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
  2412. std::vector<int> pos_data(ggml_nelements(positions));
  2413. int * data = pos_data.data();
  2414. int bucket_coords_h[1024];
  2415. int bucket_coords_w[1024];
  2416. for (int i = 0; i < pos_h; i++){
  2417. bucket_coords_h[i] = std::floor(70.0*i/pos_h);
  2418. }
  2419. for (int i = 0; i < pos_w; i++){
  2420. bucket_coords_w[i] = std::floor(70.0*i/pos_w);
  2421. }
  2422. for (int i = 0, id = 0; i < pos_h; i++){
  2423. for (int j = 0; j < pos_w; j++){
  2424. data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
  2425. }
  2426. }
  2427. ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
  2428. }
  2429. {
  2430. // inspired from resampler of Qwen-VL:
  2431. // -> https://huggingface.co/Qwen/Qwen-VL/tree/main
  2432. // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
  2433. struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
  2434. int embed_dim = 4096;
  2435. if (ctx->minicpmv_version == 2) {
  2436. embed_dim = 4096;
  2437. }
  2438. else if (ctx->minicpmv_version == 3) {
  2439. embed_dim = 3584;
  2440. }
  2441. else if (ctx->minicpmv_version == 4) {
  2442. embed_dim = 3584;
  2443. }
  2444. else {
  2445. GGML_ABORT("Unknown minicpmv version");
  2446. }
  2447. // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
  2448. auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
  2449. std::vector<float> pos_data(ggml_nelements(pos_embed));
  2450. float * data = pos_data.data();
  2451. for(int i = 0; i < pos_w * pos_h; ++i){
  2452. for(int j = 0; j < embed_dim; ++j){
  2453. data[i * embed_dim + j] = pos_embed_t[i][j];
  2454. }
  2455. }
  2456. ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
  2457. }
  2458. }
  2459. else {
  2460. // non-minicpmv models
  2461. if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
  2462. struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
  2463. const int pw = image_size_width / patch_size;
  2464. const int ph = image_size_height / patch_size;
  2465. int* positions_data = (int*)malloc(ggml_nbytes(positions));
  2466. int ptr = 0;
  2467. for (int y = 0; y < ph; y+=2)
  2468. {
  2469. for (int x = 0; x < pw; x+=2)
  2470. {
  2471. for (int dy = 0; dy < 2; dy++) {
  2472. for (int dx = 0; dx < 2; dx++) {
  2473. positions_data[ptr] = y + dy;
  2474. positions_data[num_patches + ptr] = x + dx;
  2475. positions_data[num_patches * 2 + ptr] = y + dy;
  2476. positions_data[num_patches * 3 + ptr] = x + dx;
  2477. ptr++;
  2478. }
  2479. }
  2480. }
  2481. }
  2482. ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
  2483. free(positions_data);
  2484. }
  2485. else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
  2486. // do nothing
  2487. }
  2488. else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
  2489. // do nothing
  2490. }
  2491. else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
  2492. // set the 2D positions
  2493. int n_patches_per_col = image_size_width / patch_size;
  2494. std::vector<int> pos_data(num_positions);
  2495. struct ggml_tensor * pos;
  2496. // dimension H
  2497. pos = ggml_graph_get_tensor(gf, "pos_h");
  2498. for (int i = 0; i < num_positions; i++) {
  2499. pos_data[i] = i / n_patches_per_col;
  2500. }
  2501. ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
  2502. // dimension W
  2503. pos = ggml_graph_get_tensor(gf, "pos_w");
  2504. for (int i = 0; i < num_positions; i++) {
  2505. pos_data[i] = i % n_patches_per_col;
  2506. }
  2507. ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
  2508. }
  2509. else {
  2510. // llava and other models
  2511. struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
  2512. int* positions_data = (int*)malloc(ggml_nbytes(positions));
  2513. for (int i = 0; i < num_positions; i++) {
  2514. positions_data[i] = i;
  2515. }
  2516. ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
  2517. free(positions_data);
  2518. if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
  2519. struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
  2520. // The patches vector is used to get rows to index into the embeds with;
  2521. // we should skip dim 0 only if we have CLS to avoid going out of bounds
  2522. // when retrieving the rows.
  2523. int patch_offset = model.class_embedding ? 1 : 0;
  2524. int* patches_data = (int*)malloc(ggml_nbytes(patches));
  2525. for (int i = 0; i < num_patches; i++) {
  2526. patches_data[i] = i + patch_offset;
  2527. }
  2528. ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
  2529. free(patches_data);
  2530. }
  2531. }
  2532. }
  2533. ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
  2534. auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
  2535. if (status != GGML_STATUS_SUCCESS) {
  2536. LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status);
  2537. return false;
  2538. }
  2539. // the last node is the embedding tensor
  2540. struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
  2541. // copy the embeddings to the location passed by the user
  2542. ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
  2543. return true;
  2544. }
  2545. bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
  2546. assert(itype < GGML_TYPE_COUNT);
  2547. ggml_type type = static_cast<ggml_type>(itype);
  2548. auto * ctx_clip = clip_init(fname_inp, clip_context_params{
  2549. /* use_gpu */ false,
  2550. /* verbosity */ GGML_LOG_LEVEL_ERROR,
  2551. });
  2552. const auto & ctx_src = ctx_clip->ctx_gguf.get();
  2553. const auto & ctx_data = ctx_clip->ctx_data.get();
  2554. auto * ctx_out = gguf_init_empty();
  2555. gguf_set_kv(ctx_out, ctx_src);
  2556. gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
  2557. gguf_set_val_u32(ctx_out, "general.file_type", itype);
  2558. auto fout = std::ofstream(fname_out, std::ios::binary);
  2559. const int n_tensors = gguf_get_n_tensors(ctx_src);
  2560. for (int i = 0; i < n_tensors; ++i) {
  2561. const char * name = gguf_get_tensor_name(ctx_src, i);
  2562. struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
  2563. gguf_add_tensor(ctx_out, cur);
  2564. }
  2565. const size_t meta_size = gguf_get_meta_size(ctx_out);
  2566. for (size_t i = 0; i < meta_size; ++i) {
  2567. fout.put(0);
  2568. }
  2569. // regexes of tensor names to be quantized
  2570. const std::vector<std::string> k_names = {
  2571. ".*weight",
  2572. };
  2573. std::vector<uint8_t> work(512);
  2574. std::vector<float> conv_buf(512);
  2575. size_t total_size_org = 0;
  2576. size_t total_size_new = 0;
  2577. for (int i = 0; i < n_tensors; ++i) {
  2578. const std::string name = gguf_get_tensor_name(ctx_src, i);
  2579. struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str());
  2580. enum ggml_type new_type;
  2581. void * new_data;
  2582. size_t new_size;
  2583. bool quantize = false;
  2584. for (const auto & s : k_names) {
  2585. if (std::regex_match(name, std::regex(s))) {
  2586. quantize = true;
  2587. break;
  2588. }
  2589. }
  2590. // quantize only 2D tensors and bigger than block size
  2591. quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type);
  2592. if (quantize) {
  2593. new_type = type;
  2594. if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) {
  2595. new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type
  2596. // LOG_ERR("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type));
  2597. }
  2598. const size_t n_elms = ggml_nelements(cur);
  2599. float * f32_data;
  2600. switch (cur->type) {
  2601. case GGML_TYPE_F32:
  2602. f32_data = (float *)cur->data;
  2603. break;
  2604. case GGML_TYPE_F16:
  2605. if (conv_buf.size() < n_elms) {
  2606. conv_buf.resize(n_elms);
  2607. }
  2608. for (size_t j = 0; j < n_elms; ++j) {
  2609. conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]);
  2610. }
  2611. f32_data = (float *)conv_buf.data();
  2612. break;
  2613. default:
  2614. LOG_ERR("%s: Please use an input file in f32 or f16\n", __func__);
  2615. gguf_free(ctx_out);
  2616. return false;
  2617. }
  2618. if (work.size() < n_elms * 4) {
  2619. work.resize(n_elms * 4);
  2620. }
  2621. new_data = work.data();
  2622. new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr);
  2623. } else {
  2624. new_type = cur->type;
  2625. new_data = cur->data;
  2626. new_size = ggml_nbytes(cur);
  2627. }
  2628. const size_t orig_size = ggml_nbytes(cur);
  2629. total_size_org += orig_size;
  2630. total_size_new += new_size;
  2631. gguf_set_tensor_type(ctx_out, name.c_str(), new_type);
  2632. GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size);
  2633. gguf_set_tensor_data(ctx_out, name.c_str(), new_data);
  2634. fout.write((const char *)new_data, new_size);
  2635. size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size;
  2636. for (size_t j = 0; j < pad; ++j) {
  2637. fout.put(0);
  2638. }
  2639. LOG_INF("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize,
  2640. orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
  2641. }
  2642. // go back to beginning of file and write the updated metadata
  2643. fout.seekp(0, std::ios::beg);
  2644. std::vector<uint8_t> meta(meta_size);
  2645. gguf_get_meta_data(ctx_out, meta.data());
  2646. fout.write((const char *)meta.data(), meta_size);
  2647. fout.close();
  2648. clip_free(ctx_clip);
  2649. gguf_free(ctx_out);
  2650. {
  2651. LOG_INF("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
  2652. LOG_INF("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
  2653. }
  2654. return true;
  2655. }
  2656. int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
  2657. switch (ctx->proj_type) {
  2658. case PROJECTOR_TYPE_LDP:
  2659. return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
  2660. case PROJECTOR_TYPE_LDPV2:
  2661. return ctx->vision_model.mm_model_peg_0_b->ne[0];
  2662. case PROJECTOR_TYPE_MLP:
  2663. case PROJECTOR_TYPE_PIXTRAL:
  2664. return ctx->vision_model.mm_2_b->ne[0];
  2665. case PROJECTOR_TYPE_MLP_NORM:
  2666. return ctx->vision_model.mm_3_b->ne[0];
  2667. case PROJECTOR_TYPE_MINICPMV:
  2668. if (ctx->minicpmv_version == 2) {
  2669. return 4096;
  2670. } else if (ctx->minicpmv_version == 3) {
  2671. return 3584;
  2672. } else if (ctx->minicpmv_version == 4) {
  2673. return 3584;
  2674. }
  2675. GGML_ABORT("Unknown minicpmv version");
  2676. case PROJECTOR_TYPE_GLM_EDGE:
  2677. return ctx->vision_model.mm_model_mlp_3_w->ne[1];
  2678. case PROJECTOR_TYPE_QWEN2VL:
  2679. return ctx->vision_model.mm_1_b->ne[0];
  2680. case PROJECTOR_TYPE_GEMMA3:
  2681. return ctx->vision_model.mm_input_proj_w->ne[0];
  2682. case PROJECTOR_TYPE_IDEFICS3:
  2683. return ctx->vision_model.projection->ne[1];
  2684. default:
  2685. GGML_ABORT("Unknown projector type");
  2686. }
  2687. }
  2688. int clip_is_minicpmv(const struct clip_ctx * ctx) {
  2689. if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
  2690. return ctx->minicpmv_version;
  2691. }
  2692. return 0;
  2693. }
  2694. bool clip_is_glm(const struct clip_ctx * ctx) {
  2695. return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
  2696. }
  2697. bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
  2698. return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
  2699. }
  2700. bool clip_is_llava(const struct clip_ctx * ctx) {
  2701. return ctx->has_llava_projector;
  2702. }
  2703. bool clip_is_gemma3(const struct clip_ctx * ctx) {
  2704. return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
  2705. }
  2706. bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
  2707. clip_image_f32 clip_img;
  2708. clip_img.buf.resize(h * w * 3);
  2709. for (int i = 0; i < h*w*3; i++)
  2710. {
  2711. clip_img.buf[i] = img[i];
  2712. }
  2713. clip_img.nx = w;
  2714. clip_img.ny = h;
  2715. clip_image_encode(ctx, n_threads, &clip_img, vec);
  2716. return true;
  2717. }
  2718. //
  2719. // API used internally with mtmd
  2720. //
  2721. projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
  2722. return ctx->proj_type;
  2723. }