ggml.c 230 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374
  1. #define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
  2. #define _USE_MATH_DEFINES // For M_PI on MSVC
  3. #include "ggml-backend.h"
  4. #include "ggml-impl.h"
  5. #include "ggml-threading.h"
  6. #include "ggml-cpu.h"
  7. #include "ggml.h"
  8. // FIXME: required here for quantization functions
  9. #include "ggml-quants.h"
  10. #ifdef GGML_USE_CPU_HBM
  11. #include <hbwmalloc.h>
  12. #endif
  13. #if defined(_MSC_VER) || defined(__MINGW32__)
  14. #include <malloc.h> // using malloc.h with MSC/MINGW
  15. #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
  16. #include <alloca.h>
  17. #endif
  18. #include <assert.h>
  19. #include <errno.h>
  20. #include <time.h>
  21. #include <math.h>
  22. #include <stdlib.h>
  23. #include <string.h>
  24. #include <stdint.h>
  25. #include <inttypes.h>
  26. #include <stdio.h>
  27. #include <float.h>
  28. #include <limits.h>
  29. #include <stdarg.h>
  30. #include <signal.h>
  31. #if defined(__gnu_linux__)
  32. #include <syscall.h>
  33. #endif
  34. #if defined(__APPLE__)
  35. #include <unistd.h>
  36. #include <mach/mach.h>
  37. #include <TargetConditionals.h>
  38. #endif
  39. #if defined(_WIN32)
  40. #define WIN32_LEAN_AND_MEAN
  41. #ifndef NOMINMAX
  42. #define NOMINMAX
  43. #endif
  44. #include <windows.h>
  45. #endif
  46. #define UNUSED GGML_UNUSED
  47. #if defined(_MSC_VER)
  48. #define m512bh(p) p
  49. #define m512i(p) p
  50. #else
  51. #define m512bh(p) (__m512bh)(p)
  52. #define m512i(p) (__m512i)(p)
  53. #endif
  54. #if defined(__linux__) || \
  55. defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
  56. (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
  57. #include <unistd.h>
  58. #include <sys/types.h>
  59. #include <sys/stat.h>
  60. #include <sys/wait.h>
  61. #if defined(__linux__)
  62. #include <sys/prctl.h>
  63. #endif
  64. #if defined(__ANDROID__)
  65. #include <unwind.h>
  66. #include <dlfcn.h>
  67. #include <stdio.h>
  68. struct backtrace_state {
  69. void ** current;
  70. void ** end;
  71. };
  72. static _Unwind_Reason_Code unwind_callback(struct _Unwind_Context* context, void* arg) {
  73. struct backtrace_state * state = (struct backtrace_state *)arg;
  74. uintptr_t pc = _Unwind_GetIP(context);
  75. if (pc) {
  76. if (state->current == state->end) {
  77. return _URC_END_OF_STACK;
  78. } else {
  79. *state->current++ = (void*)pc;
  80. }
  81. }
  82. return _URC_NO_REASON;
  83. }
  84. static void ggml_print_backtrace_symbols(void) {
  85. const int max = 100;
  86. void* buffer[max];
  87. struct backtrace_state state = {buffer, buffer + max};
  88. _Unwind_Backtrace(unwind_callback, &state);
  89. int count = state.current - buffer;
  90. for (int idx = 0; idx < count; ++idx) {
  91. const void * addr = buffer[idx];
  92. const char * symbol = "";
  93. Dl_info info;
  94. if (dladdr(addr, &info) && info.dli_sname) {
  95. symbol = info.dli_sname;
  96. }
  97. fprintf(stderr, "%d: %p %s\n", idx, addr, symbol);
  98. }
  99. }
  100. #elif defined(__linux__) && defined(__GLIBC__)
  101. #include <execinfo.h>
  102. static void ggml_print_backtrace_symbols(void) {
  103. void * trace[100];
  104. int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
  105. backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
  106. }
  107. #else
  108. static void ggml_print_backtrace_symbols(void) {
  109. // platform not supported
  110. }
  111. #endif
  112. void ggml_print_backtrace(void) {
  113. const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
  114. if (GGML_NO_BACKTRACE) {
  115. return;
  116. }
  117. #if defined(__linux__)
  118. FILE * f = fopen("/proc/self/status", "r");
  119. size_t size = 0;
  120. char * line = NULL;
  121. ssize_t length = 0;
  122. while ((length = getline(&line, &size, f)) > 0) {
  123. if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
  124. (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
  125. // Already being debugged, and the breakpoint is the later abort()
  126. free(line);
  127. fclose(f);
  128. return;
  129. }
  130. }
  131. free(line);
  132. fclose(f);
  133. int lock[2] = { -1, -1 };
  134. (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
  135. #endif
  136. const int parent_pid = getpid();
  137. const int child_pid = fork();
  138. if (child_pid < 0) { // error
  139. #if defined(__linux__)
  140. close(lock[1]);
  141. close(lock[0]);
  142. #endif
  143. return;
  144. } else if (child_pid == 0) { // child
  145. char attach[32];
  146. snprintf(attach, sizeof(attach), "attach %d", parent_pid);
  147. #if defined(__linux__)
  148. close(lock[1]);
  149. (void) !read(lock[0], lock, 1);
  150. close(lock[0]);
  151. #endif
  152. // try gdb
  153. execlp("gdb", "gdb", "--batch",
  154. "-ex", "set style enabled on",
  155. "-ex", attach,
  156. "-ex", "bt -frame-info source-and-location",
  157. "-ex", "detach",
  158. "-ex", "quit",
  159. (char *) NULL);
  160. // try lldb
  161. execlp("lldb", "lldb", "--batch",
  162. "-o", "bt",
  163. "-o", "quit",
  164. "-p", &attach[sizeof("attach ") - 1],
  165. (char *) NULL);
  166. // gdb failed, fallback to backtrace_symbols
  167. ggml_print_backtrace_symbols();
  168. _Exit(0);
  169. } else { // parent
  170. #if defined(__linux__)
  171. prctl(PR_SET_PTRACER, child_pid);
  172. close(lock[1]);
  173. close(lock[0]);
  174. #endif
  175. waitpid(child_pid, NULL, 0);
  176. }
  177. }
  178. #else
  179. void ggml_print_backtrace(void) {
  180. // platform not supported
  181. }
  182. #endif
  183. static ggml_abort_callback_t g_abort_callback = NULL;
  184. // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
  185. GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) {
  186. ggml_abort_callback_t ret_val = g_abort_callback;
  187. g_abort_callback = callback;
  188. return ret_val;
  189. }
  190. void ggml_abort(const char * file, int line, const char * fmt, ...) {
  191. fflush(stdout);
  192. char message[2048];
  193. int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
  194. va_list args;
  195. va_start(args, fmt);
  196. vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
  197. va_end(args);
  198. if (g_abort_callback) {
  199. g_abort_callback(message);
  200. } else {
  201. // default: print error and backtrace to stderr
  202. fprintf(stderr, "%s\n", message);
  203. ggml_print_backtrace();
  204. }
  205. abort();
  206. }
  207. // ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
  208. //
  209. // logging
  210. //
  211. struct ggml_logger_state {
  212. ggml_log_callback log_callback;
  213. void * log_callback_user_data;
  214. };
  215. static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
  216. static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
  217. if (format == NULL) {
  218. return;
  219. }
  220. va_list args_copy;
  221. va_copy(args_copy, args);
  222. char buffer[128];
  223. int len = vsnprintf(buffer, 128, format, args);
  224. if (len < 128) {
  225. g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
  226. } else {
  227. char * buffer2 = (char *) calloc(len + 1, sizeof(char));
  228. vsnprintf(buffer2, len + 1, format, args_copy);
  229. buffer2[len] = 0;
  230. g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
  231. free(buffer2);
  232. }
  233. va_end(args_copy);
  234. }
  235. void ggml_log_internal(enum ggml_log_level level, const char * format, ...) {
  236. va_list args;
  237. va_start(args, format);
  238. ggml_log_internal_v(level, format, args);
  239. va_end(args);
  240. }
  241. void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
  242. (void) level;
  243. (void) user_data;
  244. fputs(text, stderr);
  245. fflush(stderr);
  246. }
  247. //
  248. // end of logging block
  249. //
  250. #ifdef GGML_USE_ACCELERATE
  251. // uncomment to use vDSP for soft max computation
  252. // note: not sure if it is actually faster
  253. //#define GGML_SOFT_MAX_ACCELERATE
  254. #endif
  255. void * ggml_aligned_malloc(size_t size) {
  256. #if defined(__s390x__)
  257. const int alignment = 256;
  258. #else
  259. const int alignment = 64;
  260. #endif
  261. #if defined(_MSC_VER) || defined(__MINGW32__)
  262. return _aligned_malloc(size, alignment);
  263. #else
  264. if (size == 0) {
  265. GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
  266. return NULL;
  267. }
  268. void * aligned_memory = NULL;
  269. #ifdef GGML_USE_CPU_HBM
  270. int result = hbw_posix_memalign(&aligned_memory, alignment, size);
  271. #elif TARGET_OS_OSX
  272. GGML_UNUSED(alignment);
  273. kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
  274. int result = EFAULT;
  275. switch (alloc_status) {
  276. case KERN_SUCCESS:
  277. result = 0;
  278. break;
  279. case KERN_INVALID_ADDRESS:
  280. result = EINVAL;
  281. break;
  282. case KERN_NO_SPACE:
  283. result = ENOMEM;
  284. break;
  285. default:
  286. result = EFAULT;
  287. break;
  288. }
  289. #else
  290. int result = posix_memalign(&aligned_memory, alignment, size);
  291. #endif
  292. if (result != 0) {
  293. // Handle allocation failure
  294. const char *error_desc = "unknown allocation error";
  295. switch (result) {
  296. case EINVAL:
  297. error_desc = "invalid alignment value";
  298. break;
  299. case ENOMEM:
  300. error_desc = "insufficient memory";
  301. break;
  302. }
  303. GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
  304. return NULL;
  305. }
  306. return aligned_memory;
  307. #endif
  308. }
  309. void ggml_aligned_free(void * ptr, size_t size) {
  310. GGML_UNUSED(size);
  311. #if defined(_MSC_VER) || defined(__MINGW32__)
  312. _aligned_free(ptr);
  313. #elif GGML_USE_CPU_HBM
  314. if (ptr != NULL) {
  315. hbw_free(ptr);
  316. }
  317. #elif TARGET_OS_OSX
  318. if (ptr != NULL) {
  319. vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
  320. }
  321. #else
  322. free(ptr);
  323. #endif
  324. }
  325. inline static void * ggml_malloc(size_t size) {
  326. if (size == 0) {
  327. GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n");
  328. return NULL;
  329. }
  330. void * result = malloc(size);
  331. if (result == NULL) {
  332. GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
  333. GGML_ABORT("fatal error");
  334. }
  335. return result;
  336. }
  337. // calloc
  338. inline static void * ggml_calloc(size_t num, size_t size) {
  339. if (num == 0 || size == 0) {
  340. GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n");
  341. return NULL;
  342. }
  343. void * result = calloc(num, size);
  344. if (result == NULL) {
  345. GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
  346. GGML_ABORT("fatal error");
  347. }
  348. return result;
  349. }
  350. #define GGML_MALLOC(size) ggml_malloc(size)
  351. #define GGML_CALLOC(num, size) ggml_calloc(num, size)
  352. #define GGML_FREE(ptr) free(ptr)
  353. const char * ggml_status_to_string(enum ggml_status status) {
  354. switch (status) {
  355. case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
  356. case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
  357. case GGML_STATUS_SUCCESS: return "GGML status: success";
  358. case GGML_STATUS_ABORTED: return "GGML status: warning (operation aborted)";
  359. }
  360. return "GGML status: unknown";
  361. }
  362. float ggml_fp16_to_fp32(ggml_fp16_t x) {
  363. #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
  364. return GGML_FP16_TO_FP32(x);
  365. }
  366. ggml_fp16_t ggml_fp32_to_fp16(float x) {
  367. #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
  368. return GGML_FP32_TO_FP16(x);
  369. }
  370. float ggml_bf16_to_fp32(ggml_bf16_t x) {
  371. #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
  372. return GGML_BF16_TO_FP32(x); // it just left shifts
  373. }
  374. ggml_bf16_t ggml_fp32_to_bf16(float x) {
  375. #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
  376. return GGML_FP32_TO_BF16(x);
  377. }
  378. void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
  379. for (int64_t i = 0; i < n; i++) {
  380. y[i] = GGML_FP16_TO_FP32(x[i]);
  381. }
  382. }
  383. void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
  384. int i = 0;
  385. for (; i < n; ++i) {
  386. y[i] = GGML_FP32_TO_FP16(x[i]);
  387. }
  388. }
  389. void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
  390. int i = 0;
  391. for (; i < n; ++i) {
  392. y[i] = GGML_BF16_TO_FP32(x[i]);
  393. }
  394. }
  395. void ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) {
  396. for (int i = 0; i < n; i++) {
  397. y[i] = ggml_compute_fp32_to_bf16(x[i]);
  398. }
  399. }
  400. void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
  401. int i = 0;
  402. #if defined(__AVX512BF16__)
  403. // subnormals are flushed to zero on this platform
  404. for (; i + 32 <= n; i += 32) {
  405. _mm512_storeu_si512(
  406. (__m512i *)(y + i),
  407. m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
  408. _mm512_loadu_ps(x + i))));
  409. }
  410. #endif
  411. for (; i < n; i++) {
  412. y[i] = GGML_FP32_TO_BF16(x[i]);
  413. }
  414. }
  415. bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
  416. return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
  417. }
  418. const char * ggml_version(void) {
  419. return GGML_VERSION;
  420. }
  421. const char * ggml_commit(void) {
  422. return GGML_COMMIT;
  423. }
  424. //
  425. // timing
  426. //
  427. #if defined(_MSC_VER) || defined(__MINGW32__)
  428. static int64_t timer_freq, timer_start;
  429. void ggml_time_init(void) {
  430. LARGE_INTEGER t;
  431. QueryPerformanceFrequency(&t);
  432. timer_freq = t.QuadPart;
  433. // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
  434. // and the uptime is high enough.
  435. // We subtract the program start time to reduce the likelihood of that happening.
  436. QueryPerformanceCounter(&t);
  437. timer_start = t.QuadPart;
  438. }
  439. int64_t ggml_time_ms(void) {
  440. LARGE_INTEGER t;
  441. QueryPerformanceCounter(&t);
  442. return ((t.QuadPart-timer_start) * 1000) / timer_freq;
  443. }
  444. int64_t ggml_time_us(void) {
  445. LARGE_INTEGER t;
  446. QueryPerformanceCounter(&t);
  447. return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
  448. }
  449. #else
  450. void ggml_time_init(void) {}
  451. int64_t ggml_time_ms(void) {
  452. struct timespec ts;
  453. clock_gettime(CLOCK_MONOTONIC, &ts);
  454. return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
  455. }
  456. int64_t ggml_time_us(void) {
  457. struct timespec ts;
  458. clock_gettime(CLOCK_MONOTONIC, &ts);
  459. return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
  460. }
  461. #endif
  462. int64_t ggml_cycles(void) {
  463. return clock();
  464. }
  465. int64_t ggml_cycles_per_ms(void) {
  466. return CLOCKS_PER_SEC/1000;
  467. }
  468. //
  469. // cross-platform UTF-8 file paths
  470. //
  471. #ifdef _WIN32
  472. static wchar_t * ggml_mbstowcs(const char * mbs) {
  473. int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);
  474. if (!wlen) {
  475. errno = EINVAL;
  476. return NULL;
  477. }
  478. wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));
  479. wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);
  480. if (!wlen) {
  481. GGML_FREE(wbuf);
  482. errno = EINVAL;
  483. return NULL;
  484. }
  485. return wbuf;
  486. }
  487. #endif
  488. FILE * ggml_fopen(const char * fname, const char * mode) {
  489. #ifdef _WIN32
  490. FILE * file = NULL;
  491. // convert fname (UTF-8)
  492. wchar_t * wfname = ggml_mbstowcs(fname);
  493. if (wfname) {
  494. // convert mode (ANSI)
  495. wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
  496. wchar_t * wmode_p = wmode;
  497. do {
  498. *wmode_p++ = (wchar_t)*mode;
  499. } while (*mode++);
  500. // open file
  501. file = _wfopen(wfname, wmode);
  502. GGML_FREE(wfname);
  503. GGML_FREE(wmode);
  504. }
  505. return file;
  506. #else
  507. return fopen(fname, mode);
  508. #endif
  509. }
  510. static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
  511. [GGML_TYPE_I8] = {
  512. .type_name = "i8",
  513. .blck_size = 1,
  514. .type_size = sizeof(int8_t),
  515. .is_quantized = false,
  516. },
  517. [GGML_TYPE_I16] = {
  518. .type_name = "i16",
  519. .blck_size = 1,
  520. .type_size = sizeof(int16_t),
  521. .is_quantized = false,
  522. },
  523. [GGML_TYPE_I32] = {
  524. .type_name = "i32",
  525. .blck_size = 1,
  526. .type_size = sizeof(int32_t),
  527. .is_quantized = false,
  528. },
  529. [GGML_TYPE_I64] = {
  530. .type_name = "i64",
  531. .blck_size = 1,
  532. .type_size = sizeof(int64_t),
  533. .is_quantized = false,
  534. },
  535. [GGML_TYPE_F64] = {
  536. .type_name = "f64",
  537. .blck_size = 1,
  538. .type_size = sizeof(double),
  539. .is_quantized = false,
  540. },
  541. [GGML_TYPE_F32] = {
  542. .type_name = "f32",
  543. .blck_size = 1,
  544. .type_size = sizeof(float),
  545. .is_quantized = false,
  546. },
  547. [GGML_TYPE_F16] = {
  548. .type_name = "f16",
  549. .blck_size = 1,
  550. .type_size = sizeof(ggml_fp16_t),
  551. .is_quantized = false,
  552. .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
  553. .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
  554. },
  555. [GGML_TYPE_Q4_0] = {
  556. .type_name = "q4_0",
  557. .blck_size = QK4_0,
  558. .type_size = sizeof(block_q4_0),
  559. .is_quantized = true,
  560. .to_float = (ggml_to_float_t) dequantize_row_q4_0,
  561. .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
  562. },
  563. [GGML_TYPE_Q4_1] = {
  564. .type_name = "q4_1",
  565. .blck_size = QK4_1,
  566. .type_size = sizeof(block_q4_1),
  567. .is_quantized = true,
  568. .to_float = (ggml_to_float_t) dequantize_row_q4_1,
  569. .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
  570. },
  571. [4] = { // GGML_TYPE_Q4_2
  572. .type_name = "DEPRECATED",
  573. .blck_size = 0,
  574. .type_size = 0,
  575. .is_quantized = false,
  576. },
  577. [5] = { // GGML_TYPE_Q4_3
  578. .type_name = "DEPRECATED",
  579. .blck_size = 0,
  580. .type_size = 0,
  581. .is_quantized = false,
  582. },
  583. [GGML_TYPE_Q5_0] = {
  584. .type_name = "q5_0",
  585. .blck_size = QK5_0,
  586. .type_size = sizeof(block_q5_0),
  587. .is_quantized = true,
  588. .to_float = (ggml_to_float_t) dequantize_row_q5_0,
  589. .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
  590. },
  591. [GGML_TYPE_Q5_1] = {
  592. .type_name = "q5_1",
  593. .blck_size = QK5_1,
  594. .type_size = sizeof(block_q5_1),
  595. .is_quantized = true,
  596. .to_float = (ggml_to_float_t) dequantize_row_q5_1,
  597. .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
  598. },
  599. [GGML_TYPE_Q8_0] = {
  600. .type_name = "q8_0",
  601. .blck_size = QK8_0,
  602. .type_size = sizeof(block_q8_0),
  603. .is_quantized = true,
  604. .to_float = (ggml_to_float_t) dequantize_row_q8_0,
  605. .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
  606. },
  607. [GGML_TYPE_Q8_1] = {
  608. .type_name = "q8_1",
  609. .blck_size = QK8_1,
  610. .type_size = sizeof(block_q8_1),
  611. .is_quantized = true,
  612. .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
  613. },
  614. [GGML_TYPE_MXFP4] = {
  615. .type_name = "mxfp4",
  616. .blck_size = QK_MXFP4,
  617. .type_size = sizeof(block_mxfp4),
  618. .is_quantized = true,
  619. .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
  620. .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
  621. },
  622. [GGML_TYPE_Q2_K] = {
  623. .type_name = "q2_K",
  624. .blck_size = QK_K,
  625. .type_size = sizeof(block_q2_K),
  626. .is_quantized = true,
  627. .to_float = (ggml_to_float_t) dequantize_row_q2_K,
  628. .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
  629. },
  630. [GGML_TYPE_Q3_K] = {
  631. .type_name = "q3_K",
  632. .blck_size = QK_K,
  633. .type_size = sizeof(block_q3_K),
  634. .is_quantized = true,
  635. .to_float = (ggml_to_float_t) dequantize_row_q3_K,
  636. .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
  637. },
  638. [GGML_TYPE_Q4_K] = {
  639. .type_name = "q4_K",
  640. .blck_size = QK_K,
  641. .type_size = sizeof(block_q4_K),
  642. .is_quantized = true,
  643. .to_float = (ggml_to_float_t) dequantize_row_q4_K,
  644. .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
  645. },
  646. [GGML_TYPE_Q5_K] = {
  647. .type_name = "q5_K",
  648. .blck_size = QK_K,
  649. .type_size = sizeof(block_q5_K),
  650. .is_quantized = true,
  651. .to_float = (ggml_to_float_t) dequantize_row_q5_K,
  652. .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
  653. },
  654. [GGML_TYPE_Q6_K] = {
  655. .type_name = "q6_K",
  656. .blck_size = QK_K,
  657. .type_size = sizeof(block_q6_K),
  658. .is_quantized = true,
  659. .to_float = (ggml_to_float_t) dequantize_row_q6_K,
  660. .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
  661. },
  662. [GGML_TYPE_IQ2_XXS] = {
  663. .type_name = "iq2_xxs",
  664. .blck_size = QK_K,
  665. .type_size = sizeof(block_iq2_xxs),
  666. .is_quantized = true,
  667. .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
  668. .from_float_ref = NULL,
  669. },
  670. [GGML_TYPE_IQ2_XS] = {
  671. .type_name = "iq2_xs",
  672. .blck_size = QK_K,
  673. .type_size = sizeof(block_iq2_xs),
  674. .is_quantized = true,
  675. .to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
  676. .from_float_ref = NULL,
  677. },
  678. [GGML_TYPE_IQ3_XXS] = {
  679. .type_name = "iq3_xxs",
  680. .blck_size = QK_K,
  681. .type_size = sizeof(block_iq3_xxs),
  682. .is_quantized = true,
  683. .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
  684. .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
  685. },
  686. [GGML_TYPE_IQ3_S] = {
  687. .type_name = "iq3_s",
  688. .blck_size = QK_K,
  689. .type_size = sizeof(block_iq3_s),
  690. .is_quantized = true,
  691. .to_float = (ggml_to_float_t) dequantize_row_iq3_s,
  692. .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
  693. },
  694. [GGML_TYPE_IQ2_S] = {
  695. .type_name = "iq2_s",
  696. .blck_size = QK_K,
  697. .type_size = sizeof(block_iq2_s),
  698. .is_quantized = true,
  699. .to_float = (ggml_to_float_t) dequantize_row_iq2_s,
  700. .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
  701. },
  702. [GGML_TYPE_IQ1_S] = {
  703. .type_name = "iq1_s",
  704. .blck_size = QK_K,
  705. .type_size = sizeof(block_iq1_s),
  706. .is_quantized = true,
  707. .to_float = (ggml_to_float_t) dequantize_row_iq1_s,
  708. .from_float_ref = NULL,
  709. },
  710. [GGML_TYPE_IQ1_M] = {
  711. .type_name = "iq1_m",
  712. .blck_size = QK_K,
  713. .type_size = sizeof(block_iq1_m),
  714. .is_quantized = true,
  715. .to_float = (ggml_to_float_t) dequantize_row_iq1_m,
  716. .from_float_ref = NULL,
  717. },
  718. [GGML_TYPE_IQ4_NL] = {
  719. .type_name = "iq4_nl",
  720. .blck_size = QK4_NL,
  721. .type_size = sizeof(block_iq4_nl),
  722. .is_quantized = true,
  723. .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
  724. .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
  725. },
  726. [GGML_TYPE_IQ4_XS] = {
  727. .type_name = "iq4_xs",
  728. .blck_size = QK_K,
  729. .type_size = sizeof(block_iq4_xs),
  730. .is_quantized = true,
  731. .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
  732. .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
  733. },
  734. [GGML_TYPE_Q8_K] = {
  735. .type_name = "q8_K",
  736. .blck_size = QK_K,
  737. .type_size = sizeof(block_q8_K),
  738. .is_quantized = true,
  739. },
  740. [GGML_TYPE_BF16] = {
  741. .type_name = "bf16",
  742. .blck_size = 1,
  743. .type_size = sizeof(ggml_bf16_t),
  744. .is_quantized = false,
  745. .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
  746. .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
  747. },
  748. [31] = { // GGML_TYPE_Q4_0_4_4
  749. .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
  750. .blck_size = 0,
  751. .type_size = 0,
  752. .is_quantized = false,
  753. },
  754. [32] = { // GGML_TYPE_Q4_0_4_8
  755. .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
  756. .blck_size = 0,
  757. .type_size = 0,
  758. .is_quantized = false,
  759. },
  760. [33] = { // GGML_TYPE_Q4_0_8_8
  761. .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
  762. .blck_size = 0,
  763. .type_size = 0,
  764. .is_quantized = false,
  765. },
  766. [GGML_TYPE_TQ1_0] = {
  767. .type_name = "tq1_0",
  768. .blck_size = QK_K,
  769. .type_size = sizeof(block_tq1_0),
  770. .is_quantized = true,
  771. .to_float = (ggml_to_float_t) dequantize_row_tq1_0,
  772. .from_float_ref = (ggml_from_float_t) quantize_row_tq1_0_ref,
  773. },
  774. [GGML_TYPE_TQ2_0] = {
  775. .type_name = "tq2_0",
  776. .blck_size = QK_K,
  777. .type_size = sizeof(block_tq2_0),
  778. .is_quantized = true,
  779. .to_float = (ggml_to_float_t) dequantize_row_tq2_0,
  780. .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref,
  781. },
  782. [36] = { // GGML_TYPE_IQ4_NL_4_4
  783. .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
  784. .blck_size = 0,
  785. .type_size = 0,
  786. .is_quantized = false,
  787. },
  788. [37] = { // GGML_TYPE_IQ4_NL_4_8
  789. .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
  790. .blck_size = 0,
  791. .type_size = 0,
  792. .is_quantized = false,
  793. },
  794. [38] = { // GGML_TYPE_IQ4_NL_8_8
  795. .type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
  796. .blck_size = 0,
  797. .type_size = 0,
  798. .is_quantized = false,
  799. },
  800. };
  801. const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
  802. GGML_ASSERT(type < GGML_TYPE_COUNT);
  803. return &type_traits[type];
  804. }
  805. //
  806. // ggml object
  807. //
  808. struct ggml_object {
  809. size_t offs;
  810. size_t size;
  811. struct ggml_object * next;
  812. enum ggml_object_type type;
  813. char padding[4];
  814. };
  815. static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
  816. //
  817. // ggml context
  818. //
  819. struct ggml_context {
  820. size_t mem_size;
  821. void * mem_buffer;
  822. bool mem_buffer_owned;
  823. bool no_alloc;
  824. int n_objects;
  825. struct ggml_object * objects_begin;
  826. struct ggml_object * objects_end;
  827. };
  828. //
  829. // data types
  830. //
  831. static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
  832. "NONE",
  833. "DUP",
  834. "ADD",
  835. "ADD_ID",
  836. "ADD1",
  837. "ACC",
  838. "SUB",
  839. "MUL",
  840. "DIV",
  841. "SQR",
  842. "SQRT",
  843. "LOG",
  844. "SIN",
  845. "COS",
  846. "SUM",
  847. "SUM_ROWS",
  848. "MEAN",
  849. "ARGMAX",
  850. "COUNT_EQUAL",
  851. "REPEAT",
  852. "REPEAT_BACK",
  853. "CONCAT",
  854. "SILU_BACK",
  855. "NORM",
  856. "RMS_NORM",
  857. "RMS_NORM_BACK",
  858. "GROUP_NORM",
  859. "L2_NORM",
  860. "MUL_MAT",
  861. "MUL_MAT_ID",
  862. "OUT_PROD",
  863. "SCALE",
  864. "SET",
  865. "CPY",
  866. "CONT",
  867. "RESHAPE",
  868. "VIEW",
  869. "PERMUTE",
  870. "TRANSPOSE",
  871. "GET_ROWS",
  872. "GET_ROWS_BACK",
  873. "SET_ROWS",
  874. "DIAG",
  875. "DIAG_MASK_INF",
  876. "DIAG_MASK_ZERO",
  877. "SOFT_MAX",
  878. "SOFT_MAX_BACK",
  879. "ROPE",
  880. "ROPE_BACK",
  881. "CLAMP",
  882. "CONV_TRANSPOSE_1D",
  883. "IM2COL",
  884. "IM2COL_BACK",
  885. "IM2COL_3D",
  886. "CONV_2D",
  887. "CONV_3D",
  888. "CONV_2D_DW",
  889. "CONV_TRANSPOSE_2D",
  890. "POOL_1D",
  891. "POOL_2D",
  892. "POOL_2D_BACK",
  893. "UPSCALE",
  894. "PAD",
  895. "PAD_REFLECT_1D",
  896. "ROLL",
  897. "ARANGE",
  898. "TIMESTEP_EMBEDDING",
  899. "ARGSORT",
  900. "LEAKY_RELU",
  901. "FLASH_ATTN_EXT",
  902. "FLASH_ATTN_BACK",
  903. "SSM_CONV",
  904. "SSM_SCAN",
  905. "WIN_PART",
  906. "WIN_UNPART",
  907. "GET_REL_POS",
  908. "ADD_REL_POS",
  909. "RWKV_WKV6",
  910. "GATED_LINEAR_ATTN",
  911. "RWKV_WKV7",
  912. "UNARY",
  913. "MAP_CUSTOM1",
  914. "MAP_CUSTOM2",
  915. "MAP_CUSTOM3",
  916. "CUSTOM",
  917. "CROSS_ENTROPY_LOSS",
  918. "CROSS_ENTROPY_LOSS_BACK",
  919. "OPT_STEP_ADAMW",
  920. "OPT_STEP_SGD",
  921. "GLU",
  922. };
  923. static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
  924. static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
  925. "none",
  926. "x",
  927. "x+y",
  928. "x[i]+y",
  929. "x+y",
  930. "view(x,nb,offset)+=y->x",
  931. "x-y",
  932. "x*y",
  933. "x/y",
  934. "x^2",
  935. "√x",
  936. "log(x)",
  937. "sin(x)",
  938. "cos(x)",
  939. "Σx",
  940. "Σx_k",
  941. "Σx/n",
  942. "argmax(x)",
  943. "count_equal(x)",
  944. "repeat(x)",
  945. "repeat_back(x)",
  946. "concat(x, y)",
  947. "silu_back(x)",
  948. "norm(x)",
  949. "rms_norm(x)",
  950. "rms_norm_back(x)",
  951. "group_norm(x)",
  952. "l2_norm(x)",
  953. "X*Y",
  954. "X[i]*Y",
  955. "X*Y",
  956. "x*v",
  957. "y-\\>view(x)",
  958. "x-\\>y",
  959. "cont(x)",
  960. "reshape(x)",
  961. "view(x)",
  962. "permute(x)",
  963. "transpose(x)",
  964. "get_rows(x)",
  965. "get_rows_back(x)",
  966. "set_rows(x)",
  967. "diag(x)",
  968. "diag_mask_inf(x)",
  969. "diag_mask_zero(x)",
  970. "soft_max(x)",
  971. "soft_max_back(x)",
  972. "rope(x)",
  973. "rope_back(x)",
  974. "clamp(x)",
  975. "conv_transpose_1d(x)",
  976. "im2col(x)",
  977. "im2col_back(x)",
  978. "im2col_3d(x)",
  979. "conv_2d(x)",
  980. "conv_3d(x)",
  981. "conv_2d_dw(x)",
  982. "conv_transpose_2d(x)",
  983. "pool_1d(x)",
  984. "pool_2d(x)",
  985. "pool_2d_back(x)",
  986. "upscale(x)",
  987. "pad(x)",
  988. "pad_reflect_1d(x)",
  989. "roll(x)",
  990. "arange(start, stop, step)",
  991. "timestep_embedding(timesteps, dim, max_period)",
  992. "argsort(x)",
  993. "leaky_relu(x)",
  994. "flash_attn_ext(x)",
  995. "flash_attn_back(x)",
  996. "ssm_conv(x)",
  997. "ssm_scan(x)",
  998. "win_part(x)",
  999. "win_unpart(x)",
  1000. "get_rel_pos(x)",
  1001. "add_rel_pos(x)",
  1002. "rwkv_wkv6(k, v, r, tf, td, s)",
  1003. "gated_linear_attn(k, v, q, gate, s)",
  1004. "rwkv_wkv7(r, w, k, v, a, b, s)",
  1005. "unary(x)",
  1006. "map_custom(x)",
  1007. "map_custom(x,y)",
  1008. "map_custom(x,y,z)",
  1009. "custom(x)",
  1010. "cross_entropy_loss(x,y)",
  1011. "cross_entropy_loss_back(x,y)",
  1012. "adamw(x)",
  1013. "sgd(x)",
  1014. "glu(x)",
  1015. };
  1016. static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
  1017. static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
  1018. static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
  1019. "ABS",
  1020. "SGN",
  1021. "NEG",
  1022. "STEP",
  1023. "TANH",
  1024. "ELU",
  1025. "RELU",
  1026. "SIGMOID",
  1027. "GELU",
  1028. "GELU_QUICK",
  1029. "SILU",
  1030. "HARDSWISH",
  1031. "HARDSIGMOID",
  1032. "EXP",
  1033. "GELU_ERF",
  1034. "XIELU",
  1035. "FLOOR",
  1036. "CEIL",
  1037. "ROUND",
  1038. "TRUNC",
  1039. };
  1040. static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20");
  1041. static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
  1042. "REGLU",
  1043. "GEGLU",
  1044. "SWIGLU",
  1045. "SWIGLU_OAI",
  1046. "GEGLU_ERF",
  1047. "GEGLU_QUICK",
  1048. };
  1049. static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
  1050. static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
  1051. static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
  1052. ////////////////////////////////////////////////////////////////////////////////
  1053. void ggml_print_object(const struct ggml_object * obj) {
  1054. GGML_LOG_INFO(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
  1055. obj->type, obj->offs, obj->size, (const void *) obj->next);
  1056. }
  1057. void ggml_print_objects(const struct ggml_context * ctx) {
  1058. struct ggml_object * obj = ctx->objects_begin;
  1059. GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx);
  1060. while (obj != NULL) {
  1061. ggml_print_object(obj);
  1062. obj = obj->next;
  1063. }
  1064. GGML_LOG_INFO("%s: --- end ---\n", __func__);
  1065. }
  1066. int64_t ggml_nelements(const struct ggml_tensor * tensor) {
  1067. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1068. return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
  1069. }
  1070. int64_t ggml_nrows(const struct ggml_tensor * tensor) {
  1071. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1072. return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
  1073. }
  1074. size_t ggml_nbytes(const struct ggml_tensor * tensor) {
  1075. for (int i = 0; i < GGML_MAX_DIMS; ++i) {
  1076. if (tensor->ne[i] <= 0) {
  1077. return 0;
  1078. }
  1079. }
  1080. size_t nbytes;
  1081. const size_t blck_size = ggml_blck_size(tensor->type);
  1082. if (blck_size == 1) {
  1083. nbytes = ggml_type_size(tensor->type);
  1084. for (int i = 0; i < GGML_MAX_DIMS; ++i) {
  1085. nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
  1086. }
  1087. }
  1088. else {
  1089. nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
  1090. for (int i = 1; i < GGML_MAX_DIMS; ++i) {
  1091. nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
  1092. }
  1093. }
  1094. return nbytes;
  1095. }
  1096. size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
  1097. return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
  1098. }
  1099. int64_t ggml_blck_size(enum ggml_type type) {
  1100. return type_traits[type].blck_size;
  1101. }
  1102. size_t ggml_type_size(enum ggml_type type) {
  1103. return type_traits[type].type_size;
  1104. }
  1105. size_t ggml_row_size(enum ggml_type type, int64_t ne) {
  1106. assert(ne % ggml_blck_size(type) == 0);
  1107. return ggml_type_size(type)*ne/ggml_blck_size(type);
  1108. }
  1109. double ggml_type_sizef(enum ggml_type type) {
  1110. return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
  1111. }
  1112. const char * ggml_type_name(enum ggml_type type) {
  1113. return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
  1114. }
  1115. bool ggml_is_quantized(enum ggml_type type) {
  1116. return type_traits[type].is_quantized;
  1117. }
  1118. const char * ggml_op_name(enum ggml_op op) {
  1119. return GGML_OP_NAME[op];
  1120. }
  1121. const char * ggml_op_symbol(enum ggml_op op) {
  1122. return GGML_OP_SYMBOL[op];
  1123. }
  1124. const char * ggml_unary_op_name(enum ggml_unary_op op) {
  1125. return GGML_UNARY_OP_NAME[op];
  1126. }
  1127. const char * ggml_glu_op_name(enum ggml_glu_op op) {
  1128. return GGML_GLU_OP_NAME[op];
  1129. }
  1130. const char * ggml_op_desc(const struct ggml_tensor * t) {
  1131. if (t->op == GGML_OP_UNARY) {
  1132. enum ggml_unary_op uop = ggml_get_unary_op(t);
  1133. return ggml_unary_op_name(uop);
  1134. }
  1135. if (t->op == GGML_OP_GLU) {
  1136. enum ggml_glu_op gop = ggml_get_glu_op(t);
  1137. return ggml_glu_op_name(gop);
  1138. }
  1139. return ggml_op_name(t->op);
  1140. }
  1141. size_t ggml_element_size(const struct ggml_tensor * tensor) {
  1142. return ggml_type_size(tensor->type);
  1143. }
  1144. bool ggml_is_scalar(const struct ggml_tensor * tensor) {
  1145. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1146. return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
  1147. }
  1148. bool ggml_is_vector(const struct ggml_tensor * tensor) {
  1149. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1150. return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
  1151. }
  1152. bool ggml_is_matrix(const struct ggml_tensor * tensor) {
  1153. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1154. return tensor->ne[2] == 1 && tensor->ne[3] == 1;
  1155. }
  1156. bool ggml_is_3d(const struct ggml_tensor * tensor) {
  1157. return tensor->ne[3] == 1;
  1158. }
  1159. int ggml_n_dims(const struct ggml_tensor * tensor) {
  1160. for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {
  1161. if (tensor->ne[i] > 1) {
  1162. return i + 1;
  1163. }
  1164. }
  1165. return 1;
  1166. }
  1167. enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
  1168. enum ggml_type wtype = GGML_TYPE_COUNT;
  1169. switch (ftype) {
  1170. case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
  1171. case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
  1172. case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
  1173. case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
  1174. case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
  1175. case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
  1176. case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
  1177. case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
  1178. case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
  1179. case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
  1180. case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
  1181. case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
  1182. case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
  1183. case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
  1184. case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
  1185. case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
  1186. case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
  1187. case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
  1188. case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
  1189. case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
  1190. case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
  1191. case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
  1192. case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
  1193. case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
  1194. case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
  1195. }
  1196. GGML_ASSERT(wtype != GGML_TYPE_COUNT);
  1197. return wtype;
  1198. }
  1199. size_t ggml_tensor_overhead(void) {
  1200. return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
  1201. }
  1202. bool ggml_is_transposed(const struct ggml_tensor * tensor) {
  1203. return tensor->nb[0] > tensor->nb[1];
  1204. }
  1205. static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
  1206. size_t next_nb = ggml_type_size(tensor->type);
  1207. if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
  1208. return false;
  1209. }
  1210. next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
  1211. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1212. if (tensor->ne[i] != 1) {
  1213. if (i > n) {
  1214. if (tensor->nb[i] != next_nb) {
  1215. return false;
  1216. }
  1217. next_nb *= tensor->ne[i];
  1218. } else {
  1219. // this dimension does not need to be contiguous
  1220. next_nb = tensor->ne[i]*tensor->nb[i];
  1221. }
  1222. }
  1223. }
  1224. return true;
  1225. }
  1226. bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
  1227. return ggml_is_contiguous_0(tensor);
  1228. }
  1229. bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
  1230. return ggml_is_contiguous_n(tensor, 0);
  1231. }
  1232. bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
  1233. return ggml_is_contiguous_n(tensor, 1);
  1234. }
  1235. bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
  1236. return ggml_is_contiguous_n(tensor, 2);
  1237. }
  1238. bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
  1239. return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
  1240. }
  1241. bool ggml_is_permuted(const struct ggml_tensor * tensor) {
  1242. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1243. return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
  1244. }
  1245. bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
  1246. return
  1247. tensor->nb[0] > tensor->nb[2] &&
  1248. tensor->nb[1] > tensor->nb[0] &&
  1249. tensor->nb[2] == ggml_type_size(tensor->type);
  1250. }
  1251. bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
  1252. return
  1253. tensor->ne[0] == ggml_blck_size(tensor->type) ||
  1254. tensor->nb[0] == ggml_type_size(tensor->type);
  1255. }
  1256. static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
  1257. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1258. return
  1259. tensor->nb[0] == ggml_type_size(tensor->type) &&
  1260. tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
  1261. tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
  1262. }
  1263. bool ggml_is_empty(const struct ggml_tensor * tensor) {
  1264. for (int i = 0; i < GGML_MAX_DIMS; ++i) {
  1265. if (tensor->ne[i] == 0) {
  1266. // empty if any dimension has no elements
  1267. return true;
  1268. }
  1269. }
  1270. return false;
  1271. }
  1272. bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1273. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1274. return
  1275. (t0->ne[0] == t1->ne[0]) &&
  1276. (t0->ne[1] == t1->ne[1]) &&
  1277. (t0->ne[2] == t1->ne[2]) &&
  1278. (t0->ne[3] == t1->ne[3]);
  1279. }
  1280. bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1281. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1282. return
  1283. (t0->nb[0] == t1->nb[0]) &&
  1284. (t0->nb[1] == t1->nb[1]) &&
  1285. (t0->nb[2] == t1->nb[2]) &&
  1286. (t0->nb[3] == t1->nb[3]);
  1287. }
  1288. // check if t1 can be represented as a repetition of t0
  1289. bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1290. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1291. return ggml_is_empty(t0) ? ggml_is_empty(t1) :
  1292. (t1->ne[0]%t0->ne[0] == 0) &&
  1293. (t1->ne[1]%t0->ne[1] == 0) &&
  1294. (t1->ne[2]%t0->ne[2] == 0) &&
  1295. (t1->ne[3]%t0->ne[3] == 0);
  1296. }
  1297. static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1298. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1299. return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
  1300. }
  1301. // assert that pointer is aligned to GGML_MEM_ALIGN
  1302. #define GGML_ASSERT_ALIGNED(ptr) \
  1303. GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
  1304. ////////////////////////////////////////////////////////////////////////////////
  1305. struct ggml_context * ggml_init(struct ggml_init_params params) {
  1306. static bool is_first_call = true;
  1307. ggml_critical_section_start();
  1308. if (is_first_call) {
  1309. // initialize time system (required on Windows)
  1310. ggml_time_init();
  1311. is_first_call = false;
  1312. }
  1313. ggml_critical_section_end();
  1314. struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
  1315. // allow to call ggml_init with 0 size
  1316. if (params.mem_size == 0) {
  1317. params.mem_size = GGML_MEM_ALIGN;
  1318. }
  1319. const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
  1320. *ctx = (struct ggml_context) {
  1321. /*.mem_size =*/ mem_size,
  1322. /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
  1323. /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
  1324. /*.no_alloc =*/ params.no_alloc,
  1325. /*.n_objects =*/ 0,
  1326. /*.objects_begin =*/ NULL,
  1327. /*.objects_end =*/ NULL,
  1328. };
  1329. GGML_ASSERT(ctx->mem_buffer != NULL);
  1330. GGML_ASSERT_ALIGNED(ctx->mem_buffer);
  1331. GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
  1332. return ctx;
  1333. }
  1334. void ggml_reset(struct ggml_context * ctx) {
  1335. if (ctx == NULL) {
  1336. return;
  1337. }
  1338. ctx->n_objects = 0;
  1339. ctx->objects_begin = NULL;
  1340. ctx->objects_end = NULL;
  1341. }
  1342. void ggml_free(struct ggml_context * ctx) {
  1343. if (ctx == NULL) {
  1344. return;
  1345. }
  1346. if (ctx->mem_buffer_owned) {
  1347. ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
  1348. }
  1349. GGML_FREE(ctx);
  1350. }
  1351. size_t ggml_used_mem(const struct ggml_context * ctx) {
  1352. return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
  1353. }
  1354. bool ggml_get_no_alloc(struct ggml_context * ctx) {
  1355. return ctx->no_alloc;
  1356. }
  1357. void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
  1358. ctx->no_alloc = no_alloc;
  1359. }
  1360. void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
  1361. return ctx->mem_buffer;
  1362. }
  1363. size_t ggml_get_mem_size(const struct ggml_context * ctx) {
  1364. return ctx->mem_size;
  1365. }
  1366. size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
  1367. size_t max_size = 0;
  1368. for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {
  1369. size_t bytes = ggml_nbytes(tensor);
  1370. max_size = MAX(max_size, bytes);
  1371. }
  1372. return max_size;
  1373. }
  1374. ////////////////////////////////////////////////////////////////////////////////
  1375. static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
  1376. // always insert objects at the end of the context's memory pool
  1377. struct ggml_object * obj_cur = ctx->objects_end;
  1378. const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
  1379. const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
  1380. const size_t cur_end = cur_offs + cur_size;
  1381. // align to GGML_MEM_ALIGN
  1382. size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
  1383. char * const mem_buffer = ctx->mem_buffer;
  1384. struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
  1385. if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
  1386. GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
  1387. __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
  1388. #ifndef NDEBUG
  1389. GGML_ABORT("not enough space in the context's memory pool");
  1390. #endif
  1391. return NULL;
  1392. }
  1393. *obj_new = (struct ggml_object) {
  1394. .offs = cur_end + GGML_OBJECT_SIZE,
  1395. .size = size_needed,
  1396. .next = NULL,
  1397. .type = type,
  1398. };
  1399. GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs);
  1400. if (obj_cur != NULL) {
  1401. obj_cur->next = obj_new;
  1402. } else {
  1403. // this is the first object in this context
  1404. ctx->objects_begin = obj_new;
  1405. }
  1406. ctx->objects_end = obj_new;
  1407. //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
  1408. return obj_new;
  1409. }
  1410. static struct ggml_tensor * ggml_new_tensor_impl(
  1411. struct ggml_context * ctx,
  1412. enum ggml_type type,
  1413. int n_dims,
  1414. const int64_t * ne,
  1415. struct ggml_tensor * view_src,
  1416. size_t view_offs) {
  1417. GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);
  1418. GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
  1419. // find the base tensor and absolute offset
  1420. if (view_src != NULL && view_src->view_src != NULL) {
  1421. view_offs += view_src->view_offs;
  1422. view_src = view_src->view_src;
  1423. }
  1424. size_t data_size = ggml_row_size(type, ne[0]);
  1425. for (int i = 1; i < n_dims; i++) {
  1426. data_size *= ne[i];
  1427. }
  1428. GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));
  1429. void * data = view_src != NULL ? view_src->data : NULL;
  1430. if (data != NULL) {
  1431. data = (char *) data + view_offs;
  1432. }
  1433. size_t obj_alloc_size = 0;
  1434. if (view_src == NULL && !ctx->no_alloc) {
  1435. // allocate tensor data in the context's memory pool
  1436. obj_alloc_size = data_size;
  1437. }
  1438. struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
  1439. GGML_ASSERT(obj_new);
  1440. struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
  1441. *result = (struct ggml_tensor) {
  1442. /*.type =*/ type,
  1443. /*.buffer =*/ NULL,
  1444. /*.ne =*/ { 1, 1, 1, 1 },
  1445. /*.nb =*/ { 0, 0, 0, 0 },
  1446. /*.op =*/ GGML_OP_NONE,
  1447. /*.op_params =*/ { 0 },
  1448. /*.flags =*/ 0,
  1449. /*.src =*/ { NULL },
  1450. /*.view_src =*/ view_src,
  1451. /*.view_offs =*/ view_offs,
  1452. /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
  1453. /*.name =*/ { 0 },
  1454. /*.extra =*/ NULL,
  1455. /*.padding =*/ { 0 },
  1456. };
  1457. // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
  1458. //GGML_ASSERT_ALIGNED(result->data);
  1459. for (int i = 0; i < n_dims; i++) {
  1460. result->ne[i] = ne[i];
  1461. }
  1462. result->nb[0] = ggml_type_size(type);
  1463. result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
  1464. for (int i = 2; i < GGML_MAX_DIMS; i++) {
  1465. result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
  1466. }
  1467. ctx->n_objects++;
  1468. return result;
  1469. }
  1470. struct ggml_tensor * ggml_new_tensor(
  1471. struct ggml_context * ctx,
  1472. enum ggml_type type,
  1473. int n_dims,
  1474. const int64_t * ne) {
  1475. return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
  1476. }
  1477. struct ggml_tensor * ggml_new_tensor_1d(
  1478. struct ggml_context * ctx,
  1479. enum ggml_type type,
  1480. int64_t ne0) {
  1481. return ggml_new_tensor(ctx, type, 1, &ne0);
  1482. }
  1483. struct ggml_tensor * ggml_new_tensor_2d(
  1484. struct ggml_context * ctx,
  1485. enum ggml_type type,
  1486. int64_t ne0,
  1487. int64_t ne1) {
  1488. const int64_t ne[2] = { ne0, ne1 };
  1489. return ggml_new_tensor(ctx, type, 2, ne);
  1490. }
  1491. struct ggml_tensor * ggml_new_tensor_3d(
  1492. struct ggml_context * ctx,
  1493. enum ggml_type type,
  1494. int64_t ne0,
  1495. int64_t ne1,
  1496. int64_t ne2) {
  1497. const int64_t ne[3] = { ne0, ne1, ne2 };
  1498. return ggml_new_tensor(ctx, type, 3, ne);
  1499. }
  1500. struct ggml_tensor * ggml_new_tensor_4d(
  1501. struct ggml_context * ctx,
  1502. enum ggml_type type,
  1503. int64_t ne0,
  1504. int64_t ne1,
  1505. int64_t ne2,
  1506. int64_t ne3) {
  1507. const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
  1508. return ggml_new_tensor(ctx, type, 4, ne);
  1509. }
  1510. void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) {
  1511. struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes);
  1512. return (uint8_t *)ctx->mem_buffer + obj->offs;
  1513. }
  1514. struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
  1515. return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne);
  1516. }
  1517. void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
  1518. const int64_t ne2 = tensor->ne[2];
  1519. const int64_t ne1 = tensor->ne[1];
  1520. const int64_t ne0 = tensor->ne[0];
  1521. const int64_t i3_ = (i/(ne2*ne1*ne0));
  1522. const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
  1523. const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
  1524. const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
  1525. if (i0) {
  1526. * i0 = i0_;
  1527. }
  1528. if (i1) {
  1529. * i1 = i1_;
  1530. }
  1531. if (i2) {
  1532. * i2 = i2_;
  1533. }
  1534. if (i3) {
  1535. * i3 = i3_;
  1536. }
  1537. }
  1538. void * ggml_get_data(const struct ggml_tensor * tensor) {
  1539. return tensor->data;
  1540. }
  1541. float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
  1542. assert(tensor->type == GGML_TYPE_F32);
  1543. return (float *)(tensor->data);
  1544. }
  1545. enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
  1546. GGML_ASSERT(tensor->op == GGML_OP_UNARY);
  1547. return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
  1548. }
  1549. enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
  1550. GGML_ASSERT(tensor->op == GGML_OP_GLU);
  1551. return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
  1552. }
  1553. const char * ggml_get_name(const struct ggml_tensor * tensor) {
  1554. return tensor->name;
  1555. }
  1556. struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
  1557. size_t i;
  1558. for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) {
  1559. tensor->name[i] = name[i];
  1560. }
  1561. tensor->name[i] = '\0';
  1562. return tensor;
  1563. }
  1564. struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
  1565. va_list args;
  1566. va_start(args, fmt);
  1567. vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
  1568. va_end(args);
  1569. return tensor;
  1570. }
  1571. struct ggml_tensor * ggml_view_tensor(
  1572. struct ggml_context * ctx,
  1573. struct ggml_tensor * src) {
  1574. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0);
  1575. ggml_format_name(result, "%s (view)", src->name);
  1576. for (int i = 0; i < GGML_MAX_DIMS; i++) {
  1577. result->nb[i] = src->nb[i];
  1578. }
  1579. return result;
  1580. }
  1581. struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) {
  1582. struct ggml_object * obj = ctx->objects_begin;
  1583. char * const mem_buffer = ctx->mem_buffer;
  1584. while (obj != NULL) {
  1585. if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
  1586. return (struct ggml_tensor *)(mem_buffer + obj->offs);
  1587. }
  1588. obj = obj->next;
  1589. }
  1590. return NULL;
  1591. }
  1592. struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) {
  1593. struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
  1594. obj = obj->next;
  1595. char * const mem_buffer = ctx->mem_buffer;
  1596. while (obj != NULL) {
  1597. if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
  1598. return (struct ggml_tensor *)(mem_buffer + obj->offs);
  1599. }
  1600. obj = obj->next;
  1601. }
  1602. return NULL;
  1603. }
  1604. struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
  1605. struct ggml_object * obj = ctx->objects_begin;
  1606. char * const mem_buffer = ctx->mem_buffer;
  1607. while (obj != NULL) {
  1608. if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
  1609. struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
  1610. if (strcmp(cur->name, name) == 0) {
  1611. return cur;
  1612. }
  1613. }
  1614. obj = obj->next;
  1615. }
  1616. return NULL;
  1617. }
  1618. ////////////////////////////////////////////////////////////////////////////////
  1619. // ggml_dup
  1620. static struct ggml_tensor * ggml_dup_impl(
  1621. struct ggml_context * ctx,
  1622. struct ggml_tensor * a,
  1623. bool inplace) {
  1624. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1625. result->op = GGML_OP_DUP;
  1626. result->src[0] = a;
  1627. return result;
  1628. }
  1629. struct ggml_tensor * ggml_dup(
  1630. struct ggml_context * ctx,
  1631. struct ggml_tensor * a) {
  1632. return ggml_dup_impl(ctx, a, false);
  1633. }
  1634. struct ggml_tensor * ggml_dup_inplace(
  1635. struct ggml_context * ctx,
  1636. struct ggml_tensor * a) {
  1637. return ggml_dup_impl(ctx, a, true);
  1638. }
  1639. // ggml_add
  1640. static struct ggml_tensor * ggml_add_impl(
  1641. struct ggml_context * ctx,
  1642. struct ggml_tensor * a,
  1643. struct ggml_tensor * b,
  1644. bool inplace) {
  1645. GGML_ASSERT(ggml_can_repeat(b, a));
  1646. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1647. result->op = GGML_OP_ADD;
  1648. result->src[0] = a;
  1649. result->src[1] = b;
  1650. return result;
  1651. }
  1652. struct ggml_tensor * ggml_add(
  1653. struct ggml_context * ctx,
  1654. struct ggml_tensor * a,
  1655. struct ggml_tensor * b) {
  1656. return ggml_add_impl(ctx, a, b, false);
  1657. }
  1658. struct ggml_tensor * ggml_add_inplace(
  1659. struct ggml_context * ctx,
  1660. struct ggml_tensor * a,
  1661. struct ggml_tensor * b) {
  1662. return ggml_add_impl(ctx, a, b, true);
  1663. }
  1664. // ggml_add_cast
  1665. static struct ggml_tensor * ggml_add_cast_impl(
  1666. struct ggml_context * ctx,
  1667. struct ggml_tensor * a,
  1668. struct ggml_tensor * b,
  1669. enum ggml_type type) {
  1670. // TODO: support less-strict constraint
  1671. // GGML_ASSERT(ggml_can_repeat(b, a));
  1672. GGML_ASSERT(ggml_can_repeat_rows(b, a));
  1673. // currently only supported for quantized input and f16
  1674. GGML_ASSERT(ggml_is_quantized(a->type) ||
  1675. a->type == GGML_TYPE_F16 ||
  1676. a->type == GGML_TYPE_BF16);
  1677. struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
  1678. result->op = GGML_OP_ADD;
  1679. result->src[0] = a;
  1680. result->src[1] = b;
  1681. return result;
  1682. }
  1683. struct ggml_tensor * ggml_add_cast(
  1684. struct ggml_context * ctx,
  1685. struct ggml_tensor * a,
  1686. struct ggml_tensor * b,
  1687. enum ggml_type type) {
  1688. return ggml_add_cast_impl(ctx, a, b, type);
  1689. }
  1690. struct ggml_tensor * ggml_add_id(
  1691. struct ggml_context * ctx,
  1692. struct ggml_tensor * a,
  1693. struct ggml_tensor * b,
  1694. struct ggml_tensor * ids) {
  1695. GGML_ASSERT(a->ne[0] == b->ne[0]);
  1696. GGML_ASSERT(a->ne[1] == ids->ne[0]);
  1697. GGML_ASSERT(a->ne[2] == ids->ne[1]);
  1698. GGML_ASSERT(ids->type == GGML_TYPE_I32);
  1699. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  1700. result->op = GGML_OP_ADD_ID;
  1701. result->src[0] = a;
  1702. result->src[1] = b;
  1703. result->src[2] = ids;
  1704. return result;
  1705. }
  1706. // ggml_add1
  1707. static struct ggml_tensor * ggml_add1_impl(
  1708. struct ggml_context * ctx,
  1709. struct ggml_tensor * a,
  1710. struct ggml_tensor * b,
  1711. bool inplace) {
  1712. GGML_ASSERT(ggml_is_scalar(b));
  1713. GGML_ASSERT(ggml_is_padded_1d(a));
  1714. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1715. result->op = GGML_OP_ADD1;
  1716. result->src[0] = a;
  1717. result->src[1] = b;
  1718. return result;
  1719. }
  1720. struct ggml_tensor * ggml_add1(
  1721. struct ggml_context * ctx,
  1722. struct ggml_tensor * a,
  1723. struct ggml_tensor * b) {
  1724. return ggml_add1_impl(ctx, a, b, false);
  1725. }
  1726. struct ggml_tensor * ggml_add1_inplace(
  1727. struct ggml_context * ctx,
  1728. struct ggml_tensor * a,
  1729. struct ggml_tensor * b) {
  1730. return ggml_add1_impl(ctx, a, b, true);
  1731. }
  1732. // ggml_acc
  1733. static struct ggml_tensor * ggml_acc_impl(
  1734. struct ggml_context * ctx,
  1735. struct ggml_tensor * a,
  1736. struct ggml_tensor * b,
  1737. size_t nb1,
  1738. size_t nb2,
  1739. size_t nb3,
  1740. size_t offset,
  1741. bool inplace) {
  1742. GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
  1743. GGML_ASSERT(ggml_is_contiguous(a));
  1744. GGML_ASSERT(a->type == GGML_TYPE_F32);
  1745. GGML_ASSERT(b->type == GGML_TYPE_F32);
  1746. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1747. int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
  1748. ggml_set_op_params(result, params, sizeof(params));
  1749. result->op = GGML_OP_ACC;
  1750. result->src[0] = a;
  1751. result->src[1] = b;
  1752. return result;
  1753. }
  1754. struct ggml_tensor * ggml_acc(
  1755. struct ggml_context * ctx,
  1756. struct ggml_tensor * a,
  1757. struct ggml_tensor * b,
  1758. size_t nb1,
  1759. size_t nb2,
  1760. size_t nb3,
  1761. size_t offset) {
  1762. return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
  1763. }
  1764. struct ggml_tensor * ggml_acc_inplace(
  1765. struct ggml_context * ctx,
  1766. struct ggml_tensor * a,
  1767. struct ggml_tensor * b,
  1768. size_t nb1,
  1769. size_t nb2,
  1770. size_t nb3,
  1771. size_t offset) {
  1772. return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
  1773. }
  1774. // ggml_sub
  1775. static struct ggml_tensor * ggml_sub_impl(
  1776. struct ggml_context * ctx,
  1777. struct ggml_tensor * a,
  1778. struct ggml_tensor * b,
  1779. bool inplace) {
  1780. GGML_ASSERT(ggml_can_repeat(b, a));
  1781. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1782. result->op = GGML_OP_SUB;
  1783. result->src[0] = a;
  1784. result->src[1] = b;
  1785. return result;
  1786. }
  1787. struct ggml_tensor * ggml_sub(
  1788. struct ggml_context * ctx,
  1789. struct ggml_tensor * a,
  1790. struct ggml_tensor * b) {
  1791. return ggml_sub_impl(ctx, a, b, false);
  1792. }
  1793. struct ggml_tensor * ggml_sub_inplace(
  1794. struct ggml_context * ctx,
  1795. struct ggml_tensor * a,
  1796. struct ggml_tensor * b) {
  1797. return ggml_sub_impl(ctx, a, b, true);
  1798. }
  1799. // ggml_mul
  1800. static struct ggml_tensor * ggml_mul_impl(
  1801. struct ggml_context * ctx,
  1802. struct ggml_tensor * a,
  1803. struct ggml_tensor * b,
  1804. bool inplace) {
  1805. GGML_ASSERT(ggml_can_repeat(b, a));
  1806. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1807. result->op = GGML_OP_MUL;
  1808. result->src[0] = a;
  1809. result->src[1] = b;
  1810. return result;
  1811. }
  1812. struct ggml_tensor * ggml_mul(
  1813. struct ggml_context * ctx,
  1814. struct ggml_tensor * a,
  1815. struct ggml_tensor * b) {
  1816. return ggml_mul_impl(ctx, a, b, false);
  1817. }
  1818. struct ggml_tensor * ggml_mul_inplace(
  1819. struct ggml_context * ctx,
  1820. struct ggml_tensor * a,
  1821. struct ggml_tensor * b) {
  1822. return ggml_mul_impl(ctx, a, b, true);
  1823. }
  1824. // ggml_div
  1825. static struct ggml_tensor * ggml_div_impl(
  1826. struct ggml_context * ctx,
  1827. struct ggml_tensor * a,
  1828. struct ggml_tensor * b,
  1829. bool inplace) {
  1830. GGML_ASSERT(ggml_can_repeat(b, a));
  1831. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1832. result->op = GGML_OP_DIV;
  1833. result->src[0] = a;
  1834. result->src[1] = b;
  1835. return result;
  1836. }
  1837. struct ggml_tensor * ggml_div(
  1838. struct ggml_context * ctx,
  1839. struct ggml_tensor * a,
  1840. struct ggml_tensor * b) {
  1841. return ggml_div_impl(ctx, a, b, false);
  1842. }
  1843. struct ggml_tensor * ggml_div_inplace(
  1844. struct ggml_context * ctx,
  1845. struct ggml_tensor * a,
  1846. struct ggml_tensor * b) {
  1847. return ggml_div_impl(ctx, a, b, true);
  1848. }
  1849. // ggml_sqr
  1850. static struct ggml_tensor * ggml_sqr_impl(
  1851. struct ggml_context * ctx,
  1852. struct ggml_tensor * a,
  1853. bool inplace) {
  1854. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1855. result->op = GGML_OP_SQR;
  1856. result->src[0] = a;
  1857. return result;
  1858. }
  1859. struct ggml_tensor * ggml_sqr(
  1860. struct ggml_context * ctx,
  1861. struct ggml_tensor * a) {
  1862. return ggml_sqr_impl(ctx, a, false);
  1863. }
  1864. struct ggml_tensor * ggml_sqr_inplace(
  1865. struct ggml_context * ctx,
  1866. struct ggml_tensor * a) {
  1867. return ggml_sqr_impl(ctx, a, true);
  1868. }
  1869. // ggml_sqrt
  1870. static struct ggml_tensor * ggml_sqrt_impl(
  1871. struct ggml_context * ctx,
  1872. struct ggml_tensor * a,
  1873. bool inplace) {
  1874. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1875. result->op = GGML_OP_SQRT;
  1876. result->src[0] = a;
  1877. return result;
  1878. }
  1879. struct ggml_tensor * ggml_sqrt(
  1880. struct ggml_context * ctx,
  1881. struct ggml_tensor * a) {
  1882. return ggml_sqrt_impl(ctx, a, false);
  1883. }
  1884. struct ggml_tensor * ggml_sqrt_inplace(
  1885. struct ggml_context * ctx,
  1886. struct ggml_tensor * a) {
  1887. return ggml_sqrt_impl(ctx, a, true);
  1888. }
  1889. // ggml_log
  1890. static struct ggml_tensor * ggml_log_impl(
  1891. struct ggml_context * ctx,
  1892. struct ggml_tensor * a,
  1893. bool inplace) {
  1894. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1895. result->op = GGML_OP_LOG;
  1896. result->src[0] = a;
  1897. return result;
  1898. }
  1899. struct ggml_tensor * ggml_log(
  1900. struct ggml_context * ctx,
  1901. struct ggml_tensor * a) {
  1902. return ggml_log_impl(ctx, a, false);
  1903. }
  1904. struct ggml_tensor * ggml_log_inplace(
  1905. struct ggml_context * ctx,
  1906. struct ggml_tensor * a) {
  1907. return ggml_log_impl(ctx, a, true);
  1908. }
  1909. // ggml_sin
  1910. static struct ggml_tensor * ggml_sin_impl(
  1911. struct ggml_context * ctx,
  1912. struct ggml_tensor * a,
  1913. bool inplace) {
  1914. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1915. result->op = GGML_OP_SIN;
  1916. result->src[0] = a;
  1917. return result;
  1918. }
  1919. struct ggml_tensor * ggml_sin(
  1920. struct ggml_context * ctx,
  1921. struct ggml_tensor * a) {
  1922. return ggml_sin_impl(ctx, a, false);
  1923. }
  1924. struct ggml_tensor * ggml_sin_inplace(
  1925. struct ggml_context * ctx,
  1926. struct ggml_tensor * a) {
  1927. return ggml_sin_impl(ctx, a, true);
  1928. }
  1929. // ggml_cos
  1930. static struct ggml_tensor * ggml_cos_impl(
  1931. struct ggml_context * ctx,
  1932. struct ggml_tensor * a,
  1933. bool inplace) {
  1934. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1935. result->op = GGML_OP_COS;
  1936. result->src[0] = a;
  1937. return result;
  1938. }
  1939. struct ggml_tensor * ggml_cos(
  1940. struct ggml_context * ctx,
  1941. struct ggml_tensor * a) {
  1942. return ggml_cos_impl(ctx, a, false);
  1943. }
  1944. struct ggml_tensor * ggml_cos_inplace(
  1945. struct ggml_context * ctx,
  1946. struct ggml_tensor * a) {
  1947. return ggml_cos_impl(ctx, a, true);
  1948. }
  1949. // ggml_sum
  1950. struct ggml_tensor * ggml_sum(
  1951. struct ggml_context * ctx,
  1952. struct ggml_tensor * a) {
  1953. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
  1954. result->op = GGML_OP_SUM;
  1955. result->src[0] = a;
  1956. return result;
  1957. }
  1958. // ggml_sum_rows
  1959. struct ggml_tensor * ggml_sum_rows(
  1960. struct ggml_context * ctx,
  1961. struct ggml_tensor * a) {
  1962. int64_t ne[GGML_MAX_DIMS] = { 1 };
  1963. for (int i = 1; i < GGML_MAX_DIMS; ++i) {
  1964. ne[i] = a->ne[i];
  1965. }
  1966. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
  1967. result->op = GGML_OP_SUM_ROWS;
  1968. result->src[0] = a;
  1969. return result;
  1970. }
  1971. // ggml_mean
  1972. struct ggml_tensor * ggml_mean(
  1973. struct ggml_context * ctx,
  1974. struct ggml_tensor * a) {
  1975. int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
  1976. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  1977. result->op = GGML_OP_MEAN;
  1978. result->src[0] = a;
  1979. return result;
  1980. }
  1981. // ggml_argmax
  1982. struct ggml_tensor * ggml_argmax(
  1983. struct ggml_context * ctx,
  1984. struct ggml_tensor * a) {
  1985. GGML_ASSERT(ggml_is_matrix(a));
  1986. GGML_ASSERT(a->ne[0] <= INT32_MAX);
  1987. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
  1988. result->op = GGML_OP_ARGMAX;
  1989. result->src[0] = a;
  1990. return result;
  1991. }
  1992. // ggml_count_equal
  1993. struct ggml_tensor * ggml_count_equal(
  1994. struct ggml_context * ctx,
  1995. struct ggml_tensor * a,
  1996. struct ggml_tensor * b) {
  1997. GGML_ASSERT(ggml_are_same_shape(a, b));
  1998. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
  1999. result->op = GGML_OP_COUNT_EQUAL;
  2000. result->src[0] = a;
  2001. result->src[1] = b;
  2002. return result;
  2003. }
  2004. // ggml_repeat
  2005. struct ggml_tensor * ggml_repeat(
  2006. struct ggml_context * ctx,
  2007. struct ggml_tensor * a,
  2008. struct ggml_tensor * b) {
  2009. GGML_ASSERT(ggml_can_repeat(a, b));
  2010. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
  2011. result->op = GGML_OP_REPEAT;
  2012. result->src[0] = a;
  2013. return result;
  2014. }
  2015. struct ggml_tensor * ggml_repeat_4d(
  2016. struct ggml_context * ctx,
  2017. struct ggml_tensor * a,
  2018. int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
  2019. const bool can_repeat = ggml_is_empty(a) || (
  2020. (ne0 % a->ne[0] == 0) &&
  2021. (ne1 % a->ne[1] == 0) &&
  2022. (ne2 % a->ne[2] == 0) &&
  2023. (ne3 % a->ne[3] == 0)
  2024. );
  2025. GGML_ASSERT(can_repeat);
  2026. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
  2027. result->op = GGML_OP_REPEAT;
  2028. result->src[0] = a;
  2029. return result;
  2030. }
  2031. // ggml_repeat_back
  2032. struct ggml_tensor * ggml_repeat_back(
  2033. struct ggml_context * ctx,
  2034. struct ggml_tensor * a,
  2035. struct ggml_tensor * b) {
  2036. GGML_ASSERT(ggml_can_repeat(b, a));
  2037. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
  2038. result->op = GGML_OP_REPEAT_BACK;
  2039. result->src[0] = a;
  2040. return result;
  2041. }
  2042. // ggml_concat
  2043. struct ggml_tensor * ggml_concat(
  2044. struct ggml_context * ctx,
  2045. struct ggml_tensor * a,
  2046. struct ggml_tensor * b,
  2047. int dim) {
  2048. GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
  2049. GGML_ASSERT(a->type == b->type);
  2050. int64_t ne[GGML_MAX_DIMS];
  2051. for (int d = 0; d < GGML_MAX_DIMS; ++d) {
  2052. if (d == dim) {
  2053. ne[d] = a->ne[d] + b->ne[d];
  2054. continue;
  2055. }
  2056. GGML_ASSERT(a->ne[d] == b->ne[d]);
  2057. ne[d] = a->ne[d];
  2058. }
  2059. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
  2060. ggml_set_op_params_i32(result, 0, dim);
  2061. result->op = GGML_OP_CONCAT;
  2062. result->src[0] = a;
  2063. result->src[1] = b;
  2064. return result;
  2065. }
  2066. // ggml_abs
  2067. struct ggml_tensor * ggml_abs(
  2068. struct ggml_context * ctx,
  2069. struct ggml_tensor * a) {
  2070. return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
  2071. }
  2072. struct ggml_tensor * ggml_abs_inplace(
  2073. struct ggml_context * ctx,
  2074. struct ggml_tensor * a) {
  2075. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
  2076. }
  2077. // ggml_sgn
  2078. struct ggml_tensor * ggml_sgn(
  2079. struct ggml_context * ctx,
  2080. struct ggml_tensor * a) {
  2081. return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
  2082. }
  2083. struct ggml_tensor * ggml_sgn_inplace(
  2084. struct ggml_context * ctx,
  2085. struct ggml_tensor * a) {
  2086. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
  2087. }
  2088. // ggml_neg
  2089. struct ggml_tensor * ggml_neg(
  2090. struct ggml_context * ctx,
  2091. struct ggml_tensor * a) {
  2092. return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
  2093. }
  2094. struct ggml_tensor * ggml_neg_inplace(
  2095. struct ggml_context * ctx,
  2096. struct ggml_tensor * a) {
  2097. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
  2098. }
  2099. // ggml_step
  2100. struct ggml_tensor * ggml_step(
  2101. struct ggml_context * ctx,
  2102. struct ggml_tensor * a) {
  2103. return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
  2104. }
  2105. struct ggml_tensor * ggml_step_inplace(
  2106. struct ggml_context * ctx,
  2107. struct ggml_tensor * a) {
  2108. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
  2109. }
  2110. // ggml_tanh
  2111. struct ggml_tensor * ggml_tanh(
  2112. struct ggml_context * ctx,
  2113. struct ggml_tensor * a) {
  2114. return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
  2115. }
  2116. struct ggml_tensor * ggml_tanh_inplace(
  2117. struct ggml_context * ctx,
  2118. struct ggml_tensor * a) {
  2119. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
  2120. }
  2121. // ggml_elu
  2122. struct ggml_tensor * ggml_elu(
  2123. struct ggml_context * ctx,
  2124. struct ggml_tensor * a) {
  2125. return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
  2126. }
  2127. struct ggml_tensor * ggml_elu_inplace(
  2128. struct ggml_context * ctx,
  2129. struct ggml_tensor * a) {
  2130. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
  2131. }
  2132. // ggml_relu
  2133. struct ggml_tensor * ggml_relu(
  2134. struct ggml_context * ctx,
  2135. struct ggml_tensor * a) {
  2136. return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
  2137. }
  2138. struct ggml_tensor * ggml_relu_inplace(
  2139. struct ggml_context * ctx,
  2140. struct ggml_tensor * a) {
  2141. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
  2142. }
  2143. // ggml_leaky_relu
  2144. struct ggml_tensor * ggml_leaky_relu(
  2145. struct ggml_context * ctx,
  2146. struct ggml_tensor * a,
  2147. float negative_slope,
  2148. bool inplace) {
  2149. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2150. ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
  2151. result->op = GGML_OP_LEAKY_RELU;
  2152. result->src[0] = a;
  2153. return result;
  2154. }
  2155. // ggml_sigmoid
  2156. struct ggml_tensor * ggml_sigmoid(
  2157. struct ggml_context * ctx,
  2158. struct ggml_tensor * a) {
  2159. return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
  2160. }
  2161. struct ggml_tensor * ggml_sigmoid_inplace(
  2162. struct ggml_context * ctx,
  2163. struct ggml_tensor * a) {
  2164. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
  2165. }
  2166. // ggml_gelu
  2167. struct ggml_tensor * ggml_gelu(
  2168. struct ggml_context * ctx,
  2169. struct ggml_tensor * a) {
  2170. return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
  2171. }
  2172. struct ggml_tensor * ggml_gelu_inplace(
  2173. struct ggml_context * ctx,
  2174. struct ggml_tensor * a) {
  2175. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
  2176. }
  2177. // ggml_gelu_erf
  2178. struct ggml_tensor * ggml_gelu_erf(
  2179. struct ggml_context * ctx,
  2180. struct ggml_tensor * a) {
  2181. return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
  2182. }
  2183. struct ggml_tensor * ggml_gelu_erf_inplace(
  2184. struct ggml_context * ctx,
  2185. struct ggml_tensor * a) {
  2186. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
  2187. }
  2188. // ggml_gelu_quick
  2189. struct ggml_tensor * ggml_gelu_quick(
  2190. struct ggml_context * ctx,
  2191. struct ggml_tensor * a) {
  2192. return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
  2193. }
  2194. struct ggml_tensor * ggml_gelu_quick_inplace(
  2195. struct ggml_context * ctx,
  2196. struct ggml_tensor * a) {
  2197. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
  2198. }
  2199. // ggml_silu
  2200. struct ggml_tensor * ggml_silu(
  2201. struct ggml_context * ctx,
  2202. struct ggml_tensor * a) {
  2203. return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
  2204. }
  2205. struct ggml_tensor * ggml_silu_inplace(
  2206. struct ggml_context * ctx,
  2207. struct ggml_tensor * a) {
  2208. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
  2209. }
  2210. // ggml_xielu
  2211. struct ggml_tensor * ggml_xielu(
  2212. struct ggml_context * ctx,
  2213. struct ggml_tensor * a,
  2214. float alpha_n,
  2215. float alpha_p,
  2216. float beta,
  2217. float eps) {
  2218. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2219. ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
  2220. ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
  2221. ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
  2222. ggml_set_op_params_f32(result, 3, beta);
  2223. ggml_set_op_params_f32(result, 4, eps);
  2224. result->op = GGML_OP_UNARY;
  2225. result->src[0] = a;
  2226. return result;
  2227. }
  2228. // ggml_silu_back
  2229. struct ggml_tensor * ggml_silu_back(
  2230. struct ggml_context * ctx,
  2231. struct ggml_tensor * a,
  2232. struct ggml_tensor * b) {
  2233. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2234. result->op = GGML_OP_SILU_BACK;
  2235. result->src[0] = a;
  2236. result->src[1] = b;
  2237. return result;
  2238. }
  2239. // ggml hardswish
  2240. struct ggml_tensor * ggml_hardswish(
  2241. struct ggml_context * ctx,
  2242. struct ggml_tensor * a) {
  2243. return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH);
  2244. }
  2245. // ggml hardsigmoid
  2246. struct ggml_tensor * ggml_hardsigmoid(
  2247. struct ggml_context * ctx,
  2248. struct ggml_tensor * a) {
  2249. return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
  2250. }
  2251. // ggml exp
  2252. struct ggml_tensor * ggml_exp(
  2253. struct ggml_context * ctx,
  2254. struct ggml_tensor * a) {
  2255. return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
  2256. }
  2257. struct ggml_tensor * ggml_exp_inplace(
  2258. struct ggml_context * ctx,
  2259. struct ggml_tensor * a) {
  2260. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
  2261. }
  2262. // ggml_glu
  2263. static struct ggml_tensor * ggml_glu_impl(
  2264. struct ggml_context * ctx,
  2265. struct ggml_tensor * a,
  2266. struct ggml_tensor * b,
  2267. enum ggml_glu_op op,
  2268. bool swapped) {
  2269. GGML_ASSERT(ggml_is_contiguous_1(a));
  2270. if (b) {
  2271. GGML_ASSERT(ggml_is_contiguous_1(b));
  2272. GGML_ASSERT(ggml_are_same_shape(a, b));
  2273. GGML_ASSERT(a->type == b->type);
  2274. }
  2275. int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
  2276. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
  2277. ggml_set_op_params_i32(result, 0, (int32_t) op);
  2278. ggml_set_op_params_i32(result, 1, (int32_t) swapped);
  2279. result->op = GGML_OP_GLU;
  2280. result->src[0] = a;
  2281. result->src[1] = b;
  2282. return result;
  2283. }
  2284. // ggml_floor
  2285. struct ggml_tensor * ggml_floor(
  2286. struct ggml_context * ctx,
  2287. struct ggml_tensor * a) {
  2288. return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);
  2289. }
  2290. struct ggml_tensor * ggml_floor_inplace(
  2291. struct ggml_context * ctx,
  2292. struct ggml_tensor * a) {
  2293. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);
  2294. }
  2295. // ggml_ceil
  2296. struct ggml_tensor * ggml_ceil(
  2297. struct ggml_context * ctx,
  2298. struct ggml_tensor * a) {
  2299. return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);
  2300. }
  2301. struct ggml_tensor * ggml_ceil_inplace(
  2302. struct ggml_context * ctx,
  2303. struct ggml_tensor * a) {
  2304. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);
  2305. }
  2306. //ggml_round
  2307. struct ggml_tensor * ggml_round(
  2308. struct ggml_context * ctx,
  2309. struct ggml_tensor * a) {
  2310. return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND);
  2311. }
  2312. struct ggml_tensor * ggml_round_inplace(
  2313. struct ggml_context * ctx,
  2314. struct ggml_tensor * a) {
  2315. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND);
  2316. }
  2317. //ggml_trunc
  2318. struct ggml_tensor * ggml_trunc(
  2319. struct ggml_context * ctx,
  2320. struct ggml_tensor * a) {
  2321. return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC);
  2322. }
  2323. struct ggml_tensor * ggml_trunc_inplace(
  2324. struct ggml_context * ctx,
  2325. struct ggml_tensor * a) {
  2326. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC);
  2327. }
  2328. struct ggml_tensor * ggml_glu(
  2329. struct ggml_context * ctx,
  2330. struct ggml_tensor * a,
  2331. enum ggml_glu_op op,
  2332. bool swapped) {
  2333. return ggml_glu_impl(ctx, a, NULL, op, swapped);
  2334. }
  2335. struct ggml_tensor * ggml_glu_split(
  2336. struct ggml_context * ctx,
  2337. struct ggml_tensor * a,
  2338. struct ggml_tensor * b,
  2339. enum ggml_glu_op op) {
  2340. return ggml_glu_impl(ctx, a, b, op, false);
  2341. }
  2342. // ggml_reglu
  2343. struct ggml_tensor * ggml_reglu(
  2344. struct ggml_context * ctx,
  2345. struct ggml_tensor * a) {
  2346. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
  2347. }
  2348. struct ggml_tensor * ggml_reglu_swapped(
  2349. struct ggml_context * ctx,
  2350. struct ggml_tensor * a) {
  2351. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
  2352. }
  2353. struct ggml_tensor * ggml_reglu_split(
  2354. struct ggml_context * ctx,
  2355. struct ggml_tensor * a,
  2356. struct ggml_tensor * b) {
  2357. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
  2358. }
  2359. // ggml_geglu
  2360. struct ggml_tensor * ggml_geglu(
  2361. struct ggml_context * ctx,
  2362. struct ggml_tensor * a) {
  2363. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
  2364. }
  2365. struct ggml_tensor * ggml_geglu_swapped(
  2366. struct ggml_context * ctx,
  2367. struct ggml_tensor * a) {
  2368. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
  2369. }
  2370. struct ggml_tensor * ggml_geglu_split(
  2371. struct ggml_context * ctx,
  2372. struct ggml_tensor * a,
  2373. struct ggml_tensor * b) {
  2374. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
  2375. }
  2376. // ggml_swiglu
  2377. struct ggml_tensor * ggml_swiglu(
  2378. struct ggml_context * ctx,
  2379. struct ggml_tensor * a) {
  2380. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
  2381. }
  2382. struct ggml_tensor * ggml_swiglu_swapped(
  2383. struct ggml_context * ctx,
  2384. struct ggml_tensor * a) {
  2385. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
  2386. }
  2387. struct ggml_tensor * ggml_swiglu_split(
  2388. struct ggml_context * ctx,
  2389. struct ggml_tensor * a,
  2390. struct ggml_tensor * b) {
  2391. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
  2392. }
  2393. // ggml_geglu_erf
  2394. struct ggml_tensor * ggml_geglu_erf(
  2395. struct ggml_context * ctx,
  2396. struct ggml_tensor * a) {
  2397. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
  2398. }
  2399. struct ggml_tensor * ggml_geglu_erf_swapped(
  2400. struct ggml_context * ctx,
  2401. struct ggml_tensor * a) {
  2402. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
  2403. }
  2404. struct ggml_tensor * ggml_geglu_erf_split(
  2405. struct ggml_context * ctx,
  2406. struct ggml_tensor * a,
  2407. struct ggml_tensor * b) {
  2408. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
  2409. }
  2410. // ggml_geglu_quick
  2411. struct ggml_tensor * ggml_geglu_quick(
  2412. struct ggml_context * ctx,
  2413. struct ggml_tensor * a) {
  2414. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
  2415. }
  2416. struct ggml_tensor * ggml_geglu_quick_swapped(
  2417. struct ggml_context * ctx,
  2418. struct ggml_tensor * a) {
  2419. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
  2420. }
  2421. struct ggml_tensor * ggml_geglu_quick_split(
  2422. struct ggml_context * ctx,
  2423. struct ggml_tensor * a,
  2424. struct ggml_tensor * b) {
  2425. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
  2426. }
  2427. struct ggml_tensor * ggml_swiglu_oai(
  2428. struct ggml_context * ctx,
  2429. struct ggml_tensor * a,
  2430. struct ggml_tensor * b,
  2431. float alpha,
  2432. float limit) {
  2433. struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
  2434. ggml_set_op_params_f32(result, 2, alpha);
  2435. ggml_set_op_params_f32(result, 3, limit);
  2436. return result;
  2437. }
  2438. // ggml_norm
  2439. static struct ggml_tensor * ggml_norm_impl(
  2440. struct ggml_context * ctx,
  2441. struct ggml_tensor * a,
  2442. float eps,
  2443. bool inplace) {
  2444. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2445. ggml_set_op_params(result, &eps, sizeof(eps));
  2446. result->op = GGML_OP_NORM;
  2447. result->src[0] = a;
  2448. return result;
  2449. }
  2450. struct ggml_tensor * ggml_norm(
  2451. struct ggml_context * ctx,
  2452. struct ggml_tensor * a,
  2453. float eps) {
  2454. return ggml_norm_impl(ctx, a, eps, false);
  2455. }
  2456. struct ggml_tensor * ggml_norm_inplace(
  2457. struct ggml_context * ctx,
  2458. struct ggml_tensor * a,
  2459. float eps) {
  2460. return ggml_norm_impl(ctx, a, eps, true);
  2461. }
  2462. // ggml_rms_norm
  2463. static struct ggml_tensor * ggml_rms_norm_impl(
  2464. struct ggml_context * ctx,
  2465. struct ggml_tensor * a,
  2466. float eps,
  2467. bool inplace) {
  2468. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2469. ggml_set_op_params(result, &eps, sizeof(eps));
  2470. result->op = GGML_OP_RMS_NORM;
  2471. result->src[0] = a;
  2472. return result;
  2473. }
  2474. struct ggml_tensor * ggml_rms_norm(
  2475. struct ggml_context * ctx,
  2476. struct ggml_tensor * a,
  2477. float eps) {
  2478. return ggml_rms_norm_impl(ctx, a, eps, false);
  2479. }
  2480. struct ggml_tensor * ggml_rms_norm_inplace(
  2481. struct ggml_context * ctx,
  2482. struct ggml_tensor * a,
  2483. float eps) {
  2484. return ggml_rms_norm_impl(ctx, a, eps, true);
  2485. }
  2486. // ggml_rms_norm_back
  2487. struct ggml_tensor * ggml_rms_norm_back(
  2488. struct ggml_context * ctx,
  2489. struct ggml_tensor * a,
  2490. struct ggml_tensor * b,
  2491. float eps) {
  2492. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2493. ggml_set_op_params(result, &eps, sizeof(eps));
  2494. result->op = GGML_OP_RMS_NORM_BACK;
  2495. result->src[0] = a;
  2496. result->src[1] = b;
  2497. return result;
  2498. }
  2499. // ggml_group_norm
  2500. static struct ggml_tensor * ggml_group_norm_impl(
  2501. struct ggml_context * ctx,
  2502. struct ggml_tensor * a,
  2503. int n_groups,
  2504. float eps,
  2505. bool inplace) {
  2506. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2507. ggml_set_op_params_i32(result, 0, n_groups);
  2508. ggml_set_op_params_f32(result, 1, eps);
  2509. result->op = GGML_OP_GROUP_NORM;
  2510. result->src[0] = a;
  2511. return result;
  2512. }
  2513. struct ggml_tensor * ggml_group_norm(
  2514. struct ggml_context * ctx,
  2515. struct ggml_tensor * a,
  2516. int n_groups,
  2517. float eps) {
  2518. return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
  2519. }
  2520. struct ggml_tensor * ggml_group_norm_inplace(
  2521. struct ggml_context * ctx,
  2522. struct ggml_tensor * a,
  2523. int n_groups,
  2524. float eps) {
  2525. return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
  2526. }
  2527. // ggml_l2_norm
  2528. static struct ggml_tensor * ggml_l2_norm_impl(
  2529. struct ggml_context * ctx,
  2530. struct ggml_tensor * a,
  2531. float eps,
  2532. bool inplace) {
  2533. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2534. ggml_set_op_params_f32(result, 0, eps);
  2535. result->op = GGML_OP_L2_NORM;
  2536. result->src[0] = a;
  2537. return result;
  2538. }
  2539. struct ggml_tensor * ggml_l2_norm(
  2540. struct ggml_context * ctx,
  2541. struct ggml_tensor * a,
  2542. float eps) {
  2543. return ggml_l2_norm_impl(ctx, a, eps, false);
  2544. }
  2545. struct ggml_tensor * ggml_l2_norm_inplace(
  2546. struct ggml_context * ctx,
  2547. struct ggml_tensor * a,
  2548. float eps) {
  2549. return ggml_l2_norm_impl(ctx, a, eps, true);
  2550. }
  2551. // ggml_mul_mat
  2552. static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  2553. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  2554. return (t0->ne[0] == t1->ne[0]) &&
  2555. (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
  2556. (t1->ne[3]%t0->ne[3] == 0);
  2557. }
  2558. struct ggml_tensor * ggml_mul_mat(
  2559. struct ggml_context * ctx,
  2560. struct ggml_tensor * a,
  2561. struct ggml_tensor * b) {
  2562. GGML_ASSERT(ggml_can_mul_mat(a, b));
  2563. GGML_ASSERT(!ggml_is_transposed(a));
  2564. const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
  2565. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2566. result->op = GGML_OP_MUL_MAT;
  2567. result->src[0] = a;
  2568. result->src[1] = b;
  2569. return result;
  2570. }
  2571. void ggml_mul_mat_set_prec(
  2572. struct ggml_tensor * a,
  2573. enum ggml_prec prec) {
  2574. GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
  2575. const int32_t prec_i32 = (int32_t) prec;
  2576. ggml_set_op_params_i32(a, 0, prec_i32);
  2577. }
  2578. // ggml_mul_mat_id
  2579. /*
  2580. c = ggml_mul_mat_id(ctx, as, b, ids);
  2581. as -> [cols, rows, n_expert]
  2582. b -> [cols, n_expert_used, n_tokens]
  2583. ids -> [n_expert_used, n_tokens] (i32)
  2584. c -> [rows, n_expert_used, n_tokens]
  2585. in b, n_expert_used can be broadcasted to match the n_expert_used of ids
  2586. c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
  2587. */
  2588. struct ggml_tensor * ggml_mul_mat_id(
  2589. struct ggml_context * ctx,
  2590. struct ggml_tensor * as,
  2591. struct ggml_tensor * b,
  2592. struct ggml_tensor * ids) {
  2593. GGML_ASSERT(!ggml_is_transposed(as));
  2594. GGML_ASSERT(ids->type == GGML_TYPE_I32);
  2595. GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
  2596. GGML_ASSERT(b->ne[3] == 1); // b is 3d
  2597. GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
  2598. GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
  2599. GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
  2600. GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
  2601. const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
  2602. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2603. result->op = GGML_OP_MUL_MAT_ID;
  2604. result->src[0] = as;
  2605. result->src[1] = b;
  2606. result->src[2] = ids;
  2607. return result;
  2608. }
  2609. // ggml_out_prod
  2610. static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  2611. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  2612. return (t0->ne[1] == t1->ne[1]) &&
  2613. (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
  2614. (t1->ne[3]%t0->ne[3] == 0);
  2615. }
  2616. struct ggml_tensor * ggml_out_prod(
  2617. struct ggml_context * ctx,
  2618. struct ggml_tensor * a,
  2619. struct ggml_tensor * b) {
  2620. GGML_ASSERT(ggml_can_out_prod(a, b));
  2621. GGML_ASSERT(!ggml_is_transposed(a));
  2622. // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
  2623. const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
  2624. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2625. result->op = GGML_OP_OUT_PROD;
  2626. result->src[0] = a;
  2627. result->src[1] = b;
  2628. return result;
  2629. }
  2630. // ggml_scale
  2631. static struct ggml_tensor * ggml_scale_impl(
  2632. struct ggml_context * ctx,
  2633. struct ggml_tensor * a,
  2634. float s,
  2635. float b,
  2636. bool inplace) {
  2637. GGML_ASSERT(ggml_is_padded_1d(a));
  2638. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2639. float params[2] = { s, b };
  2640. ggml_set_op_params(result, &params, sizeof(params));
  2641. result->op = GGML_OP_SCALE;
  2642. result->src[0] = a;
  2643. return result;
  2644. }
  2645. struct ggml_tensor * ggml_scale(
  2646. struct ggml_context * ctx,
  2647. struct ggml_tensor * a,
  2648. float s) {
  2649. return ggml_scale_impl(ctx, a, s, 0.0, false);
  2650. }
  2651. struct ggml_tensor * ggml_scale_inplace(
  2652. struct ggml_context * ctx,
  2653. struct ggml_tensor * a,
  2654. float s) {
  2655. return ggml_scale_impl(ctx, a, s, 0.0, true);
  2656. }
  2657. struct ggml_tensor * ggml_scale_bias(
  2658. struct ggml_context * ctx,
  2659. struct ggml_tensor * a,
  2660. float s,
  2661. float b) {
  2662. return ggml_scale_impl(ctx, a, s, b, false);
  2663. }
  2664. struct ggml_tensor * ggml_scale_bias_inplace(
  2665. struct ggml_context * ctx,
  2666. struct ggml_tensor * a,
  2667. float s,
  2668. float b) {
  2669. return ggml_scale_impl(ctx, a, s, b, true);
  2670. }
  2671. // ggml_set
  2672. static struct ggml_tensor * ggml_set_impl(
  2673. struct ggml_context * ctx,
  2674. struct ggml_tensor * a,
  2675. struct ggml_tensor * b,
  2676. size_t nb1,
  2677. size_t nb2,
  2678. size_t nb3,
  2679. size_t offset,
  2680. bool inplace) {
  2681. GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
  2682. // make a view of the destination
  2683. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2684. GGML_ASSERT(offset < (size_t)(1 << 30));
  2685. int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
  2686. ggml_set_op_params(result, params, sizeof(params));
  2687. result->op = GGML_OP_SET;
  2688. result->src[0] = a;
  2689. result->src[1] = b;
  2690. return result;
  2691. }
  2692. struct ggml_tensor * ggml_set(
  2693. struct ggml_context * ctx,
  2694. struct ggml_tensor * a,
  2695. struct ggml_tensor * b,
  2696. size_t nb1,
  2697. size_t nb2,
  2698. size_t nb3,
  2699. size_t offset) {
  2700. return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
  2701. }
  2702. struct ggml_tensor * ggml_set_inplace(
  2703. struct ggml_context * ctx,
  2704. struct ggml_tensor * a,
  2705. struct ggml_tensor * b,
  2706. size_t nb1,
  2707. size_t nb2,
  2708. size_t nb3,
  2709. size_t offset) {
  2710. return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
  2711. }
  2712. struct ggml_tensor * ggml_set_1d(
  2713. struct ggml_context * ctx,
  2714. struct ggml_tensor * a,
  2715. struct ggml_tensor * b,
  2716. size_t offset) {
  2717. return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
  2718. }
  2719. struct ggml_tensor * ggml_set_1d_inplace(
  2720. struct ggml_context * ctx,
  2721. struct ggml_tensor * a,
  2722. struct ggml_tensor * b,
  2723. size_t offset) {
  2724. return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
  2725. }
  2726. struct ggml_tensor * ggml_set_2d(
  2727. struct ggml_context * ctx,
  2728. struct ggml_tensor * a,
  2729. struct ggml_tensor * b,
  2730. size_t nb1,
  2731. size_t offset) {
  2732. return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
  2733. }
  2734. struct ggml_tensor * ggml_set_2d_inplace(
  2735. struct ggml_context * ctx,
  2736. struct ggml_tensor * a,
  2737. struct ggml_tensor * b,
  2738. size_t nb1,
  2739. size_t offset) {
  2740. return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
  2741. }
  2742. // ggml_cpy
  2743. static struct ggml_tensor * ggml_cpy_impl(
  2744. struct ggml_context * ctx,
  2745. struct ggml_tensor * a,
  2746. struct ggml_tensor * b) {
  2747. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
  2748. // make a view of the destination
  2749. struct ggml_tensor * result = ggml_view_tensor(ctx, b);
  2750. if (strlen(b->name) > 0) {
  2751. ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
  2752. } else {
  2753. ggml_format_name(result, "%s (copy)", a->name);
  2754. }
  2755. result->op = GGML_OP_CPY;
  2756. result->src[0] = a;
  2757. result->src[1] = b;
  2758. return result;
  2759. }
  2760. struct ggml_tensor * ggml_cpy(
  2761. struct ggml_context * ctx,
  2762. struct ggml_tensor * a,
  2763. struct ggml_tensor * b) {
  2764. return ggml_cpy_impl(ctx, a, b);
  2765. }
  2766. struct ggml_tensor * ggml_cast(
  2767. struct ggml_context * ctx,
  2768. struct ggml_tensor * a,
  2769. enum ggml_type type) {
  2770. struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
  2771. ggml_format_name(result, "%s (copy)", a->name);
  2772. result->op = GGML_OP_CPY;
  2773. result->src[0] = a;
  2774. result->src[1] = result;
  2775. return result;
  2776. }
  2777. // ggml_cont
  2778. static struct ggml_tensor * ggml_cont_impl(
  2779. struct ggml_context * ctx,
  2780. struct ggml_tensor * a) {
  2781. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2782. ggml_format_name(result, "%s (cont)", a->name);
  2783. result->op = GGML_OP_CONT;
  2784. result->src[0] = a;
  2785. return result;
  2786. }
  2787. struct ggml_tensor * ggml_cont(
  2788. struct ggml_context * ctx,
  2789. struct ggml_tensor * a) {
  2790. return ggml_cont_impl(ctx, a);
  2791. }
  2792. // make contiguous, with new shape
  2793. GGML_API struct ggml_tensor * ggml_cont_1d(
  2794. struct ggml_context * ctx,
  2795. struct ggml_tensor * a,
  2796. int64_t ne0) {
  2797. return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
  2798. }
  2799. GGML_API struct ggml_tensor * ggml_cont_2d(
  2800. struct ggml_context * ctx,
  2801. struct ggml_tensor * a,
  2802. int64_t ne0,
  2803. int64_t ne1) {
  2804. return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
  2805. }
  2806. GGML_API struct ggml_tensor * ggml_cont_3d(
  2807. struct ggml_context * ctx,
  2808. struct ggml_tensor * a,
  2809. int64_t ne0,
  2810. int64_t ne1,
  2811. int64_t ne2) {
  2812. return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
  2813. }
  2814. struct ggml_tensor * ggml_cont_4d(
  2815. struct ggml_context * ctx,
  2816. struct ggml_tensor * a,
  2817. int64_t ne0,
  2818. int64_t ne1,
  2819. int64_t ne2,
  2820. int64_t ne3) {
  2821. GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
  2822. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
  2823. ggml_format_name(result, "%s (cont)", a->name);
  2824. result->op = GGML_OP_CONT;
  2825. result->src[0] = a;
  2826. return result;
  2827. }
  2828. // ggml_reshape
  2829. struct ggml_tensor * ggml_reshape(
  2830. struct ggml_context * ctx,
  2831. struct ggml_tensor * a,
  2832. struct ggml_tensor * b) {
  2833. GGML_ASSERT(ggml_is_contiguous(a));
  2834. // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
  2835. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
  2836. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);
  2837. ggml_format_name(result, "%s (reshaped)", a->name);
  2838. result->op = GGML_OP_RESHAPE;
  2839. result->src[0] = a;
  2840. return result;
  2841. }
  2842. struct ggml_tensor * ggml_reshape_1d(
  2843. struct ggml_context * ctx,
  2844. struct ggml_tensor * a,
  2845. int64_t ne0) {
  2846. GGML_ASSERT(ggml_is_contiguous(a));
  2847. GGML_ASSERT(ggml_nelements(a) == ne0);
  2848. const int64_t ne[1] = { ne0 };
  2849. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
  2850. ggml_format_name(result, "%s (reshaped)", a->name);
  2851. result->op = GGML_OP_RESHAPE;
  2852. result->src[0] = a;
  2853. return result;
  2854. }
  2855. struct ggml_tensor * ggml_reshape_2d(
  2856. struct ggml_context * ctx,
  2857. struct ggml_tensor * a,
  2858. int64_t ne0,
  2859. int64_t ne1) {
  2860. GGML_ASSERT(ggml_is_contiguous(a));
  2861. GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
  2862. const int64_t ne[2] = { ne0, ne1 };
  2863. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
  2864. ggml_format_name(result, "%s (reshaped)", a->name);
  2865. result->op = GGML_OP_RESHAPE;
  2866. result->src[0] = a;
  2867. return result;
  2868. }
  2869. struct ggml_tensor * ggml_reshape_3d(
  2870. struct ggml_context * ctx,
  2871. struct ggml_tensor * a,
  2872. int64_t ne0,
  2873. int64_t ne1,
  2874. int64_t ne2) {
  2875. GGML_ASSERT(ggml_is_contiguous(a));
  2876. GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
  2877. const int64_t ne[3] = { ne0, ne1, ne2 };
  2878. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
  2879. ggml_format_name(result, "%s (reshaped)", a->name);
  2880. result->op = GGML_OP_RESHAPE;
  2881. result->src[0] = a;
  2882. return result;
  2883. }
  2884. struct ggml_tensor * ggml_reshape_4d(
  2885. struct ggml_context * ctx,
  2886. struct ggml_tensor * a,
  2887. int64_t ne0,
  2888. int64_t ne1,
  2889. int64_t ne2,
  2890. int64_t ne3) {
  2891. GGML_ASSERT(ggml_is_contiguous(a));
  2892. GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
  2893. const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
  2894. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
  2895. ggml_format_name(result, "%s (reshaped)", a->name);
  2896. result->op = GGML_OP_RESHAPE;
  2897. result->src[0] = a;
  2898. return result;
  2899. }
  2900. static struct ggml_tensor * ggml_view_impl(
  2901. struct ggml_context * ctx,
  2902. struct ggml_tensor * a,
  2903. int n_dims,
  2904. const int64_t * ne,
  2905. size_t offset) {
  2906. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
  2907. ggml_format_name(result, "%s (view)", a->name);
  2908. ggml_set_op_params(result, &offset, sizeof(offset));
  2909. result->op = GGML_OP_VIEW;
  2910. result->src[0] = a;
  2911. return result;
  2912. }
  2913. // ggml_view_1d
  2914. struct ggml_tensor * ggml_view_1d(
  2915. struct ggml_context * ctx,
  2916. struct ggml_tensor * a,
  2917. int64_t ne0,
  2918. size_t offset) {
  2919. struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
  2920. return result;
  2921. }
  2922. // ggml_view_2d
  2923. struct ggml_tensor * ggml_view_2d(
  2924. struct ggml_context * ctx,
  2925. struct ggml_tensor * a,
  2926. int64_t ne0,
  2927. int64_t ne1,
  2928. size_t nb1,
  2929. size_t offset) {
  2930. const int64_t ne[2] = { ne0, ne1 };
  2931. struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
  2932. result->nb[1] = nb1;
  2933. result->nb[2] = result->nb[1]*ne1;
  2934. result->nb[3] = result->nb[2];
  2935. return result;
  2936. }
  2937. // ggml_view_3d
  2938. struct ggml_tensor * ggml_view_3d(
  2939. struct ggml_context * ctx,
  2940. struct ggml_tensor * a,
  2941. int64_t ne0,
  2942. int64_t ne1,
  2943. int64_t ne2,
  2944. size_t nb1,
  2945. size_t nb2,
  2946. size_t offset) {
  2947. const int64_t ne[3] = { ne0, ne1, ne2 };
  2948. struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
  2949. result->nb[1] = nb1;
  2950. result->nb[2] = nb2;
  2951. result->nb[3] = result->nb[2]*ne2;
  2952. return result;
  2953. }
  2954. // ggml_view_4d
  2955. struct ggml_tensor * ggml_view_4d(
  2956. struct ggml_context * ctx,
  2957. struct ggml_tensor * a,
  2958. int64_t ne0,
  2959. int64_t ne1,
  2960. int64_t ne2,
  2961. int64_t ne3,
  2962. size_t nb1,
  2963. size_t nb2,
  2964. size_t nb3,
  2965. size_t offset) {
  2966. const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
  2967. struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
  2968. result->nb[1] = nb1;
  2969. result->nb[2] = nb2;
  2970. result->nb[3] = nb3;
  2971. return result;
  2972. }
  2973. // ggml_permute
  2974. struct ggml_tensor * ggml_permute(
  2975. struct ggml_context * ctx,
  2976. struct ggml_tensor * a,
  2977. int axis0,
  2978. int axis1,
  2979. int axis2,
  2980. int axis3) {
  2981. GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
  2982. GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
  2983. GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
  2984. GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
  2985. GGML_ASSERT(axis0 != axis1);
  2986. GGML_ASSERT(axis0 != axis2);
  2987. GGML_ASSERT(axis0 != axis3);
  2988. GGML_ASSERT(axis1 != axis2);
  2989. GGML_ASSERT(axis1 != axis3);
  2990. GGML_ASSERT(axis2 != axis3);
  2991. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  2992. ggml_format_name(result, "%s (permuted)", a->name);
  2993. int ne[GGML_MAX_DIMS];
  2994. int nb[GGML_MAX_DIMS];
  2995. ne[axis0] = a->ne[0];
  2996. ne[axis1] = a->ne[1];
  2997. ne[axis2] = a->ne[2];
  2998. ne[axis3] = a->ne[3];
  2999. nb[axis0] = a->nb[0];
  3000. nb[axis1] = a->nb[1];
  3001. nb[axis2] = a->nb[2];
  3002. nb[axis3] = a->nb[3];
  3003. result->ne[0] = ne[0];
  3004. result->ne[1] = ne[1];
  3005. result->ne[2] = ne[2];
  3006. result->ne[3] = ne[3];
  3007. result->nb[0] = nb[0];
  3008. result->nb[1] = nb[1];
  3009. result->nb[2] = nb[2];
  3010. result->nb[3] = nb[3];
  3011. result->op = GGML_OP_PERMUTE;
  3012. result->src[0] = a;
  3013. int32_t params[] = { axis0, axis1, axis2, axis3 };
  3014. ggml_set_op_params(result, params, sizeof(params));
  3015. return result;
  3016. }
  3017. // ggml_transpose
  3018. struct ggml_tensor * ggml_transpose(
  3019. struct ggml_context * ctx,
  3020. struct ggml_tensor * a) {
  3021. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3022. ggml_format_name(result, "%s (transposed)", a->name);
  3023. result->ne[0] = a->ne[1];
  3024. result->ne[1] = a->ne[0];
  3025. result->nb[0] = a->nb[1];
  3026. result->nb[1] = a->nb[0];
  3027. result->op = GGML_OP_TRANSPOSE;
  3028. result->src[0] = a;
  3029. return result;
  3030. }
  3031. // ggml_get_rows
  3032. struct ggml_tensor * ggml_get_rows(
  3033. struct ggml_context * ctx,
  3034. struct ggml_tensor * a,
  3035. struct ggml_tensor * b) {
  3036. GGML_ASSERT(a->ne[2] == b->ne[1]);
  3037. GGML_ASSERT(a->ne[3] == b->ne[2]);
  3038. GGML_ASSERT(b->ne[3] == 1);
  3039. GGML_ASSERT(b->type == GGML_TYPE_I32);
  3040. // TODO: implement non F32 return
  3041. enum ggml_type type = GGML_TYPE_F32;
  3042. if (a->type == GGML_TYPE_I32) {
  3043. type = a->type;
  3044. }
  3045. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
  3046. result->op = GGML_OP_GET_ROWS;
  3047. result->src[0] = a;
  3048. result->src[1] = b;
  3049. return result;
  3050. }
  3051. // ggml_get_rows_back
  3052. struct ggml_tensor * ggml_get_rows_back(
  3053. struct ggml_context * ctx,
  3054. struct ggml_tensor * a,
  3055. struct ggml_tensor * b,
  3056. struct ggml_tensor * c) {
  3057. GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
  3058. GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
  3059. // TODO: implement non F32 return
  3060. //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
  3061. struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
  3062. result->op = GGML_OP_GET_ROWS_BACK;
  3063. result->src[0] = a;
  3064. result->src[1] = b;
  3065. return result;
  3066. }
  3067. // ggml_set_rows
  3068. struct ggml_tensor * ggml_set_rows(
  3069. struct ggml_context * ctx,
  3070. struct ggml_tensor * a,
  3071. struct ggml_tensor * b,
  3072. struct ggml_tensor * c) {
  3073. GGML_ASSERT(a->ne[0] == b->ne[0]);
  3074. GGML_ASSERT(a->ne[2] == b->ne[2]);
  3075. GGML_ASSERT(a->ne[3] == b->ne[3]);
  3076. GGML_ASSERT(b->ne[1] == c->ne[0]);
  3077. GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
  3078. GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
  3079. GGML_ASSERT(c->ne[3] == 1);
  3080. GGML_ASSERT(b->type == GGML_TYPE_F32);
  3081. GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
  3082. GGML_ASSERT(ggml_is_contiguous_rows(a));
  3083. GGML_ASSERT(ggml_is_contiguous_rows(b));
  3084. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3085. result->op = GGML_OP_SET_ROWS;
  3086. result->src[0] = b;
  3087. result->src[1] = c;
  3088. result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
  3089. return result;
  3090. }
  3091. // ggml_diag
  3092. struct ggml_tensor * ggml_diag(
  3093. struct ggml_context * ctx,
  3094. struct ggml_tensor * a) {
  3095. GGML_ASSERT(a->ne[1] == 1);
  3096. const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
  3097. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);
  3098. result->op = GGML_OP_DIAG;
  3099. result->src[0] = a;
  3100. return result;
  3101. }
  3102. // ggml_diag_mask_inf
  3103. static struct ggml_tensor * ggml_diag_mask_inf_impl(
  3104. struct ggml_context * ctx,
  3105. struct ggml_tensor * a,
  3106. int n_past,
  3107. bool inplace) {
  3108. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3109. int32_t params[] = { n_past };
  3110. ggml_set_op_params(result, params, sizeof(params));
  3111. result->op = GGML_OP_DIAG_MASK_INF;
  3112. result->src[0] = a;
  3113. return result;
  3114. }
  3115. struct ggml_tensor * ggml_diag_mask_inf(
  3116. struct ggml_context * ctx,
  3117. struct ggml_tensor * a,
  3118. int n_past) {
  3119. return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
  3120. }
  3121. struct ggml_tensor * ggml_diag_mask_inf_inplace(
  3122. struct ggml_context * ctx,
  3123. struct ggml_tensor * a,
  3124. int n_past) {
  3125. return ggml_diag_mask_inf_impl(ctx, a, n_past, true);
  3126. }
  3127. // ggml_diag_mask_zero
  3128. static struct ggml_tensor * ggml_diag_mask_zero_impl(
  3129. struct ggml_context * ctx,
  3130. struct ggml_tensor * a,
  3131. int n_past,
  3132. bool inplace) {
  3133. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3134. int32_t params[] = { n_past };
  3135. ggml_set_op_params(result, params, sizeof(params));
  3136. result->op = GGML_OP_DIAG_MASK_ZERO;
  3137. result->src[0] = a;
  3138. return result;
  3139. }
  3140. struct ggml_tensor * ggml_diag_mask_zero(
  3141. struct ggml_context * ctx,
  3142. struct ggml_tensor * a,
  3143. int n_past) {
  3144. return ggml_diag_mask_zero_impl(ctx, a, n_past, false);
  3145. }
  3146. struct ggml_tensor * ggml_diag_mask_zero_inplace(
  3147. struct ggml_context * ctx,
  3148. struct ggml_tensor * a,
  3149. int n_past) {
  3150. return ggml_diag_mask_zero_impl(ctx, a, n_past, true);
  3151. }
  3152. // ggml_soft_max
  3153. static struct ggml_tensor * ggml_soft_max_impl(
  3154. struct ggml_context * ctx,
  3155. struct ggml_tensor * a,
  3156. struct ggml_tensor * mask,
  3157. float scale,
  3158. float max_bias,
  3159. bool inplace) {
  3160. GGML_ASSERT(ggml_is_contiguous(a));
  3161. if (mask) {
  3162. GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
  3163. GGML_ASSERT(ggml_is_contiguous(mask));
  3164. GGML_ASSERT(mask->ne[0] == a->ne[0]);
  3165. GGML_ASSERT(mask->ne[1] >= a->ne[1]);
  3166. GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
  3167. GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
  3168. }
  3169. if (max_bias > 0.0f) {
  3170. GGML_ASSERT(mask);
  3171. }
  3172. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3173. float params[] = { scale, max_bias };
  3174. ggml_set_op_params(result, params, sizeof(params));
  3175. result->op = GGML_OP_SOFT_MAX;
  3176. result->src[0] = a;
  3177. result->src[1] = mask;
  3178. return result;
  3179. }
  3180. struct ggml_tensor * ggml_soft_max(
  3181. struct ggml_context * ctx,
  3182. struct ggml_tensor * a) {
  3183. return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
  3184. }
  3185. struct ggml_tensor * ggml_soft_max_inplace(
  3186. struct ggml_context * ctx,
  3187. struct ggml_tensor * a) {
  3188. return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
  3189. }
  3190. struct ggml_tensor * ggml_soft_max_ext(
  3191. struct ggml_context * ctx,
  3192. struct ggml_tensor * a,
  3193. struct ggml_tensor * mask,
  3194. float scale,
  3195. float max_bias) {
  3196. return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
  3197. }
  3198. struct ggml_tensor * ggml_soft_max_ext_inplace(
  3199. struct ggml_context * ctx,
  3200. struct ggml_tensor * a,
  3201. struct ggml_tensor * mask,
  3202. float scale,
  3203. float max_bias) {
  3204. return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
  3205. }
  3206. void ggml_soft_max_add_sinks(
  3207. struct ggml_tensor * a,
  3208. struct ggml_tensor * sinks) {
  3209. if (!sinks) {
  3210. a->src[2] = NULL;
  3211. return;
  3212. }
  3213. GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
  3214. GGML_ASSERT(a->src[2] == NULL);
  3215. GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
  3216. GGML_ASSERT(sinks->type == GGML_TYPE_F32);
  3217. a->src[2] = sinks;
  3218. }
  3219. // ggml_soft_max_ext_back
  3220. static struct ggml_tensor * ggml_soft_max_ext_back_impl(
  3221. struct ggml_context * ctx,
  3222. struct ggml_tensor * a,
  3223. struct ggml_tensor * b,
  3224. float scale,
  3225. float max_bias,
  3226. bool inplace) {
  3227. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3228. result->op = GGML_OP_SOFT_MAX_BACK;
  3229. result->src[0] = a;
  3230. result->src[1] = b;
  3231. memcpy((float *) result->op_params + 0, &scale, sizeof(float));
  3232. memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
  3233. return result;
  3234. }
  3235. struct ggml_tensor * ggml_soft_max_ext_back(
  3236. struct ggml_context * ctx,
  3237. struct ggml_tensor * a,
  3238. struct ggml_tensor * b,
  3239. float scale,
  3240. float max_bias) {
  3241. return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
  3242. }
  3243. struct ggml_tensor * ggml_soft_max_ext_back_inplace(
  3244. struct ggml_context * ctx,
  3245. struct ggml_tensor * a,
  3246. struct ggml_tensor * b,
  3247. float scale,
  3248. float max_bias) {
  3249. return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
  3250. }
  3251. // ggml_rope
  3252. static struct ggml_tensor * ggml_rope_impl(
  3253. struct ggml_context * ctx,
  3254. struct ggml_tensor * a,
  3255. struct ggml_tensor * b,
  3256. struct ggml_tensor * c,
  3257. int n_dims,
  3258. int sections[GGML_MROPE_SECTIONS],
  3259. int mode,
  3260. int n_ctx_orig,
  3261. float freq_base,
  3262. float freq_scale,
  3263. float ext_factor,
  3264. float attn_factor,
  3265. float beta_fast,
  3266. float beta_slow,
  3267. bool inplace) {
  3268. GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
  3269. GGML_ASSERT(ggml_is_vector(b));
  3270. GGML_ASSERT(b->type == GGML_TYPE_I32);
  3271. bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
  3272. if (mrope_used) {
  3273. GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
  3274. } else {
  3275. GGML_ASSERT(a->ne[2] == b->ne[0]);
  3276. }
  3277. if (c) {
  3278. GGML_ASSERT(c->type == GGML_TYPE_F32);
  3279. GGML_ASSERT(c->ne[0] >= n_dims / 2);
  3280. }
  3281. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3282. int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
  3283. memcpy(params + 5, &freq_base, sizeof(float));
  3284. memcpy(params + 6, &freq_scale, sizeof(float));
  3285. memcpy(params + 7, &ext_factor, sizeof(float));
  3286. memcpy(params + 8, &attn_factor, sizeof(float));
  3287. memcpy(params + 9, &beta_fast, sizeof(float));
  3288. memcpy(params + 10, &beta_slow, sizeof(float));
  3289. if (mrope_used && sections) {
  3290. memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
  3291. } else {
  3292. memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
  3293. }
  3294. ggml_set_op_params(result, params, sizeof(params));
  3295. result->op = GGML_OP_ROPE;
  3296. result->src[0] = a;
  3297. result->src[1] = b;
  3298. result->src[2] = c;
  3299. return result;
  3300. }
  3301. struct ggml_tensor * ggml_rope(
  3302. struct ggml_context * ctx,
  3303. struct ggml_tensor * a,
  3304. struct ggml_tensor * b,
  3305. int n_dims,
  3306. int mode) {
  3307. return ggml_rope_impl(
  3308. ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
  3309. );
  3310. }
  3311. struct ggml_tensor * ggml_rope_multi(
  3312. struct ggml_context * ctx,
  3313. struct ggml_tensor * a,
  3314. struct ggml_tensor * b,
  3315. struct ggml_tensor * c,
  3316. int n_dims,
  3317. int sections[GGML_MROPE_SECTIONS],
  3318. int mode,
  3319. int n_ctx_orig,
  3320. float freq_base,
  3321. float freq_scale,
  3322. float ext_factor,
  3323. float attn_factor,
  3324. float beta_fast,
  3325. float beta_slow) {
  3326. return ggml_rope_impl(
  3327. ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
  3328. ext_factor, attn_factor, beta_fast, beta_slow, false
  3329. );
  3330. }
  3331. struct ggml_tensor * ggml_rope_multi_inplace(
  3332. struct ggml_context * ctx,
  3333. struct ggml_tensor * a,
  3334. struct ggml_tensor * b,
  3335. struct ggml_tensor * c,
  3336. int n_dims,
  3337. int sections[GGML_MROPE_SECTIONS],
  3338. int mode,
  3339. int n_ctx_orig,
  3340. float freq_base,
  3341. float freq_scale,
  3342. float ext_factor,
  3343. float attn_factor,
  3344. float beta_fast,
  3345. float beta_slow) {
  3346. return ggml_rope_impl(
  3347. ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
  3348. ext_factor, attn_factor, beta_fast, beta_slow, true
  3349. );
  3350. }
  3351. struct ggml_tensor * ggml_rope_inplace(
  3352. struct ggml_context * ctx,
  3353. struct ggml_tensor * a,
  3354. struct ggml_tensor * b,
  3355. int n_dims,
  3356. int mode) {
  3357. return ggml_rope_impl(
  3358. ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
  3359. );
  3360. }
  3361. struct ggml_tensor * ggml_rope_ext(
  3362. struct ggml_context * ctx,
  3363. struct ggml_tensor * a,
  3364. struct ggml_tensor * b,
  3365. struct ggml_tensor * c,
  3366. int n_dims,
  3367. int mode,
  3368. int n_ctx_orig,
  3369. float freq_base,
  3370. float freq_scale,
  3371. float ext_factor,
  3372. float attn_factor,
  3373. float beta_fast,
  3374. float beta_slow) {
  3375. return ggml_rope_impl(
  3376. ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3377. ext_factor, attn_factor, beta_fast, beta_slow, false
  3378. );
  3379. }
  3380. struct ggml_tensor * ggml_rope_ext_inplace(
  3381. struct ggml_context * ctx,
  3382. struct ggml_tensor * a,
  3383. struct ggml_tensor * b,
  3384. struct ggml_tensor * c,
  3385. int n_dims,
  3386. int mode,
  3387. int n_ctx_orig,
  3388. float freq_base,
  3389. float freq_scale,
  3390. float ext_factor,
  3391. float attn_factor,
  3392. float beta_fast,
  3393. float beta_slow) {
  3394. return ggml_rope_impl(
  3395. ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3396. ext_factor, attn_factor, beta_fast, beta_slow, true
  3397. );
  3398. }
  3399. struct ggml_tensor * ggml_rope_custom(
  3400. struct ggml_context * ctx,
  3401. struct ggml_tensor * a,
  3402. struct ggml_tensor * b,
  3403. int n_dims,
  3404. int mode,
  3405. int n_ctx_orig,
  3406. float freq_base,
  3407. float freq_scale,
  3408. float ext_factor,
  3409. float attn_factor,
  3410. float beta_fast,
  3411. float beta_slow) {
  3412. return ggml_rope_impl(
  3413. ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3414. ext_factor, attn_factor, beta_fast, beta_slow, false
  3415. );
  3416. }
  3417. struct ggml_tensor * ggml_rope_custom_inplace(
  3418. struct ggml_context * ctx,
  3419. struct ggml_tensor * a,
  3420. struct ggml_tensor * b,
  3421. int n_dims,
  3422. int mode,
  3423. int n_ctx_orig,
  3424. float freq_base,
  3425. float freq_scale,
  3426. float ext_factor,
  3427. float attn_factor,
  3428. float beta_fast,
  3429. float beta_slow) {
  3430. return ggml_rope_impl(
  3431. ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3432. ext_factor, attn_factor, beta_fast, beta_slow, true
  3433. );
  3434. }
  3435. // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
  3436. // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
  3437. static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
  3438. return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
  3439. }
  3440. void ggml_rope_yarn_corr_dims(
  3441. int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
  3442. ) {
  3443. // start and end correction dims
  3444. float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
  3445. float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
  3446. dims[0] = MAX(0, start);
  3447. dims[1] = MIN(n_dims - 1, end);
  3448. }
  3449. // ggml_rope_back
  3450. struct ggml_tensor * ggml_rope_ext_back(
  3451. struct ggml_context * ctx,
  3452. struct ggml_tensor * a,
  3453. struct ggml_tensor * b,
  3454. struct ggml_tensor * c,
  3455. int n_dims,
  3456. int mode,
  3457. int n_ctx_orig,
  3458. float freq_base,
  3459. float freq_scale,
  3460. float ext_factor,
  3461. float attn_factor,
  3462. float beta_fast,
  3463. float beta_slow) {
  3464. struct ggml_tensor * result = ggml_rope_ext(
  3465. ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
  3466. result->op = GGML_OP_ROPE_BACK;
  3467. return result;
  3468. }
  3469. struct ggml_tensor * ggml_rope_multi_back(
  3470. struct ggml_context * ctx,
  3471. struct ggml_tensor * a,
  3472. struct ggml_tensor * b,
  3473. struct ggml_tensor * c,
  3474. int n_dims,
  3475. int sections[4],
  3476. int mode,
  3477. int n_ctx_orig,
  3478. float freq_base,
  3479. float freq_scale,
  3480. float ext_factor,
  3481. float attn_factor,
  3482. float beta_fast,
  3483. float beta_slow) {
  3484. struct ggml_tensor * result = ggml_rope_multi(
  3485. ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
  3486. result->op = GGML_OP_ROPE_BACK;
  3487. return result;
  3488. }
  3489. // ggml_clamp
  3490. struct ggml_tensor * ggml_clamp(
  3491. struct ggml_context * ctx,
  3492. struct ggml_tensor * a,
  3493. float min,
  3494. float max) {
  3495. // TODO: when implement backward, fix this:
  3496. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3497. float params[] = { min, max };
  3498. ggml_set_op_params(result, params, sizeof(params));
  3499. result->op = GGML_OP_CLAMP;
  3500. result->src[0] = a;
  3501. return result;
  3502. }
  3503. static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
  3504. return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
  3505. }
  3506. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  3507. // a: [OC,IC, KH, KW]
  3508. // b: [N, IC, IH, IW]
  3509. // result: [N, OH, OW, IC*KH*KW]
  3510. struct ggml_tensor * ggml_im2col(
  3511. struct ggml_context * ctx,
  3512. struct ggml_tensor * a,
  3513. struct ggml_tensor * b,
  3514. int s0,
  3515. int s1,
  3516. int p0,
  3517. int p1,
  3518. int d0,
  3519. int d1,
  3520. bool is_2D,
  3521. enum ggml_type dst_type) {
  3522. if (is_2D) {
  3523. GGML_ASSERT(a->ne[2] == b->ne[2]);
  3524. } else {
  3525. //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
  3526. GGML_ASSERT(b->ne[1] == a->ne[1]);
  3527. GGML_ASSERT(b->ne[3] == 1);
  3528. }
  3529. const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
  3530. const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
  3531. GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
  3532. GGML_ASSERT((OW > 0) && "b too small compared to a");
  3533. const int64_t ne[4] = {
  3534. is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
  3535. OW,
  3536. is_2D ? OH : b->ne[2],
  3537. is_2D ? b->ne[3] : 1,
  3538. };
  3539. struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
  3540. int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
  3541. ggml_set_op_params(result, params, sizeof(params));
  3542. result->op = GGML_OP_IM2COL;
  3543. result->src[0] = a;
  3544. result->src[1] = b;
  3545. return result;
  3546. }
  3547. struct ggml_tensor * ggml_im2col_back(
  3548. struct ggml_context * ctx,
  3549. struct ggml_tensor * a,
  3550. struct ggml_tensor * b,
  3551. int64_t * ne,
  3552. int s0,
  3553. int s1,
  3554. int p0,
  3555. int p1,
  3556. int d0,
  3557. int d1,
  3558. bool is_2D) {
  3559. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3560. int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
  3561. ggml_set_op_params(result, params, sizeof(params));
  3562. result->op = GGML_OP_IM2COL_BACK;
  3563. result->src[0] = a;
  3564. result->src[1] = b;
  3565. return result;
  3566. }
  3567. // ggml_conv_1d
  3568. struct ggml_tensor * ggml_conv_1d(
  3569. struct ggml_context * ctx,
  3570. struct ggml_tensor * a,
  3571. struct ggml_tensor * b,
  3572. int s0,
  3573. int p0,
  3574. int d0) {
  3575. struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
  3576. struct ggml_tensor * result =
  3577. ggml_mul_mat(ctx,
  3578. ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
  3579. ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
  3580. result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
  3581. return result;
  3582. }
  3583. // ggml_conv_1d_ph
  3584. struct ggml_tensor* ggml_conv_1d_ph(
  3585. struct ggml_context * ctx,
  3586. struct ggml_tensor * a,
  3587. struct ggml_tensor * b,
  3588. int s,
  3589. int d) {
  3590. return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
  3591. }
  3592. // ggml_conv_1d_dw
  3593. struct ggml_tensor * ggml_conv_1d_dw(
  3594. struct ggml_context * ctx,
  3595. struct ggml_tensor * a,
  3596. struct ggml_tensor * b,
  3597. int s0,
  3598. int p0,
  3599. int d0) {
  3600. struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
  3601. struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
  3602. struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
  3603. result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
  3604. return result;
  3605. }
  3606. // ggml_conv_1d_dw_ph
  3607. struct ggml_tensor * ggml_conv_1d_dw_ph(
  3608. struct ggml_context * ctx,
  3609. struct ggml_tensor * a,
  3610. struct ggml_tensor * b,
  3611. int s0,
  3612. int d0) {
  3613. return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
  3614. }
  3615. // ggml_conv_transpose_1d
  3616. static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
  3617. return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
  3618. }
  3619. GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
  3620. struct ggml_context * ctx,
  3621. struct ggml_tensor * a,
  3622. struct ggml_tensor * b,
  3623. int s0,
  3624. int p0,
  3625. int d0) {
  3626. GGML_ASSERT(ggml_is_matrix(b));
  3627. GGML_ASSERT(a->ne[2] == b->ne[1]);
  3628. GGML_ASSERT(a->ne[3] == 1);
  3629. GGML_ASSERT(p0 == 0);
  3630. GGML_ASSERT(d0 == 1);
  3631. const int64_t ne[4] = {
  3632. ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
  3633. a->ne[1], b->ne[2], 1,
  3634. };
  3635. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3636. int32_t params[] = { s0, p0, d0 };
  3637. ggml_set_op_params(result, params, sizeof(params));
  3638. result->op = GGML_OP_CONV_TRANSPOSE_1D;
  3639. result->src[0] = a;
  3640. result->src[1] = b;
  3641. return result;
  3642. }
  3643. // ggml_conv_2d
  3644. // a: [OC,IC, KH, KW]
  3645. // b: [N, IC, IH, IW]
  3646. // result: [N, OC, OH, OW]
  3647. struct ggml_tensor * ggml_conv_2d(
  3648. struct ggml_context * ctx,
  3649. struct ggml_tensor * a,
  3650. struct ggml_tensor * b,
  3651. int s0,
  3652. int s1,
  3653. int p0,
  3654. int p1,
  3655. int d0,
  3656. int d1) {
  3657. struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
  3658. struct ggml_tensor * result =
  3659. ggml_mul_mat(ctx,
  3660. ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
  3661. ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
  3662. result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
  3663. result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
  3664. return result;
  3665. }
  3666. // a: [OC*IC, KD, KH, KW]
  3667. // b: [N*IC, ID, IH, IW]
  3668. // result: [N*OD, OH, OW, IC * KD * KH * KW]
  3669. struct ggml_tensor * ggml_im2col_3d(
  3670. struct ggml_context * ctx,
  3671. struct ggml_tensor * a,
  3672. struct ggml_tensor * b,
  3673. int64_t IC,
  3674. int s0, // stride width
  3675. int s1, // stride height
  3676. int s2, // stride depth
  3677. int p0, // padding width
  3678. int p1, // padding height
  3679. int p2, // padding depth
  3680. int d0, // dilation width
  3681. int d1, // dilation height
  3682. int d2, // dilation depth
  3683. enum ggml_type dst_type) {
  3684. const int64_t N = b->ne[3] / IC;
  3685. const int64_t ID = b->ne[2];
  3686. const int64_t IH = b->ne[1];
  3687. const int64_t IW = b->ne[0];
  3688. const int64_t OC = a->ne[3] / IC;
  3689. UNUSED(OC);
  3690. const int64_t KD = a->ne[2];
  3691. const int64_t KH = a->ne[1];
  3692. const int64_t KW = a->ne[0];
  3693. const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
  3694. const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
  3695. const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
  3696. GGML_ASSERT((OD > 0) && "b too small compared to a");
  3697. GGML_ASSERT((OH > 0) && "b too small compared to a");
  3698. GGML_ASSERT((OW > 0) && "b too small compared to a");
  3699. const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
  3700. struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
  3701. int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
  3702. ggml_set_op_params(result, params, sizeof(params));
  3703. result->op = GGML_OP_IM2COL_3D;
  3704. result->src[0] = a;
  3705. result->src[1] = b;
  3706. return result;
  3707. }
  3708. // a: [OC*IC, KD, KH, KW]
  3709. // b: [N*IC, ID, IH, IW]
  3710. // result: [N*OC, OD, OH, OW]
  3711. struct ggml_tensor * ggml_conv_3d(
  3712. struct ggml_context * ctx,
  3713. struct ggml_tensor * a,
  3714. struct ggml_tensor * b,
  3715. int64_t IC,
  3716. int s0, // stride width
  3717. int s1, // stride height
  3718. int s2, // stride depth
  3719. int p0, // padding width
  3720. int p1, // padding height
  3721. int p2, // padding depth
  3722. int d0, // dilation width
  3723. int d1, // dilation height
  3724. int d2 // dilation depth
  3725. ) {
  3726. struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
  3727. int64_t OC = a->ne[3] / IC;
  3728. int64_t N = b->ne[3] / IC;
  3729. struct ggml_tensor * result =
  3730. ggml_mul_mat(ctx,
  3731. ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
  3732. ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
  3733. int64_t OD = im2col->ne[3] / N;
  3734. result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
  3735. result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
  3736. result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
  3737. return result;
  3738. }
  3739. // ggml_conv_2d_sk_p0
  3740. struct ggml_tensor * ggml_conv_2d_sk_p0(
  3741. struct ggml_context * ctx,
  3742. struct ggml_tensor * a,
  3743. struct ggml_tensor * b) {
  3744. return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
  3745. }
  3746. // ggml_conv_2d_s1_ph
  3747. struct ggml_tensor * ggml_conv_2d_s1_ph(
  3748. struct ggml_context * ctx,
  3749. struct ggml_tensor * a,
  3750. struct ggml_tensor * b) {
  3751. return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
  3752. }
  3753. // ggml_conv_2d_dw
  3754. struct ggml_tensor * ggml_conv_2d_dw(
  3755. struct ggml_context * ctx,
  3756. struct ggml_tensor * a,
  3757. struct ggml_tensor * b,
  3758. int s0,
  3759. int s1,
  3760. int p0,
  3761. int p1,
  3762. int d0,
  3763. int d1) {
  3764. struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
  3765. struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
  3766. ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
  3767. s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
  3768. struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
  3769. new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
  3770. struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
  3771. result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
  3772. return result;
  3773. }
  3774. // ggml_conv_2d_dw_direct
  3775. struct ggml_tensor * ggml_conv_2d_dw_direct(
  3776. struct ggml_context * ctx,
  3777. struct ggml_tensor * a,
  3778. struct ggml_tensor * b,
  3779. int stride0,
  3780. int stride1,
  3781. int pad0,
  3782. int pad1,
  3783. int dilation0,
  3784. int dilation1) {
  3785. GGML_ASSERT(a->ne[2] == 1);
  3786. GGML_ASSERT(a->ne[3] == b->ne[2]);
  3787. int64_t ne[4];
  3788. ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
  3789. ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
  3790. ne[2] = b->ne[2];
  3791. ne[3] = b->ne[3];
  3792. struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
  3793. if (ggml_is_contiguous_channels(b)) {
  3794. // Result will be permuted the same way as input (CWHN order)
  3795. const int64_t type_size = ggml_type_size(result->type);
  3796. GGML_ASSERT(ggml_blck_size(result->type) == 1);
  3797. result->nb[0] = result->ne[2] * type_size;
  3798. result->nb[1] = result->ne[0] * result->nb[0];
  3799. result->nb[2] = type_size;
  3800. }
  3801. int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
  3802. ggml_set_op_params(result, params, sizeof(params));
  3803. result->op = GGML_OP_CONV_2D_DW;
  3804. result->src[0] = a;
  3805. result->src[1] = b;
  3806. return result;
  3807. }
  3808. // ggml_conv_2d_direct
  3809. struct ggml_tensor * ggml_conv_2d_direct(
  3810. struct ggml_context * ctx,
  3811. struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
  3812. struct ggml_tensor * b, // input data [W, H, C, N]
  3813. int s0, // stride dimension 0
  3814. int s1, // stride dimension 1
  3815. int p0, // padding dimension 0
  3816. int p1, // padding dimension 1
  3817. int d0, // dilation dimension 0
  3818. int d1) {// dilation dimension 1
  3819. GGML_ASSERT(a->ne[2] == b->ne[2]);
  3820. //GGML_ASSERT(a->type == b->type);
  3821. int64_t ne[4];
  3822. ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
  3823. ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
  3824. ne[2] = a->ne[3];
  3825. ne[3] = b->ne[3];
  3826. struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
  3827. ggml_set_op_params_i32(result, 0, s0);
  3828. ggml_set_op_params_i32(result, 1, s1);
  3829. ggml_set_op_params_i32(result, 2, p0);
  3830. ggml_set_op_params_i32(result, 3, p1);
  3831. ggml_set_op_params_i32(result, 4, d0);
  3832. ggml_set_op_params_i32(result, 5, d1);
  3833. result->op = GGML_OP_CONV_2D;
  3834. result->src[0] = a;
  3835. result->src[1] = b;
  3836. return result;
  3837. }
  3838. // ggml_conv_3d_direct
  3839. struct ggml_tensor * ggml_conv_3d_direct(
  3840. struct ggml_context * ctx,
  3841. struct ggml_tensor * a,
  3842. struct ggml_tensor * b,
  3843. int s0,
  3844. int s1,
  3845. int s2,
  3846. int p0,
  3847. int p1,
  3848. int p2,
  3849. int d0,
  3850. int d1,
  3851. int d2,
  3852. int c,
  3853. int n,
  3854. int oc) {
  3855. GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
  3856. GGML_ASSERT(b->ne[3] == (int64_t) c * n);
  3857. int64_t ne[4];
  3858. ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
  3859. ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
  3860. ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
  3861. ne[3] = (int64_t) oc * n;
  3862. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3863. ggml_set_op_params_i32(result, 0, s0);
  3864. ggml_set_op_params_i32(result, 1, s1);
  3865. ggml_set_op_params_i32(result, 2, s2);
  3866. ggml_set_op_params_i32(result, 3, p0);
  3867. ggml_set_op_params_i32(result, 4, p1);
  3868. ggml_set_op_params_i32(result, 5, p2);
  3869. ggml_set_op_params_i32(result, 6, d0);
  3870. ggml_set_op_params_i32(result, 7, d1);
  3871. ggml_set_op_params_i32(result, 8, d2);
  3872. ggml_set_op_params_i32(result, 9, c);
  3873. ggml_set_op_params_i32(result, 10, n);
  3874. ggml_set_op_params_i32(result, 11, oc);
  3875. result->op = GGML_OP_CONV_3D;
  3876. result->src[0] = a;
  3877. result->src[1] = b;
  3878. return result;
  3879. }
  3880. // ggml_conv_transpose_2d_p0
  3881. static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
  3882. return (ins - 1) * s - 2 * p + ks;
  3883. }
  3884. struct ggml_tensor * ggml_conv_transpose_2d_p0(
  3885. struct ggml_context * ctx,
  3886. struct ggml_tensor * a,
  3887. struct ggml_tensor * b,
  3888. int stride) {
  3889. GGML_ASSERT(a->ne[3] == b->ne[2]);
  3890. const int64_t ne[4] = {
  3891. ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
  3892. ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
  3893. a->ne[2], b->ne[3],
  3894. };
  3895. struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3896. ggml_set_op_params_i32(result, 0, stride);
  3897. result->op = GGML_OP_CONV_TRANSPOSE_2D;
  3898. result->src[0] = a;
  3899. result->src[1] = b;
  3900. return result;
  3901. }
  3902. // ggml_pool_*
  3903. static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
  3904. return (ins + 2 * p - ks) / s + 1;
  3905. }
  3906. // ggml_pool_1d
  3907. struct ggml_tensor * ggml_pool_1d(
  3908. struct ggml_context * ctx,
  3909. struct ggml_tensor * a,
  3910. enum ggml_op_pool op,
  3911. int k0,
  3912. int s0,
  3913. int p0) {
  3914. const int64_t ne[4] = {
  3915. ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
  3916. a->ne[1],
  3917. a->ne[2],
  3918. a->ne[3],
  3919. };
  3920. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3921. int32_t params[] = { op, k0, s0, p0 };
  3922. ggml_set_op_params(result, params, sizeof(params));
  3923. result->op = GGML_OP_POOL_1D;
  3924. result->src[0] = a;
  3925. return result;
  3926. }
  3927. // ggml_pool_2d
  3928. struct ggml_tensor * ggml_pool_2d(
  3929. struct ggml_context * ctx,
  3930. struct ggml_tensor * a,
  3931. enum ggml_op_pool op,
  3932. int k0,
  3933. int k1,
  3934. int s0,
  3935. int s1,
  3936. float p0,
  3937. float p1) {
  3938. struct ggml_tensor * result;
  3939. const int64_t ne[4] = {
  3940. ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
  3941. ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
  3942. a->ne[2],
  3943. a->ne[3],
  3944. };
  3945. result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3946. int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
  3947. ggml_set_op_params(result, params, sizeof(params));
  3948. result->op = GGML_OP_POOL_2D;
  3949. result->src[0] = a;
  3950. return result;
  3951. }
  3952. struct ggml_tensor * ggml_pool_2d_back(
  3953. struct ggml_context * ctx,
  3954. struct ggml_tensor * a,
  3955. struct ggml_tensor * af,
  3956. enum ggml_op_pool op,
  3957. int k0,
  3958. int k1,
  3959. int s0,
  3960. int s1,
  3961. float p0,
  3962. float p1) {
  3963. struct ggml_tensor * result;
  3964. result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);
  3965. int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
  3966. ggml_set_op_params(result, params, sizeof(params));
  3967. result->op = GGML_OP_POOL_2D_BACK;
  3968. result->src[0] = a;
  3969. result->src[1] = af;
  3970. return result;
  3971. }
  3972. // ggml_upscale / ggml_interpolate
  3973. static struct ggml_tensor * ggml_interpolate_impl(
  3974. struct ggml_context * ctx,
  3975. struct ggml_tensor * a,
  3976. int64_t ne0,
  3977. int64_t ne1,
  3978. int64_t ne2,
  3979. int64_t ne3,
  3980. uint32_t mode) {
  3981. GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
  3982. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
  3983. ggml_set_op_params_i32(result, 0, (int32_t)mode);
  3984. result->op = GGML_OP_UPSCALE;
  3985. result->src[0] = a;
  3986. return result;
  3987. }
  3988. struct ggml_tensor * ggml_upscale(
  3989. struct ggml_context * ctx,
  3990. struct ggml_tensor * a,
  3991. int scale_factor,
  3992. enum ggml_scale_mode mode) {
  3993. GGML_ASSERT(scale_factor > 1);
  3994. return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
  3995. }
  3996. struct ggml_tensor * ggml_upscale_ext(
  3997. struct ggml_context * ctx,
  3998. struct ggml_tensor * a,
  3999. int ne0,
  4000. int ne1,
  4001. int ne2,
  4002. int ne3,
  4003. enum ggml_scale_mode mode) {
  4004. return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
  4005. }
  4006. struct ggml_tensor * ggml_interpolate(
  4007. struct ggml_context * ctx,
  4008. struct ggml_tensor * a,
  4009. int64_t ne0,
  4010. int64_t ne1,
  4011. int64_t ne2,
  4012. int64_t ne3,
  4013. uint32_t mode) {
  4014. return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
  4015. }
  4016. // ggml_pad
  4017. struct ggml_tensor * ggml_pad(
  4018. struct ggml_context * ctx,
  4019. struct ggml_tensor * a,
  4020. int p0,
  4021. int p1,
  4022. int p2,
  4023. int p3) {
  4024. return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
  4025. }
  4026. struct ggml_tensor * ggml_pad_ext(
  4027. struct ggml_context * ctx,
  4028. struct ggml_tensor * a,
  4029. int lp0,
  4030. int rp0,
  4031. int lp1,
  4032. int rp1,
  4033. int lp2,
  4034. int rp2,
  4035. int lp3,
  4036. int rp3
  4037. ) {
  4038. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
  4039. a->ne[0] + lp0 + rp0,
  4040. a->ne[1] + lp1 + rp1,
  4041. a->ne[2] + lp2 + rp2,
  4042. a->ne[3] + lp3 + rp3);
  4043. ggml_set_op_params_i32(result, 0, lp0);
  4044. ggml_set_op_params_i32(result, 1, rp0);
  4045. ggml_set_op_params_i32(result, 2, lp1);
  4046. ggml_set_op_params_i32(result, 3, rp1);
  4047. ggml_set_op_params_i32(result, 4, lp2);
  4048. ggml_set_op_params_i32(result, 5, rp2);
  4049. ggml_set_op_params_i32(result, 6, lp3);
  4050. ggml_set_op_params_i32(result, 7, rp3);
  4051. result->op = GGML_OP_PAD;
  4052. result->src[0] = a;
  4053. return result;
  4054. }
  4055. // ggml_pad_reflect_1d
  4056. struct ggml_tensor * ggml_pad_reflect_1d(
  4057. struct ggml_context * ctx,
  4058. struct ggml_tensor * a,
  4059. int p0,
  4060. int p1) {
  4061. GGML_ASSERT(p0 >= 0);
  4062. GGML_ASSERT(p1 >= 0);
  4063. GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
  4064. GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
  4065. GGML_ASSERT(ggml_is_contiguous(a));
  4066. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4067. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
  4068. a->ne[0] + p0 + p1,
  4069. a->ne[1],
  4070. a->ne[2],
  4071. a->ne[3]);
  4072. int32_t params[] = { p0, p1 };
  4073. ggml_set_op_params(result, params, sizeof(params));
  4074. result->op = GGML_OP_PAD_REFLECT_1D;
  4075. result->src[0] = a;
  4076. return result;
  4077. }
  4078. // ggml_roll
  4079. struct ggml_tensor * ggml_roll(
  4080. struct ggml_context * ctx,
  4081. struct ggml_tensor * a,
  4082. int shift0,
  4083. int shift1,
  4084. int shift2,
  4085. int shift3) {
  4086. GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
  4087. GGML_ASSERT(abs(shift0) < a->ne[0]);
  4088. GGML_ASSERT(abs(shift1) < a->ne[1]);
  4089. GGML_ASSERT(abs(shift2) < a->ne[2]);
  4090. GGML_ASSERT(abs(shift3) < a->ne[3]);
  4091. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  4092. ggml_set_op_params_i32(result, 0, shift0);
  4093. ggml_set_op_params_i32(result, 1, shift1);
  4094. ggml_set_op_params_i32(result, 2, shift2);
  4095. ggml_set_op_params_i32(result, 3, shift3);
  4096. result->op = GGML_OP_ROLL;
  4097. result->src[0] = a;
  4098. return result;
  4099. }
  4100. // ggml_arange
  4101. struct ggml_tensor * ggml_arange(
  4102. struct ggml_context * ctx,
  4103. float start,
  4104. float stop,
  4105. float step) {
  4106. GGML_ASSERT(stop > start);
  4107. const int64_t steps = (int64_t) ceilf((stop - start) / step);
  4108. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
  4109. ggml_set_op_params_f32(result, 0, start);
  4110. ggml_set_op_params_f32(result, 1, stop);
  4111. ggml_set_op_params_f32(result, 2, step);
  4112. result->op = GGML_OP_ARANGE;
  4113. return result;
  4114. }
  4115. // ggml_timestep_embedding
  4116. struct ggml_tensor * ggml_timestep_embedding(
  4117. struct ggml_context * ctx,
  4118. struct ggml_tensor * timesteps,
  4119. int dim,
  4120. int max_period) {
  4121. struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
  4122. ggml_set_op_params_i32(result, 0, dim);
  4123. ggml_set_op_params_i32(result, 1, max_period);
  4124. result->op = GGML_OP_TIMESTEP_EMBEDDING;
  4125. result->src[0] = timesteps;
  4126. return result;
  4127. }
  4128. // ggml_argsort
  4129. struct ggml_tensor * ggml_argsort(
  4130. struct ggml_context * ctx,
  4131. struct ggml_tensor * a,
  4132. enum ggml_sort_order order) {
  4133. GGML_ASSERT(a->ne[0] <= INT32_MAX);
  4134. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
  4135. ggml_set_op_params_i32(result, 0, (int32_t) order);
  4136. result->op = GGML_OP_ARGSORT;
  4137. result->src[0] = a;
  4138. return result;
  4139. }
  4140. // ggml_top_k
  4141. struct ggml_tensor * ggml_top_k(
  4142. struct ggml_context * ctx,
  4143. struct ggml_tensor * a,
  4144. int k) {
  4145. GGML_ASSERT(a->ne[0] >= k);
  4146. struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
  4147. result = ggml_view_4d(ctx, result,
  4148. k, result->ne[1], result->ne[2], result->ne[3],
  4149. result->nb[1], result->nb[2], result->nb[3],
  4150. 0);
  4151. return result;
  4152. }
  4153. // ggml_flash_attn_ext
  4154. struct ggml_tensor * ggml_flash_attn_ext(
  4155. struct ggml_context * ctx,
  4156. struct ggml_tensor * q,
  4157. struct ggml_tensor * k,
  4158. struct ggml_tensor * v,
  4159. struct ggml_tensor * mask,
  4160. float scale,
  4161. float max_bias,
  4162. float logit_softcap) {
  4163. GGML_ASSERT(ggml_can_mul_mat(k, q));
  4164. // TODO: check if vT can be multiplied by (k*qT)
  4165. GGML_ASSERT(q->ne[3] == k->ne[3]);
  4166. GGML_ASSERT(q->ne[3] == v->ne[3]);
  4167. if (mask) {
  4168. GGML_ASSERT(ggml_is_contiguous(mask));
  4169. GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
  4170. "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
  4171. //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
  4172. GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
  4173. GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
  4174. }
  4175. if (max_bias > 0.0f) {
  4176. GGML_ASSERT(mask);
  4177. }
  4178. // permute(0, 2, 1, 3)
  4179. int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
  4180. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4181. float params[] = { scale, max_bias, logit_softcap };
  4182. ggml_set_op_params(result, params, sizeof(params));
  4183. result->op = GGML_OP_FLASH_ATTN_EXT;
  4184. result->src[0] = q;
  4185. result->src[1] = k;
  4186. result->src[2] = v;
  4187. result->src[3] = mask;
  4188. return result;
  4189. }
  4190. void ggml_flash_attn_ext_set_prec(
  4191. struct ggml_tensor * a,
  4192. enum ggml_prec prec) {
  4193. GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
  4194. const int32_t prec_i32 = (int32_t) prec;
  4195. ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
  4196. }
  4197. enum ggml_prec ggml_flash_attn_ext_get_prec(
  4198. const struct ggml_tensor * a) {
  4199. GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
  4200. const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
  4201. return (enum ggml_prec) prec_i32;
  4202. }
  4203. void ggml_flash_attn_ext_add_sinks(
  4204. struct ggml_tensor * a,
  4205. struct ggml_tensor * sinks) {
  4206. if (!sinks) {
  4207. a->src[4] = NULL;
  4208. return;
  4209. }
  4210. GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
  4211. GGML_ASSERT(a->src[4] == NULL);
  4212. GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
  4213. GGML_ASSERT(sinks->type == GGML_TYPE_F32);
  4214. a->src[4] = sinks;
  4215. }
  4216. // ggml_flash_attn_back
  4217. struct ggml_tensor * ggml_flash_attn_back(
  4218. struct ggml_context * ctx,
  4219. struct ggml_tensor * q,
  4220. struct ggml_tensor * k,
  4221. struct ggml_tensor * v,
  4222. struct ggml_tensor * d,
  4223. bool masked) {
  4224. GGML_ABORT("TODO: adapt to ggml_flash_attn_ext() changes");
  4225. GGML_ASSERT(ggml_can_mul_mat(k, q));
  4226. // TODO: check if vT can be multiplied by (k*qT)
  4227. // d shape [D,N,ne2,ne3]
  4228. // q shape [D,N,ne2,ne3]
  4229. // k shape [D,M,kvne2,ne3]
  4230. // v shape [M,D,kvne2,ne3]
  4231. const int64_t D = q->ne[0];
  4232. const int64_t N = q->ne[1];
  4233. const int64_t M = k->ne[1];
  4234. const int64_t ne2 = q->ne[2];
  4235. const int64_t ne3 = q->ne[3];
  4236. const int64_t kvne2 = k->ne[2];
  4237. GGML_ASSERT(k->ne[0] == D);
  4238. GGML_ASSERT(v->ne[0] == M);
  4239. GGML_ASSERT(v->ne[1] == D);
  4240. GGML_ASSERT(d->ne[0] == D);
  4241. GGML_ASSERT(d->ne[1] == N);
  4242. GGML_ASSERT(k->ne[2] == kvne2);
  4243. GGML_ASSERT(k->ne[3] == ne3);
  4244. GGML_ASSERT(v->ne[2] == kvne2);
  4245. GGML_ASSERT(v->ne[3] == ne3);
  4246. GGML_ASSERT(d->ne[2] == ne2);
  4247. GGML_ASSERT(d->ne[3] == ne3);
  4248. GGML_ASSERT(ne2 % kvne2 == 0);
  4249. // store gradients of q, k and v as continuous tensors concatenated in result.
  4250. // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
  4251. const int64_t elem_q = ggml_nelements(q);
  4252. const int64_t elem_k = ggml_nelements(k);
  4253. const int64_t elem_v = ggml_nelements(v);
  4254. enum ggml_type result_type = GGML_TYPE_F32;
  4255. GGML_ASSERT(ggml_blck_size(result_type) == 1);
  4256. const size_t tsize = ggml_type_size(result_type);
  4257. const size_t offs_q = 0;
  4258. const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
  4259. const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
  4260. const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
  4261. const size_t nelements = (end + tsize - 1)/tsize;
  4262. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
  4263. int32_t masked_i = masked ? 1 : 0;
  4264. ggml_set_op_params(result, &masked_i, sizeof(masked_i));
  4265. result->op = GGML_OP_FLASH_ATTN_BACK;
  4266. result->src[0] = q;
  4267. result->src[1] = k;
  4268. result->src[2] = v;
  4269. result->src[3] = d;
  4270. return result;
  4271. }
  4272. // ggml_ssm_conv
  4273. struct ggml_tensor * ggml_ssm_conv(
  4274. struct ggml_context * ctx,
  4275. struct ggml_tensor * sx,
  4276. struct ggml_tensor * c) {
  4277. GGML_ASSERT(ggml_is_3d(sx));
  4278. GGML_ASSERT(ggml_is_matrix(c));
  4279. const int64_t d_conv = c->ne[0];
  4280. const int64_t d_inner = c->ne[1];
  4281. const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
  4282. const int64_t n_s = sx->ne[2];
  4283. // TODO: maybe support other strides than 1?
  4284. GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
  4285. GGML_ASSERT(sx->ne[1] == d_inner);
  4286. GGML_ASSERT(n_t >= 0);
  4287. struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
  4288. result->op = GGML_OP_SSM_CONV;
  4289. result->src[0] = sx;
  4290. result->src[1] = c;
  4291. return result;
  4292. }
  4293. // ggml_ssm_scan
  4294. struct ggml_tensor * ggml_ssm_scan(
  4295. struct ggml_context * ctx,
  4296. struct ggml_tensor * s,
  4297. struct ggml_tensor * x,
  4298. struct ggml_tensor * dt,
  4299. struct ggml_tensor * A,
  4300. struct ggml_tensor * B,
  4301. struct ggml_tensor * C,
  4302. struct ggml_tensor * ids) {
  4303. GGML_ASSERT(ggml_is_contiguous(s));
  4304. GGML_ASSERT(ggml_is_contiguous(dt));
  4305. GGML_ASSERT(ggml_is_contiguous(A));
  4306. GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
  4307. GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
  4308. GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
  4309. GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
  4310. GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
  4311. GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
  4312. GGML_ASSERT(ggml_are_same_shape(B, C));
  4313. GGML_ASSERT(ids->type == GGML_TYPE_I32);
  4314. {
  4315. const int64_t d_state = s->ne[0];
  4316. const int64_t head_dim = x->ne[0];
  4317. const int64_t n_head = x->ne[1];
  4318. const int64_t n_seq_tokens = x->ne[2];
  4319. const int64_t n_seqs = x->ne[3];
  4320. GGML_ASSERT(dt->ne[0] == n_head);
  4321. GGML_ASSERT(dt->ne[1] == n_seq_tokens);
  4322. GGML_ASSERT(dt->ne[2] == n_seqs);
  4323. GGML_ASSERT(ggml_is_3d(dt));
  4324. GGML_ASSERT(s->ne[1] == head_dim);
  4325. GGML_ASSERT(s->ne[2] == n_head);
  4326. GGML_ASSERT(B->ne[0] == d_state);
  4327. GGML_ASSERT(B->ne[2] == n_seq_tokens);
  4328. GGML_ASSERT(B->ne[3] == n_seqs);
  4329. GGML_ASSERT(ids->ne[0] == n_seqs);
  4330. GGML_ASSERT(ggml_is_vector(ids));
  4331. GGML_ASSERT(A->ne[1] == n_head);
  4332. GGML_ASSERT(ggml_is_matrix(A));
  4333. if (A->ne[0] != 1) {
  4334. // Mamba-1 has more granular decay factors
  4335. GGML_ASSERT(A->ne[0] == d_state);
  4336. }
  4337. }
  4338. // concatenated y + ssm_states
  4339. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
  4340. result->op = GGML_OP_SSM_SCAN;
  4341. result->src[0] = s;
  4342. result->src[1] = x;
  4343. result->src[2] = dt;
  4344. result->src[3] = A;
  4345. result->src[4] = B;
  4346. result->src[5] = C;
  4347. result->src[6] = ids;
  4348. return result;
  4349. }
  4350. // ggml_win_part
  4351. struct ggml_tensor * ggml_win_part(
  4352. struct ggml_context * ctx,
  4353. struct ggml_tensor * a,
  4354. int w) {
  4355. GGML_ASSERT(a->ne[3] == 1);
  4356. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4357. // padding
  4358. const int px = (w - a->ne[1]%w)%w;
  4359. const int py = (w - a->ne[2]%w)%w;
  4360. const int npx = (px + a->ne[1])/w;
  4361. const int npy = (py + a->ne[2])/w;
  4362. const int np = npx*npy;
  4363. const int64_t ne[4] = { a->ne[0], w, w, np, };
  4364. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4365. int32_t params[] = { npx, npy, w };
  4366. ggml_set_op_params(result, params, sizeof(params));
  4367. result->op = GGML_OP_WIN_PART;
  4368. result->src[0] = a;
  4369. return result;
  4370. }
  4371. // ggml_win_unpart
  4372. struct ggml_tensor * ggml_win_unpart(
  4373. struct ggml_context * ctx,
  4374. struct ggml_tensor * a,
  4375. int w0,
  4376. int h0,
  4377. int w) {
  4378. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4379. const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
  4380. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
  4381. int32_t params[] = { w };
  4382. ggml_set_op_params(result, params, sizeof(params));
  4383. result->op = GGML_OP_WIN_UNPART;
  4384. result->src[0] = a;
  4385. return result;
  4386. }
  4387. // ggml_get_rel_pos
  4388. struct ggml_tensor * ggml_get_rel_pos(
  4389. struct ggml_context * ctx,
  4390. struct ggml_tensor * a,
  4391. int qh,
  4392. int kh) {
  4393. GGML_ASSERT(qh == kh);
  4394. GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
  4395. const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
  4396. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
  4397. result->op = GGML_OP_GET_REL_POS;
  4398. result->src[0] = a;
  4399. return result;
  4400. }
  4401. // ggml_add_rel_pos
  4402. static struct ggml_tensor * ggml_add_rel_pos_impl(
  4403. struct ggml_context * ctx,
  4404. struct ggml_tensor * a,
  4405. struct ggml_tensor * pw,
  4406. struct ggml_tensor * ph,
  4407. bool inplace) {
  4408. GGML_ASSERT(ggml_are_same_shape(pw, ph));
  4409. GGML_ASSERT(ggml_is_contiguous(a));
  4410. GGML_ASSERT(ggml_is_contiguous(pw));
  4411. GGML_ASSERT(ggml_is_contiguous(ph));
  4412. GGML_ASSERT(ph->type == GGML_TYPE_F32);
  4413. GGML_ASSERT(pw->type == GGML_TYPE_F32);
  4414. GGML_ASSERT(pw->ne[3] == a->ne[2]);
  4415. GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
  4416. GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
  4417. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4418. ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
  4419. result->op = GGML_OP_ADD_REL_POS;
  4420. result->src[0] = a;
  4421. result->src[1] = pw;
  4422. result->src[2] = ph;
  4423. return result;
  4424. }
  4425. struct ggml_tensor * ggml_add_rel_pos(
  4426. struct ggml_context * ctx,
  4427. struct ggml_tensor * a,
  4428. struct ggml_tensor * pw,
  4429. struct ggml_tensor * ph) {
  4430. return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
  4431. }
  4432. struct ggml_tensor * ggml_add_rel_pos_inplace(
  4433. struct ggml_context * ctx,
  4434. struct ggml_tensor * a,
  4435. struct ggml_tensor * pw,
  4436. struct ggml_tensor * ph) {
  4437. return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
  4438. }
  4439. // ggml_rwkv_wkv6
  4440. struct ggml_tensor * ggml_rwkv_wkv6(
  4441. struct ggml_context * ctx,
  4442. struct ggml_tensor * k,
  4443. struct ggml_tensor * v,
  4444. struct ggml_tensor * r,
  4445. struct ggml_tensor * tf,
  4446. struct ggml_tensor * td,
  4447. struct ggml_tensor * state) {
  4448. GGML_ASSERT(ggml_is_contiguous(k));
  4449. GGML_ASSERT(ggml_is_contiguous(v));
  4450. GGML_ASSERT(ggml_is_contiguous(r));
  4451. GGML_ASSERT(ggml_is_contiguous(tf));
  4452. GGML_ASSERT(ggml_is_contiguous(td));
  4453. GGML_ASSERT(ggml_is_contiguous(state));
  4454. const int64_t S = k->ne[0];
  4455. const int64_t H = k->ne[1];
  4456. const int64_t n_tokens = k->ne[2];
  4457. const int64_t n_seqs = state->ne[1];
  4458. {
  4459. GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
  4460. GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
  4461. GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
  4462. GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
  4463. }
  4464. // concat output and new_state
  4465. const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
  4466. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4467. result->op = GGML_OP_RWKV_WKV6;
  4468. result->src[0] = k;
  4469. result->src[1] = v;
  4470. result->src[2] = r;
  4471. result->src[3] = tf;
  4472. result->src[4] = td;
  4473. result->src[5] = state;
  4474. return result;
  4475. }
  4476. // ggml_gated_linear_attn
  4477. struct ggml_tensor * ggml_gated_linear_attn(
  4478. struct ggml_context * ctx,
  4479. struct ggml_tensor * k,
  4480. struct ggml_tensor * v,
  4481. struct ggml_tensor * q,
  4482. struct ggml_tensor * g,
  4483. struct ggml_tensor * state,
  4484. float scale) {
  4485. GGML_ASSERT(ggml_is_contiguous(k));
  4486. GGML_ASSERT(ggml_is_contiguous(v));
  4487. GGML_ASSERT(ggml_is_contiguous(q));
  4488. GGML_ASSERT(ggml_is_contiguous(g));
  4489. GGML_ASSERT(ggml_is_contiguous(state));
  4490. const int64_t S = k->ne[0];
  4491. const int64_t H = k->ne[1];
  4492. const int64_t n_tokens = k->ne[2];
  4493. const int64_t n_seqs = state->ne[1];
  4494. {
  4495. GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
  4496. GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
  4497. GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
  4498. GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
  4499. }
  4500. // concat output and new_state
  4501. const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
  4502. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4503. ggml_set_op_params_f32(result, 0, scale);
  4504. result->op = GGML_OP_GATED_LINEAR_ATTN;
  4505. result->src[0] = k;
  4506. result->src[1] = v;
  4507. result->src[2] = q;
  4508. result->src[3] = g;
  4509. result->src[4] = state;
  4510. return result;
  4511. }
  4512. // ggml_rwkv_wkv7
  4513. struct ggml_tensor * ggml_rwkv_wkv7(
  4514. struct ggml_context * ctx,
  4515. struct ggml_tensor * r,
  4516. struct ggml_tensor * w,
  4517. struct ggml_tensor * k,
  4518. struct ggml_tensor * v,
  4519. struct ggml_tensor * a,
  4520. struct ggml_tensor * b,
  4521. struct ggml_tensor * state) {
  4522. GGML_ASSERT(ggml_is_contiguous(r));
  4523. GGML_ASSERT(ggml_is_contiguous(w));
  4524. GGML_ASSERT(ggml_is_contiguous(k));
  4525. GGML_ASSERT(ggml_is_contiguous(v));
  4526. GGML_ASSERT(ggml_is_contiguous(a));
  4527. GGML_ASSERT(ggml_is_contiguous(b));
  4528. GGML_ASSERT(ggml_is_contiguous(state));
  4529. const int64_t S = k->ne[0];
  4530. const int64_t H = k->ne[1];
  4531. const int64_t n_tokens = k->ne[2];
  4532. const int64_t n_seqs = state->ne[1];
  4533. {
  4534. GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
  4535. GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
  4536. GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
  4537. GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
  4538. GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
  4539. GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
  4540. }
  4541. // concat output and new_state
  4542. const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
  4543. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4544. result->op = GGML_OP_RWKV_WKV7;
  4545. result->src[0] = r;
  4546. result->src[1] = w;
  4547. result->src[2] = k;
  4548. result->src[3] = v;
  4549. result->src[4] = a;
  4550. result->src[5] = b;
  4551. result->src[6] = state;
  4552. return result;
  4553. }
  4554. // ggml_unary
  4555. static struct ggml_tensor * ggml_unary_impl(
  4556. struct ggml_context * ctx,
  4557. struct ggml_tensor * a,
  4558. enum ggml_unary_op op,
  4559. bool inplace) {
  4560. GGML_ASSERT(ggml_is_contiguous_1(a));
  4561. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4562. ggml_set_op_params_i32(result, 0, (int32_t) op);
  4563. result->op = GGML_OP_UNARY;
  4564. result->src[0] = a;
  4565. return result;
  4566. }
  4567. struct ggml_tensor * ggml_unary(
  4568. struct ggml_context * ctx,
  4569. struct ggml_tensor * a,
  4570. enum ggml_unary_op op) {
  4571. return ggml_unary_impl(ctx, a, op, false);
  4572. }
  4573. struct ggml_tensor * ggml_unary_inplace(
  4574. struct ggml_context * ctx,
  4575. struct ggml_tensor * a,
  4576. enum ggml_unary_op op) {
  4577. return ggml_unary_impl(ctx, a, op, true);
  4578. }
  4579. // ggml_map_custom1
  4580. static struct ggml_tensor * ggml_map_custom1_impl(
  4581. struct ggml_context * ctx,
  4582. struct ggml_tensor * a,
  4583. const ggml_custom1_op_t fun,
  4584. int n_tasks,
  4585. void * userdata,
  4586. bool inplace) {
  4587. GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
  4588. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4589. struct ggml_map_custom1_op_params params = {
  4590. /*.fun =*/ fun,
  4591. /*.n_tasks =*/ n_tasks,
  4592. /*.userdata =*/ userdata
  4593. };
  4594. ggml_set_op_params(result, &params, sizeof(params));
  4595. result->op = GGML_OP_MAP_CUSTOM1;
  4596. result->src[0] = a;
  4597. return result;
  4598. }
  4599. struct ggml_tensor * ggml_map_custom1(
  4600. struct ggml_context * ctx,
  4601. struct ggml_tensor * a,
  4602. const ggml_custom1_op_t fun,
  4603. int n_tasks,
  4604. void * userdata) {
  4605. return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
  4606. }
  4607. struct ggml_tensor * ggml_map_custom1_inplace(
  4608. struct ggml_context * ctx,
  4609. struct ggml_tensor * a,
  4610. const ggml_custom1_op_t fun,
  4611. int n_tasks,
  4612. void * userdata) {
  4613. return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
  4614. }
  4615. // ggml_map_custom2
  4616. static struct ggml_tensor * ggml_map_custom2_impl(
  4617. struct ggml_context * ctx,
  4618. struct ggml_tensor * a,
  4619. struct ggml_tensor * b,
  4620. const ggml_custom2_op_t fun,
  4621. int n_tasks,
  4622. void * userdata,
  4623. bool inplace) {
  4624. GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
  4625. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4626. struct ggml_map_custom2_op_params params = {
  4627. /*.fun =*/ fun,
  4628. /*.n_tasks =*/ n_tasks,
  4629. /*.userdata =*/ userdata
  4630. };
  4631. ggml_set_op_params(result, &params, sizeof(params));
  4632. result->op = GGML_OP_MAP_CUSTOM2;
  4633. result->src[0] = a;
  4634. result->src[1] = b;
  4635. return result;
  4636. }
  4637. struct ggml_tensor * ggml_map_custom2(
  4638. struct ggml_context * ctx,
  4639. struct ggml_tensor * a,
  4640. struct ggml_tensor * b,
  4641. const ggml_custom2_op_t fun,
  4642. int n_tasks,
  4643. void * userdata) {
  4644. return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
  4645. }
  4646. struct ggml_tensor * ggml_map_custom2_inplace(
  4647. struct ggml_context * ctx,
  4648. struct ggml_tensor * a,
  4649. struct ggml_tensor * b,
  4650. const ggml_custom2_op_t fun,
  4651. int n_tasks,
  4652. void * userdata) {
  4653. return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
  4654. }
  4655. // ggml_map_custom3
  4656. static struct ggml_tensor * ggml_map_custom3_impl(
  4657. struct ggml_context * ctx,
  4658. struct ggml_tensor * a,
  4659. struct ggml_tensor * b,
  4660. struct ggml_tensor * c,
  4661. const ggml_custom3_op_t fun,
  4662. int n_tasks,
  4663. void * userdata,
  4664. bool inplace) {
  4665. GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
  4666. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4667. struct ggml_map_custom3_op_params params = {
  4668. /*.fun =*/ fun,
  4669. /*.n_tasks =*/ n_tasks,
  4670. /*.userdata =*/ userdata
  4671. };
  4672. ggml_set_op_params(result, &params, sizeof(params));
  4673. result->op = GGML_OP_MAP_CUSTOM3;
  4674. result->src[0] = a;
  4675. result->src[1] = b;
  4676. result->src[2] = c;
  4677. return result;
  4678. }
  4679. struct ggml_tensor * ggml_map_custom3(
  4680. struct ggml_context * ctx,
  4681. struct ggml_tensor * a,
  4682. struct ggml_tensor * b,
  4683. struct ggml_tensor * c,
  4684. const ggml_custom3_op_t fun,
  4685. int n_tasks,
  4686. void * userdata) {
  4687. return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
  4688. }
  4689. struct ggml_tensor * ggml_map_custom3_inplace(
  4690. struct ggml_context * ctx,
  4691. struct ggml_tensor * a,
  4692. struct ggml_tensor * b,
  4693. struct ggml_tensor * c,
  4694. const ggml_custom3_op_t fun,
  4695. int n_tasks,
  4696. void * userdata) {
  4697. return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
  4698. }
  4699. struct ggml_tensor * ggml_custom_4d(
  4700. struct ggml_context * ctx,
  4701. enum ggml_type type,
  4702. int64_t ne0,
  4703. int64_t ne1,
  4704. int64_t ne2,
  4705. int64_t ne3,
  4706. struct ggml_tensor ** args,
  4707. int n_args,
  4708. ggml_custom_op_t fun,
  4709. int n_tasks,
  4710. void * userdata) {
  4711. GGML_ASSERT(n_args < GGML_MAX_SRC);
  4712. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
  4713. struct ggml_custom_op_params params = {
  4714. /*.fun =*/ fun,
  4715. /*.n_tasks =*/ n_tasks,
  4716. /*.userdata =*/ userdata
  4717. };
  4718. ggml_set_op_params(result, &params, sizeof(params));
  4719. result->op = GGML_OP_CUSTOM;
  4720. for (int i = 0; i < n_args; i++) {
  4721. result->src[i] = args[i];
  4722. }
  4723. return result;
  4724. }
  4725. struct ggml_tensor * ggml_custom_inplace(
  4726. struct ggml_context * ctx,
  4727. struct ggml_tensor * a,
  4728. struct ggml_tensor ** args,
  4729. int n_args,
  4730. ggml_custom_op_t fun,
  4731. int n_tasks,
  4732. void * userdata) {
  4733. GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
  4734. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  4735. struct ggml_custom_op_params params = {
  4736. /*.fun =*/ fun,
  4737. /*.n_tasks =*/ n_tasks,
  4738. /*.userdata =*/ userdata
  4739. };
  4740. ggml_set_op_params(result, &params, sizeof(params));
  4741. result->op = GGML_OP_CUSTOM;
  4742. result->src[0] = a;
  4743. for (int i = 0; i < n_args; i++) {
  4744. result->src[i + 1] = args[i];
  4745. }
  4746. return result;
  4747. }
  4748. // ggml_cross_entropy_loss
  4749. struct ggml_tensor * ggml_cross_entropy_loss(
  4750. struct ggml_context * ctx,
  4751. struct ggml_tensor * a,
  4752. struct ggml_tensor * b) {
  4753. GGML_ASSERT(ggml_are_same_shape(a, b));
  4754. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
  4755. result->op = GGML_OP_CROSS_ENTROPY_LOSS;
  4756. result->src[0] = a;
  4757. result->src[1] = b;
  4758. return result;
  4759. }
  4760. // ggml_cross_entropy_loss_back
  4761. struct ggml_tensor * ggml_cross_entropy_loss_back(
  4762. struct ggml_context * ctx,
  4763. struct ggml_tensor * a,
  4764. struct ggml_tensor * b,
  4765. struct ggml_tensor * c) {
  4766. GGML_ASSERT(ggml_is_scalar(a));
  4767. GGML_ASSERT(ggml_are_same_shape(b, c));
  4768. struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
  4769. result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
  4770. result->src[0] = a;
  4771. result->src[1] = b;
  4772. result->src[2] = c;
  4773. return result;
  4774. }
  4775. // opt_step_adamw
  4776. struct ggml_tensor * ggml_opt_step_adamw(
  4777. struct ggml_context * ctx,
  4778. struct ggml_tensor * a,
  4779. struct ggml_tensor * grad,
  4780. struct ggml_tensor * m,
  4781. struct ggml_tensor * v,
  4782. struct ggml_tensor * adamw_params) {
  4783. GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
  4784. GGML_ASSERT(ggml_are_same_shape(a, grad));
  4785. GGML_ASSERT(ggml_are_same_shape(a, m));
  4786. GGML_ASSERT(ggml_are_same_shape(a, v));
  4787. GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
  4788. GGML_ASSERT(ggml_nelements(adamw_params) == 7);
  4789. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  4790. result->op = GGML_OP_OPT_STEP_ADAMW;
  4791. result->src[0] = a;
  4792. result->src[1] = grad;
  4793. result->src[2] = m;
  4794. result->src[3] = v;
  4795. result->src[4] = adamw_params;
  4796. return result;
  4797. }
  4798. // opt_step_sgd
  4799. struct ggml_tensor * ggml_opt_step_sgd(
  4800. struct ggml_context * ctx,
  4801. struct ggml_tensor * a,
  4802. struct ggml_tensor * grad,
  4803. struct ggml_tensor * params) {
  4804. GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
  4805. GGML_ASSERT(ggml_are_same_shape(a, grad));
  4806. GGML_ASSERT(params->type == GGML_TYPE_F32);
  4807. GGML_ASSERT(ggml_nelements(params) == 2);
  4808. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  4809. result->op = GGML_OP_OPT_STEP_SGD;
  4810. result->src[0] = a;
  4811. result->src[1] = grad;
  4812. result->src[2] = params;
  4813. return result;
  4814. }
  4815. ////////////////////////////////////////////////////////////////////////////////
  4816. struct ggml_hash_set ggml_hash_set_new(size_t size) {
  4817. size = ggml_hash_size(size);
  4818. struct ggml_hash_set result;
  4819. result.size = size;
  4820. result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size);
  4821. result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t));
  4822. return result;
  4823. }
  4824. void ggml_hash_set_reset(struct ggml_hash_set * hash_set) {
  4825. memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size));
  4826. }
  4827. void ggml_hash_set_free(struct ggml_hash_set * hash_set) {
  4828. GGML_FREE(hash_set->used);
  4829. GGML_FREE(hash_set->keys);
  4830. }
  4831. size_t ggml_hash_size(size_t min_sz) {
  4832. // next primes after powers of two
  4833. static const size_t primes[] = {
  4834. 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
  4835. 2053, 4099, 8209, 16411, 32771, 65537, 131101,
  4836. 262147, 524309, 1048583, 2097169, 4194319, 8388617,
  4837. 16777259, 33554467, 67108879, 134217757, 268435459,
  4838. 536870923, 1073741827, 2147483659
  4839. };
  4840. static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
  4841. // find the smallest prime that is larger or equal than min_sz
  4842. size_t l = 0;
  4843. size_t r = n_primes;
  4844. while (l < r) {
  4845. size_t m = (l + r)/2;
  4846. if (primes[m] < min_sz) {
  4847. l = m + 1;
  4848. } else {
  4849. r = m;
  4850. }
  4851. }
  4852. size_t sz = l < n_primes ? primes[l] : min_sz | 1;
  4853. return sz;
  4854. }
  4855. struct hash_map {
  4856. struct ggml_hash_set set;
  4857. struct ggml_tensor ** vals;
  4858. };
  4859. static struct hash_map * ggml_new_hash_map(size_t size) {
  4860. struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map));
  4861. result->set = ggml_hash_set_new(size);
  4862. result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *));
  4863. return result;
  4864. }
  4865. static void ggml_hash_map_free(struct hash_map * map) {
  4866. ggml_hash_set_free(&map->set);
  4867. GGML_FREE(map->vals);
  4868. GGML_FREE(map);
  4869. }
  4870. // utility functions to change gradients
  4871. // isrc is the index of tensor in cgraph->visited_has_set.keys
  4872. // the corresponding gradient (accumulators) are also at position isrc
  4873. // if tensor has a gradient accumulator, modify that accumulator in-place
  4874. // else if there is no gradient for tensor, set the corresponding value
  4875. // else, just add/subtract/etc. the gradients
  4876. static void ggml_add_or_set(
  4877. struct ggml_context * ctx,
  4878. struct ggml_cgraph * cgraph,
  4879. size_t isrc,
  4880. struct ggml_tensor * tensor) {
  4881. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  4882. GGML_ASSERT(src);
  4883. if (cgraph->grads[isrc]) {
  4884. cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
  4885. } else {
  4886. cgraph->grads[isrc] = tensor;
  4887. }
  4888. ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
  4889. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  4890. }
  4891. static void ggml_acc_or_set(
  4892. struct ggml_context * ctx,
  4893. struct ggml_cgraph * cgraph,
  4894. size_t isrc,
  4895. struct ggml_tensor * tensor,
  4896. const size_t nb1,
  4897. const size_t nb2,
  4898. const size_t nb3,
  4899. const size_t offset) {
  4900. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  4901. GGML_ASSERT(src);
  4902. if (cgraph->grads[isrc]) {
  4903. cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
  4904. } else {
  4905. struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
  4906. cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
  4907. }
  4908. ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
  4909. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  4910. }
  4911. static void ggml_add1_or_set(
  4912. struct ggml_context * ctx,
  4913. struct ggml_cgraph * cgraph,
  4914. size_t isrc,
  4915. struct ggml_tensor * tensor) {
  4916. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  4917. GGML_ASSERT(src);
  4918. if (cgraph->grads[isrc]) {
  4919. cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
  4920. } else {
  4921. cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
  4922. }
  4923. ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
  4924. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  4925. }
  4926. static void ggml_sub_or_set(
  4927. struct ggml_context * ctx,
  4928. struct ggml_cgraph * cgraph,
  4929. size_t isrc,
  4930. struct ggml_tensor * tensor) {
  4931. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  4932. GGML_ASSERT(src);
  4933. if (cgraph->grads[isrc]) {
  4934. cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
  4935. } else {
  4936. cgraph->grads[isrc] = ggml_neg(ctx, tensor);
  4937. }
  4938. ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
  4939. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  4940. }
  4941. static void ggml_compute_backward(
  4942. struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
  4943. struct ggml_tensor * tensor = cgraph->nodes[i];
  4944. struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
  4945. if (!grad) {
  4946. return;
  4947. }
  4948. struct ggml_tensor * src0 = tensor->src[0];
  4949. struct ggml_tensor * src1 = tensor->src[1];
  4950. struct ggml_tensor * src2 = tensor->src[2];
  4951. struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
  4952. const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
  4953. const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
  4954. const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
  4955. const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
  4956. const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
  4957. const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
  4958. switch (tensor->op) {
  4959. case GGML_OP_DUP: {
  4960. if (src0_needs_grads) {
  4961. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  4962. }
  4963. } break;
  4964. case GGML_OP_ADD: {
  4965. if (src0_needs_grads) {
  4966. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  4967. }
  4968. if (src1_needs_grads) {
  4969. struct ggml_tensor * tmp = grad;
  4970. if (!ggml_are_same_shape(src0, src1)) {
  4971. tmp = ggml_repeat_back(ctx, tmp, src1);
  4972. }
  4973. ggml_add_or_set(ctx, cgraph, isrc1, tmp);
  4974. }
  4975. } break;
  4976. case GGML_OP_ADD1: {
  4977. if (src0_needs_grads) {
  4978. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  4979. }
  4980. if (src1_needs_grads) {
  4981. ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
  4982. }
  4983. } break;
  4984. case GGML_OP_ACC: {
  4985. if (src0_needs_grads) {
  4986. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  4987. }
  4988. if (src1_needs_grads) {
  4989. const size_t nb1 = ((int32_t *) tensor->op_params)[0];
  4990. const size_t nb2 = ((int32_t *) tensor->op_params)[1];
  4991. const size_t nb3 = ((int32_t *) tensor->op_params)[2];
  4992. const size_t offset = ((int32_t *) tensor->op_params)[3];
  4993. struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
  4994. grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
  4995. nb1, nb2, nb3, offset);
  4996. ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
  4997. }
  4998. } break;
  4999. case GGML_OP_SUB: {
  5000. if (src0_needs_grads) {
  5001. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  5002. }
  5003. if (src1_needs_grads) {
  5004. ggml_sub_or_set(ctx, cgraph, isrc1, grad);
  5005. }
  5006. } break;
  5007. case GGML_OP_MUL: {
  5008. if (src0_needs_grads) {
  5009. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
  5010. }
  5011. if (src1_needs_grads) {
  5012. struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
  5013. if (!ggml_are_same_shape(src0, src1)) {
  5014. tmp = ggml_repeat_back(ctx, tmp, src1);
  5015. }
  5016. ggml_add_or_set(ctx, cgraph, isrc1, tmp);
  5017. }
  5018. } break;
  5019. case GGML_OP_DIV: {
  5020. if (src0_needs_grads) {
  5021. ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
  5022. }
  5023. if (src1_needs_grads) {
  5024. ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
  5025. }
  5026. } break;
  5027. case GGML_OP_SQR: {
  5028. if (src0_needs_grads) {
  5029. ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
  5030. }
  5031. } break;
  5032. case GGML_OP_SQRT: {
  5033. if (src0_needs_grads) {
  5034. ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
  5035. }
  5036. } break;
  5037. case GGML_OP_LOG: {
  5038. if (src0_needs_grads) {
  5039. ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
  5040. }
  5041. } break;
  5042. case GGML_OP_SIN: {
  5043. if (src0_needs_grads) {
  5044. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
  5045. }
  5046. } break;
  5047. case GGML_OP_COS: {
  5048. if (src0_needs_grads) {
  5049. ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
  5050. }
  5051. } break;
  5052. case GGML_OP_SUM: {
  5053. if (src0_needs_grads) {
  5054. ggml_add1_or_set(ctx, cgraph, isrc0, grad);
  5055. }
  5056. } break;
  5057. case GGML_OP_SUM_ROWS: {
  5058. if (src0_needs_grads) {
  5059. ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
  5060. }
  5061. } break;
  5062. case GGML_OP_MEAN: {
  5063. if (src0_needs_grads) {
  5064. ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
  5065. }
  5066. } break;
  5067. case GGML_OP_REPEAT: {
  5068. if (src0_needs_grads) {
  5069. ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
  5070. }
  5071. } break;
  5072. case GGML_OP_REPEAT_BACK: {
  5073. if (src0_needs_grads) {
  5074. ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
  5075. }
  5076. } break;
  5077. case GGML_OP_RMS_NORM: {
  5078. if (src0_needs_grads) {
  5079. float eps;
  5080. memcpy(&eps, tensor->op_params, sizeof(float));
  5081. ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
  5082. }
  5083. } break;
  5084. case GGML_OP_MUL_MAT: {
  5085. // https://cs231n.github.io/optimization-2/#staged
  5086. // # forward pass
  5087. // s0 = np.random.randn(5, 10)
  5088. // s1 = np.random.randn(10, 3)
  5089. // t = s0.dot(s1)
  5090. // # now suppose we had the gradient on t from above in the circuit
  5091. // dt = np.random.randn(*t.shape) # same shape as t
  5092. // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
  5093. // ds1 = t.T.dot(dt)
  5094. // tensor.shape [m,p,qq,rr]
  5095. // src0.shape [n,m,q1,r1]
  5096. // src1.shape [n,p,qq,rr]
  5097. if (src0_needs_grads) {
  5098. GGML_ASSERT(grad->ne[2] == src1->ne[2]);
  5099. GGML_ASSERT(grad->ne[3] == src1->ne[3]);
  5100. struct ggml_tensor * tmp =
  5101. ggml_out_prod(ctx, // [n,m,qq,rr]
  5102. src1, // [n,p,qq,rr]
  5103. grad); // [m,p,qq,rr]
  5104. if (!ggml_are_same_shape(tmp, src0)) {
  5105. GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
  5106. GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
  5107. GGML_ASSERT(tmp->ne[3] == 1);
  5108. const int64_t nr2 = tmp->ne[2] / src0->ne[2];
  5109. const size_t nb2 = tmp->nb[2] * nr2;
  5110. const size_t nb3 = tmp->nb[2];
  5111. tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
  5112. tmp = ggml_repeat_back(ctx, tmp, src0);
  5113. }
  5114. ggml_add_or_set(ctx, cgraph, isrc0, tmp);
  5115. }
  5116. if (src1_needs_grads) {
  5117. ggml_add_or_set(ctx, cgraph, isrc1,
  5118. // ggml_mul_mat(ctx, // [n,p,qq,rr]
  5119. // ggml_cont(ctx, // [m,n,q1,r1]
  5120. // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
  5121. // grad), // [m,p,qq,rr]
  5122. // when src0 is bigger than tensor->grad (this is mostly the case in llama),
  5123. // avoid transpose of src0, rather transpose smaller tensor->grad
  5124. // and then use ggml_out_prod
  5125. ggml_out_prod(ctx, // [n,p,qq,rr]
  5126. src0, // [n,m,q1,r1]
  5127. ggml_transpose(ctx, // [p,m,qq,rr]
  5128. grad))); // [m,p,qq,rr]
  5129. }
  5130. } break;
  5131. case GGML_OP_SCALE: {
  5132. if (src0_needs_grads) {
  5133. float s;
  5134. memcpy(&s, tensor->op_params, sizeof(float));
  5135. ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
  5136. }
  5137. } break;
  5138. case GGML_OP_SET: {
  5139. const size_t nb1 = ((const int32_t *) tensor->op_params)[0];
  5140. const size_t nb2 = ((const int32_t *) tensor->op_params)[1];
  5141. const size_t nb3 = ((const int32_t *) tensor->op_params)[2];
  5142. const size_t offset = ((const int32_t *) tensor->op_params)[3];
  5143. struct ggml_tensor * tensor_grad_view = NULL;
  5144. if (src0_needs_grads || src1_needs_grads) {
  5145. GGML_ASSERT(src0->type == tensor->type);
  5146. GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type);
  5147. GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
  5148. tensor_grad_view = ggml_view_4d(ctx,
  5149. grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
  5150. nb1, nb2, nb3, offset);
  5151. }
  5152. if (src0_needs_grads) {
  5153. struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
  5154. ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
  5155. }
  5156. if (src1_needs_grads) {
  5157. ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
  5158. }
  5159. } break;
  5160. case GGML_OP_CPY: {
  5161. // cpy overwrites value of src1 by src0 and returns view(src1)
  5162. // the overwriting is mathematically equivalent to:
  5163. // tensor = src0 * 1 + src1 * 0
  5164. if (src0_needs_grads) {
  5165. // dsrc0 = dtensor * 1
  5166. ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
  5167. }
  5168. if (src1_needs_grads) {
  5169. // dsrc1 = dtensor * 0 -> noop
  5170. }
  5171. } break;
  5172. case GGML_OP_CONT: {
  5173. // same as cpy
  5174. if (src0_needs_grads) {
  5175. GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
  5176. GGML_ASSERT(ggml_is_contiguous(grad));
  5177. GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
  5178. ggml_add_or_set(ctx, cgraph, isrc0,
  5179. ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
  5180. }
  5181. } break;
  5182. case GGML_OP_RESHAPE: {
  5183. if (src0_needs_grads) {
  5184. struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
  5185. ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
  5186. }
  5187. } break;
  5188. case GGML_OP_VIEW: {
  5189. if (src0_needs_grads) {
  5190. size_t offset;
  5191. memcpy(&offset, tensor->op_params, sizeof(offset));
  5192. size_t nb1 = tensor->nb[1];
  5193. size_t nb2 = tensor->nb[2];
  5194. size_t nb3 = tensor->nb[3];
  5195. if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
  5196. // gradient is typically F32, but src0 could be other type
  5197. size_t ng = ggml_element_size(cgraph->grads[isrc0]);
  5198. size_t n0 = ggml_element_size(src0);
  5199. GGML_ASSERT(offset % n0 == 0);
  5200. GGML_ASSERT(nb1 % n0 == 0);
  5201. GGML_ASSERT(nb2 % n0 == 0);
  5202. GGML_ASSERT(nb3 % n0 == 0);
  5203. offset = (offset / n0) * ng;
  5204. nb1 = (nb1 / n0) * ng;
  5205. nb2 = (nb2 / n0) * ng;
  5206. nb3 = (nb3 / n0) * ng;
  5207. }
  5208. ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
  5209. }
  5210. } break;
  5211. case GGML_OP_PERMUTE: {
  5212. if (src0_needs_grads) {
  5213. const int32_t * axes = (const int32_t *) tensor->op_params;
  5214. const int axis0 = axes[0] & 0x3;
  5215. const int axis1 = axes[1] & 0x3;
  5216. const int axis2 = axes[2] & 0x3;
  5217. const int axis3 = axes[3] & 0x3;
  5218. int axb[4] = {0,0,0,0}; // axes backward
  5219. axb[axis0] = 0;
  5220. axb[axis1] = 1;
  5221. axb[axis2] = 2;
  5222. axb[axis3] = 3;
  5223. ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
  5224. }
  5225. } break;
  5226. case GGML_OP_TRANSPOSE: {
  5227. if (src0_needs_grads) {
  5228. ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
  5229. }
  5230. } break;
  5231. case GGML_OP_GET_ROWS: {
  5232. if (src0_needs_grads) {
  5233. ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
  5234. }
  5235. if (src1_needs_grads) {
  5236. // noop
  5237. }
  5238. } break;
  5239. case GGML_OP_DIAG_MASK_INF: {
  5240. if (src0_needs_grads) {
  5241. /* ggml_diag_mask_inf_impl() shouldn't be here */
  5242. /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
  5243. const int n_past = ((const int32_t *) tensor->op_params)[0];
  5244. ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
  5245. }
  5246. } break;
  5247. case GGML_OP_DIAG_MASK_ZERO: {
  5248. if (src0_needs_grads) {
  5249. const int n_past = ((const int32_t *) tensor->op_params)[0];
  5250. ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
  5251. }
  5252. } break;
  5253. case GGML_OP_SOFT_MAX: {
  5254. if (src0_needs_grads) {
  5255. float scale = 1.0f;
  5256. float max_bias = 0.0f;
  5257. memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
  5258. memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
  5259. ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
  5260. }
  5261. GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
  5262. } break;
  5263. case GGML_OP_ROPE: {
  5264. if (src0_needs_grads) {
  5265. //const int n_past = ((int32_t *) tensor->op_params)[0];
  5266. const int n_dims = ((const int32_t *) tensor->op_params)[1];
  5267. const int mode = ((const int32_t *) tensor->op_params)[2];
  5268. //const int n_ctx = ((int32_t *) tensor->op_params)[3];
  5269. const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
  5270. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  5271. int sections[4] = {0, 0, 0, 0};
  5272. memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
  5273. memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
  5274. memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float));
  5275. memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
  5276. memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
  5277. memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
  5278. memcpy(&sections, tensor->op_params + 11, sizeof(sections));
  5279. struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
  5280. ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
  5281. mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
  5282. ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
  5283. mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
  5284. ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
  5285. }
  5286. GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
  5287. } break;
  5288. case GGML_OP_IM2COL: {
  5289. if (src1_needs_grads) {
  5290. const int32_t s0 = ggml_get_op_params_i32(tensor, 0);
  5291. const int32_t s1 = ggml_get_op_params_i32(tensor, 1);
  5292. const int32_t p0 = ggml_get_op_params_i32(tensor, 2);
  5293. const int32_t p1 = ggml_get_op_params_i32(tensor, 3);
  5294. const int32_t d0 = ggml_get_op_params_i32(tensor, 4);
  5295. const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
  5296. const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
  5297. ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
  5298. }
  5299. } break;
  5300. case GGML_OP_POOL_2D: {
  5301. if (src0_needs_grads) {
  5302. const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
  5303. const int32_t k0 = ggml_get_op_params_i32(tensor, 1);
  5304. const int32_t k1 = ggml_get_op_params_i32(tensor, 2);
  5305. const int32_t s0 = ggml_get_op_params_i32(tensor, 3);
  5306. const int32_t s1 = ggml_get_op_params_i32(tensor, 4);
  5307. const int32_t p0 = ggml_get_op_params_i32(tensor, 5);
  5308. const int32_t p1 = ggml_get_op_params_i32(tensor, 6);
  5309. ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
  5310. }
  5311. } break;
  5312. case GGML_OP_WIN_PART:
  5313. case GGML_OP_WIN_UNPART:
  5314. case GGML_OP_UNARY: {
  5315. switch (ggml_get_unary_op(tensor)) {
  5316. case GGML_UNARY_OP_ABS: {
  5317. if (src0_needs_grads) {
  5318. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
  5319. }
  5320. } break;
  5321. case GGML_UNARY_OP_SGN: {
  5322. // noop
  5323. } break;
  5324. case GGML_UNARY_OP_NEG: {
  5325. if (src0_needs_grads) {
  5326. ggml_sub_or_set(ctx, cgraph, isrc0, grad);
  5327. }
  5328. } break;
  5329. case GGML_UNARY_OP_STEP: {
  5330. // noop
  5331. } break;
  5332. case GGML_UNARY_OP_RELU: {
  5333. if (src0_needs_grads) {
  5334. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
  5335. }
  5336. } break;
  5337. case GGML_UNARY_OP_SILU: {
  5338. if (src0_needs_grads) {
  5339. ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
  5340. }
  5341. } break;
  5342. case GGML_UNARY_OP_EXP: {
  5343. if (src0_needs_grads) {
  5344. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
  5345. }
  5346. } break;
  5347. default: {
  5348. fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
  5349. __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
  5350. GGML_ABORT("fatal error");
  5351. } //break;
  5352. }
  5353. } break;
  5354. case GGML_OP_CROSS_ENTROPY_LOSS: {
  5355. if (src0_needs_grads) {
  5356. ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
  5357. }
  5358. GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
  5359. } break;
  5360. case GGML_OP_GLU: {
  5361. switch (ggml_get_glu_op(tensor)) {
  5362. case GGML_GLU_OP_SWIGLU: {
  5363. if (src0_needs_grads) {
  5364. GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
  5365. ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
  5366. }
  5367. if (src1_needs_grads) {
  5368. ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
  5369. }
  5370. } break;
  5371. default: {
  5372. GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
  5373. } //break;
  5374. }
  5375. } break;
  5376. case GGML_OP_NONE: {
  5377. // noop
  5378. } break;
  5379. case GGML_OP_COUNT:
  5380. default: {
  5381. GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
  5382. } //break;
  5383. }
  5384. GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
  5385. GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
  5386. GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
  5387. }
  5388. static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
  5389. // check if already visited
  5390. size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
  5391. GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
  5392. if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
  5393. // This is the first time we see this node in the current graph.
  5394. cgraph->visited_hash_set.keys[node_hash_pos] = node;
  5395. ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
  5396. cgraph->use_counts[node_hash_pos] = 0;
  5397. } else {
  5398. // already visited
  5399. return node_hash_pos;
  5400. }
  5401. for (int i = 0; i < GGML_MAX_SRC; ++i) {
  5402. const int k =
  5403. (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
  5404. (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
  5405. /* unknown order, just fall back to using i */ i;
  5406. struct ggml_tensor * src = node->src[k];
  5407. if (src) {
  5408. size_t src_hash_pos = ggml_visit_parents(cgraph, src);
  5409. // Update the use count for this operand.
  5410. cgraph->use_counts[src_hash_pos]++;
  5411. }
  5412. }
  5413. if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) {
  5414. // reached a leaf node, not part of the gradient graph (e.g. a constant)
  5415. GGML_ASSERT(cgraph->n_leafs < cgraph->size);
  5416. if (strlen(node->name) == 0) {
  5417. ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
  5418. }
  5419. cgraph->leafs[cgraph->n_leafs] = node;
  5420. cgraph->n_leafs++;
  5421. } else {
  5422. GGML_ASSERT(cgraph->n_nodes < cgraph->size);
  5423. if (strlen(node->name) == 0) {
  5424. ggml_format_name(node, "node_%d", cgraph->n_nodes);
  5425. }
  5426. cgraph->nodes[cgraph->n_nodes] = node;
  5427. cgraph->n_nodes++;
  5428. }
  5429. return node_hash_pos;
  5430. }
  5431. static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
  5432. if (!expand) {
  5433. // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
  5434. ggml_graph_clear(cgraph);
  5435. }
  5436. const int n0 = cgraph->n_nodes;
  5437. ggml_visit_parents(cgraph, tensor);
  5438. const int n_new = cgraph->n_nodes - n0;
  5439. GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
  5440. if (n_new > 0) {
  5441. // the last added node should always be starting point
  5442. GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
  5443. }
  5444. }
  5445. void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
  5446. ggml_build_forward_impl(cgraph, tensor, true);
  5447. }
  5448. void ggml_build_backward_expand(
  5449. struct ggml_context * ctx,
  5450. struct ggml_cgraph * cgraph,
  5451. struct ggml_tensor ** grad_accs) {
  5452. GGML_ASSERT(cgraph->n_nodes > 0);
  5453. GGML_ASSERT(cgraph->grads);
  5454. GGML_ASSERT(cgraph->grad_accs);
  5455. const int n_nodes_f = cgraph->n_nodes;
  5456. memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5457. memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5458. bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
  5459. {
  5460. bool any_params = false;
  5461. bool any_loss = false;
  5462. for (int i = 0; i < n_nodes_f; ++i) {
  5463. struct ggml_tensor * node = cgraph->nodes[i];
  5464. any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
  5465. any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS);
  5466. }
  5467. GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
  5468. GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
  5469. }
  5470. for (int i = 0; i < n_nodes_f; ++i) {
  5471. struct ggml_tensor * node = cgraph->nodes[i];
  5472. if (node->type == GGML_TYPE_I32) {
  5473. continue;
  5474. }
  5475. bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
  5476. bool ignore_src[GGML_MAX_SRC] = {false};
  5477. switch (node->op) {
  5478. // gradients in node->src[0] for one reason or another have no effect on output gradients
  5479. case GGML_OP_IM2COL: // only used for its shape
  5480. case GGML_OP_IM2COL_BACK: // same as IM2COL
  5481. ignore_src[0] = true;
  5482. break;
  5483. case GGML_OP_UNARY: {
  5484. const enum ggml_unary_op uop = ggml_get_unary_op(node);
  5485. // SGN and STEP unary ops are piecewise constant
  5486. if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) {
  5487. ignore_src[0] = true;
  5488. }
  5489. } break;
  5490. // gradients in node->src[1] for one reason or another have no effect on output gradients
  5491. case GGML_OP_CPY: // gradients in CPY target are irrelevant
  5492. case GGML_OP_GET_ROWS: // row indices not differentiable
  5493. case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
  5494. case GGML_OP_ROPE: // positions not differentiable
  5495. ignore_src[1] = true;
  5496. break;
  5497. default:
  5498. break;
  5499. }
  5500. for (int j = 0; j < GGML_MAX_SRC; ++j) {
  5501. if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
  5502. continue;
  5503. }
  5504. GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
  5505. node_needs_grad = true;
  5506. break;
  5507. }
  5508. if (!node_needs_grad) {
  5509. continue;
  5510. }
  5511. // inplace operations are currently not supported
  5512. GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
  5513. node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
  5514. const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
  5515. GGML_ASSERT(ihash != GGML_HASHSET_FULL);
  5516. GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
  5517. if (grad_accs && grad_accs[i]) {
  5518. cgraph->grad_accs[ihash] = grad_accs[i];
  5519. cgraph->grads[ihash] = cgraph->grad_accs[ihash];
  5520. } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
  5521. // loss tensors always need a gradient accumulator
  5522. cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
  5523. cgraph->grads[ihash] = cgraph->grad_accs[ihash];
  5524. }
  5525. grads_needed[ihash] = true;
  5526. }
  5527. for (int i = n_nodes_f - 1; i >= 0; --i) {
  5528. // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
  5529. // use allocator to automatically make inplace operations
  5530. ggml_compute_backward(ctx, cgraph, i, grads_needed);
  5531. }
  5532. free(grads_needed);
  5533. }
  5534. static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
  5535. void * ptr = *p;
  5536. ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
  5537. *p = (void *) ((char *) ptr + size);
  5538. return ptr;
  5539. }
  5540. static size_t ggml_graph_nbytes(size_t size, bool grads) {
  5541. size_t hash_size = ggml_hash_size(size * 2);
  5542. void * p = 0;
  5543. incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
  5544. incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
  5545. incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
  5546. incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
  5547. incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
  5548. if (grads) {
  5549. incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
  5550. incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
  5551. }
  5552. incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
  5553. size_t nbytes = (size_t) p;
  5554. return nbytes;
  5555. }
  5556. size_t ggml_graph_overhead_custom(size_t size, bool grads) {
  5557. return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
  5558. }
  5559. size_t ggml_graph_overhead(void) {
  5560. return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
  5561. }
  5562. struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
  5563. const size_t obj_size = ggml_graph_nbytes(size, grads);
  5564. struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size);
  5565. struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
  5566. // the size of the hash table is doubled since it needs to hold both nodes and leafs
  5567. size_t hash_size = ggml_hash_size(size * 2);
  5568. void * p = cgraph + 1;
  5569. struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
  5570. struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
  5571. int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
  5572. struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
  5573. struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
  5574. struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
  5575. ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
  5576. // check that we allocated the correct amount of memory
  5577. assert(obj_size == (size_t)((char *)p - (char *)cgraph));
  5578. *cgraph = (struct ggml_cgraph) {
  5579. /*.size =*/ size,
  5580. /*.n_nodes =*/ 0,
  5581. /*.n_leafs =*/ 0,
  5582. /*.nodes =*/ nodes_ptr,
  5583. /*.grads =*/ grads_ptr,
  5584. /*.grad_accs =*/ grad_accs_ptr,
  5585. /*.leafs =*/ leafs_ptr,
  5586. /*.use_counts =*/ use_counts_ptr,
  5587. /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
  5588. /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
  5589. };
  5590. ggml_hash_set_reset(&cgraph->visited_hash_set);
  5591. if (grads) {
  5592. memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
  5593. memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
  5594. }
  5595. return cgraph;
  5596. }
  5597. struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
  5598. return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
  5599. }
  5600. struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
  5601. struct ggml_cgraph cgraph = {
  5602. /*.size =*/ 0,
  5603. /*.n_nodes =*/ i1 - i0,
  5604. /*.n_leafs =*/ 0,
  5605. /*.nodes =*/ cgraph0->nodes + i0,
  5606. /*.grads =*/ NULL, // gradients would need visited_hash_set
  5607. /*.grad_accs =*/ NULL,
  5608. /*.leafs =*/ NULL,
  5609. /*.use_counts =*/ cgraph0->use_counts,
  5610. /*.visited_hash_set =*/ cgraph0->visited_hash_set,
  5611. /*.order =*/ cgraph0->order,
  5612. };
  5613. return cgraph;
  5614. }
  5615. void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
  5616. GGML_ASSERT(dst->size >= src->n_leafs);
  5617. GGML_ASSERT(dst->size >= src->n_nodes);
  5618. GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size);
  5619. dst->n_leafs = src->n_leafs;
  5620. dst->n_nodes = src->n_nodes;
  5621. dst->order = src->order;
  5622. for (int i = 0; i < src->n_leafs; ++i) {
  5623. dst->leafs[i] = src->leafs[i];
  5624. }
  5625. for (int i = 0; i < src->n_nodes; ++i) {
  5626. dst->nodes[i] = src->nodes[i];
  5627. }
  5628. for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
  5629. // copy all hashset keys (tensors) that are in use
  5630. if (ggml_bitset_get(src->visited_hash_set.used, i)) {
  5631. size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
  5632. dst->use_counts[new_hash_pos] = src->use_counts[i];
  5633. }
  5634. }
  5635. if (dst->grads) {
  5636. memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5637. memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5638. }
  5639. if (src->grads) {
  5640. GGML_ASSERT(dst->grads != NULL);
  5641. GGML_ASSERT(dst->grad_accs != NULL);
  5642. for (int i = 0; i < src->n_nodes; ++i) {
  5643. const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
  5644. const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
  5645. GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
  5646. GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
  5647. GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
  5648. GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
  5649. dst->grads[igrad_dst] = src->grads[igrad_src];
  5650. dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
  5651. }
  5652. }
  5653. }
  5654. struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
  5655. struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
  5656. ggml_graph_cpy(cgraph, result);
  5657. return result;
  5658. }
  5659. struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
  5660. if (ggml_is_empty(tensor)) {
  5661. return tensor;
  5662. }
  5663. if (tensor->buffer) {
  5664. ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
  5665. } else {
  5666. GGML_ASSERT(tensor->data);
  5667. memset(tensor->data, 0, ggml_nbytes(tensor));
  5668. }
  5669. return tensor;
  5670. }
  5671. void ggml_graph_reset(struct ggml_cgraph * cgraph) {
  5672. if (!cgraph) {
  5673. return;
  5674. }
  5675. GGML_ASSERT(cgraph->grads != NULL);
  5676. for (int i = 0; i < cgraph->n_nodes; i++) {
  5677. struct ggml_tensor * node = cgraph->nodes[i];
  5678. struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
  5679. if (node->op == GGML_OP_OPT_STEP_ADAMW) {
  5680. // clear momenta
  5681. ggml_set_zero(node->src[2]);
  5682. ggml_set_zero(node->src[3]);
  5683. }
  5684. // initial gradients of loss should be 1, 0 otherwise
  5685. if (grad_acc) {
  5686. if (node->flags & GGML_TENSOR_FLAG_LOSS) {
  5687. GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
  5688. GGML_ASSERT(ggml_is_scalar(grad_acc));
  5689. const float onef = 1.0f;
  5690. if (grad_acc->buffer) {
  5691. ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
  5692. } else {
  5693. GGML_ASSERT(grad_acc->data);
  5694. *((float *) grad_acc->data) = onef;
  5695. }
  5696. } else {
  5697. ggml_set_zero(grad_acc);
  5698. }
  5699. }
  5700. }
  5701. }
  5702. void ggml_graph_clear(struct ggml_cgraph * cgraph) {
  5703. cgraph->n_leafs = 0;
  5704. cgraph->n_nodes = 0;
  5705. ggml_hash_set_reset(&cgraph->visited_hash_set);
  5706. }
  5707. int ggml_graph_size(struct ggml_cgraph * cgraph) {
  5708. return cgraph->size;
  5709. }
  5710. struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) {
  5711. if (i < 0) {
  5712. GGML_ASSERT(cgraph->n_nodes + i >= 0);
  5713. return cgraph->nodes[cgraph->n_nodes + i];
  5714. }
  5715. GGML_ASSERT(i < cgraph->n_nodes);
  5716. return cgraph->nodes[i];
  5717. }
  5718. struct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) {
  5719. return cgraph->nodes;
  5720. }
  5721. int ggml_graph_n_nodes(struct ggml_cgraph * cgraph) {
  5722. return cgraph->n_nodes;
  5723. }
  5724. void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
  5725. GGML_ASSERT(cgraph->size > cgraph->n_nodes);
  5726. cgraph->nodes[cgraph->n_nodes] = tensor;
  5727. cgraph->n_nodes++;
  5728. }
  5729. struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
  5730. for (int i = 0; i < cgraph->n_leafs; i++) {
  5731. struct ggml_tensor * leaf = cgraph->leafs[i];
  5732. if (strcmp(leaf->name, name) == 0) {
  5733. return leaf;
  5734. }
  5735. }
  5736. for (int i = 0; i < cgraph->n_nodes; i++) {
  5737. struct ggml_tensor * node = cgraph->nodes[i];
  5738. if (strcmp(node->name, name) == 0) {
  5739. return node;
  5740. }
  5741. }
  5742. return NULL;
  5743. }
  5744. struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  5745. const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
  5746. return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;
  5747. }
  5748. struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  5749. const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
  5750. return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;
  5751. }
  5752. void ggml_graph_print(const struct ggml_cgraph * cgraph) {
  5753. GGML_LOG_INFO("=== GRAPH ===\n");
  5754. GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes);
  5755. for (int i = 0; i < cgraph->n_nodes; i++) {
  5756. struct ggml_tensor * node = cgraph->nodes[i];
  5757. GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
  5758. i,
  5759. node->ne[0], node->ne[1], node->ne[2],
  5760. ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
  5761. ggml_graph_get_grad(cgraph, node) ? "g" : " ");
  5762. }
  5763. GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
  5764. for (int i = 0; i < cgraph->n_leafs; i++) {
  5765. struct ggml_tensor * node = cgraph->leafs[i];
  5766. GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
  5767. i,
  5768. node->ne[0], node->ne[1],
  5769. ggml_op_name(node->op),
  5770. ggml_get_name(node));
  5771. }
  5772. GGML_LOG_INFO("========================================\n");
  5773. }
  5774. static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
  5775. const int * idxs,
  5776. int count,
  5777. const struct ggml_tensor * tensor) {
  5778. GGML_ASSERT(cgraph && idxs);
  5779. for (int i = 0; i < count; ++i) {
  5780. const int node_idx = idxs[i];
  5781. if (node_idx >= cgraph->n_nodes) {
  5782. return -1;
  5783. }
  5784. if (cgraph->nodes[node_idx] == tensor) {
  5785. return i;
  5786. }
  5787. }
  5788. return -1;
  5789. }
  5790. bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
  5791. const int * node_idxs,
  5792. int count,
  5793. const enum ggml_op * ops,
  5794. const int * outputs,
  5795. int num_outputs) {
  5796. GGML_ASSERT(outputs && num_outputs > 0);
  5797. for (int i = 0; i < count; ++i) {
  5798. if (node_idxs[i] >= cgraph->n_nodes) {
  5799. return false;
  5800. }
  5801. const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
  5802. if (node->op != ops[i]) {
  5803. return false;
  5804. }
  5805. if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
  5806. continue;
  5807. }
  5808. if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
  5809. return false;
  5810. }
  5811. int subgraph_uses = 0;
  5812. for (int j = i + 1; j < count; ++j) {
  5813. const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
  5814. for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
  5815. if (other_node->src[src_idx] == node) {
  5816. subgraph_uses++;
  5817. }
  5818. }
  5819. }
  5820. if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
  5821. return false;
  5822. }
  5823. // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
  5824. struct ggml_tensor * view_src = node->view_src;
  5825. while (view_src) {
  5826. if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
  5827. return false;
  5828. }
  5829. view_src = view_src->view_src;
  5830. }
  5831. }
  5832. return true;
  5833. }
  5834. // check if node is part of the graph
  5835. static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  5836. if (cgraph == NULL) {
  5837. return true;
  5838. }
  5839. for (int i = 0; i < cgraph->n_nodes; i++) {
  5840. if (cgraph->nodes[i] == node) {
  5841. return true;
  5842. }
  5843. }
  5844. return false;
  5845. }
  5846. static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  5847. for (int i = 0; i < cgraph->n_nodes; i++) {
  5848. struct ggml_tensor * parent = cgraph->nodes[i];
  5849. struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
  5850. if (grad == node) {
  5851. return parent;
  5852. }
  5853. }
  5854. return NULL;
  5855. }
  5856. static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
  5857. struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
  5858. struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
  5859. fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
  5860. gparent0 ? (void *) gparent0 : (void *) parent,
  5861. gparent ? (void *) gparent : (void *) node,
  5862. gparent ? "empty" : "vee",
  5863. gparent ? "dashed" : "solid",
  5864. label);
  5865. }
  5866. static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
  5867. fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
  5868. (void *) parent,
  5869. (void *) node,
  5870. label);
  5871. }
  5872. void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
  5873. char color[16];
  5874. FILE * fp = ggml_fopen(filename, "w");
  5875. GGML_ASSERT(fp);
  5876. fprintf(fp, "digraph G {\n");
  5877. fprintf(fp, " newrank = true;\n");
  5878. fprintf(fp, " rankdir = TB;\n");
  5879. for (int i = 0; i < gb->n_nodes; i++) {
  5880. struct ggml_tensor * node = gb->nodes[i];
  5881. struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
  5882. if (ggml_graph_get_parent(gb, node) != NULL) {
  5883. continue;
  5884. }
  5885. if (node->flags & GGML_TENSOR_FLAG_PARAM) {
  5886. snprintf(color, sizeof(color), "yellow");
  5887. } else if (grad) {
  5888. if (ggml_graph_find(gf, node)) {
  5889. snprintf(color, sizeof(color), "green");
  5890. } else {
  5891. snprintf(color, sizeof(color), "lightblue");
  5892. }
  5893. } else {
  5894. snprintf(color, sizeof(color), "white");
  5895. }
  5896. fprintf(fp, " \"%p\" [ "
  5897. "style = filled; fillcolor = %s; shape = record; "
  5898. "label=\"",
  5899. (void *) node, color);
  5900. if (strlen(node->name) > 0) {
  5901. fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
  5902. } else {
  5903. fprintf(fp, "(%s)|", ggml_type_name(node->type));
  5904. }
  5905. if (ggml_is_matrix(node)) {
  5906. fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
  5907. } else {
  5908. fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
  5909. }
  5910. if (grad) {
  5911. fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(grad->op));
  5912. } else {
  5913. fprintf(fp, "\"; ]\n");
  5914. }
  5915. }
  5916. for (int i = 0; i < gb->n_leafs; i++) {
  5917. struct ggml_tensor * node = gb->leafs[i];
  5918. snprintf(color, sizeof(color), "pink");
  5919. fprintf(fp, " \"%p\" [ "
  5920. "style = filled; fillcolor = %s; shape = record; "
  5921. "label=\"<x>",
  5922. (void *) node, color);
  5923. if (strlen(node->name) > 0) {
  5924. fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
  5925. } else {
  5926. fprintf(fp, "(%s)|", ggml_type_name(node->type));
  5927. }
  5928. fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
  5929. if (ggml_nelements(node) < 5 && node->data != NULL) {
  5930. fprintf(fp, " | (");
  5931. for (int j = 0; j < ggml_nelements(node); j++) {
  5932. // FIXME: use ggml-backend to obtain the tensor data
  5933. //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
  5934. // fprintf(fp, "%d", ggml_get_i32_1d(node, j));
  5935. //}
  5936. //else if (node->type == GGML_TYPE_F32 ||
  5937. // node->type == GGML_TYPE_F16 ||
  5938. // node->type == GGML_TYPE_BF16) {
  5939. // fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
  5940. //}
  5941. //else
  5942. {
  5943. fprintf(fp, "#");
  5944. }
  5945. if (j < ggml_nelements(node) - 1) {
  5946. fprintf(fp, ", ");
  5947. }
  5948. }
  5949. fprintf(fp, ")");
  5950. }
  5951. fprintf(fp, "\"; ]\n");
  5952. }
  5953. for (int i = 0; i < gb->n_nodes; i++) {
  5954. struct ggml_tensor * node = gb->nodes[i];
  5955. for (int j = 0; j < GGML_MAX_SRC; j++) {
  5956. if (node->src[j]) {
  5957. char label[16];
  5958. snprintf(label, sizeof(label), "src %d", j);
  5959. ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
  5960. }
  5961. }
  5962. }
  5963. for (int i = 0; i < gb->n_leafs; i++) {
  5964. struct ggml_tensor * node = gb->leafs[i];
  5965. for (int j = 0; j < GGML_MAX_SRC; j++) {
  5966. if (node->src[j]) {
  5967. char label[16];
  5968. snprintf(label, sizeof(label), "src %d", j);
  5969. ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
  5970. }
  5971. }
  5972. }
  5973. fprintf(fp, "}\n");
  5974. fclose(fp);
  5975. GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
  5976. }
  5977. ////////////////////////////////////////////////////////////////////////////////
  5978. void ggml_set_input(struct ggml_tensor * tensor) {
  5979. tensor->flags |= GGML_TENSOR_FLAG_INPUT;
  5980. }
  5981. void ggml_set_output(struct ggml_tensor * tensor) {
  5982. tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
  5983. }
  5984. void ggml_set_param(struct ggml_tensor * tensor) {
  5985. GGML_ASSERT(tensor->op == GGML_OP_NONE);
  5986. tensor->flags |= GGML_TENSOR_FLAG_PARAM;
  5987. }
  5988. void ggml_set_loss(struct ggml_tensor * tensor) {
  5989. GGML_ASSERT(ggml_is_scalar(tensor));
  5990. GGML_ASSERT(tensor->type == GGML_TYPE_F32);
  5991. tensor->flags |= GGML_TENSOR_FLAG_LOSS;
  5992. }
  5993. ////////////////////////////////////////////////////////////////////////////////
  5994. void ggml_quantize_init(enum ggml_type type) {
  5995. ggml_critical_section_start();
  5996. switch (type) {
  5997. case GGML_TYPE_IQ2_XXS:
  5998. case GGML_TYPE_IQ2_XS:
  5999. case GGML_TYPE_IQ2_S:
  6000. case GGML_TYPE_IQ1_S:
  6001. case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
  6002. case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
  6003. case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
  6004. default: // nothing
  6005. break;
  6006. }
  6007. ggml_critical_section_end();
  6008. }
  6009. void ggml_quantize_free(void) {
  6010. ggml_critical_section_start();
  6011. iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
  6012. iq2xs_free_impl(GGML_TYPE_IQ2_XS);
  6013. iq2xs_free_impl(GGML_TYPE_IQ1_S);
  6014. iq3xs_free_impl(256);
  6015. ggml_critical_section_end();
  6016. }
  6017. bool ggml_quantize_requires_imatrix(enum ggml_type type) {
  6018. return
  6019. type == GGML_TYPE_IQ2_XXS ||
  6020. type == GGML_TYPE_IQ2_XS ||
  6021. type == GGML_TYPE_IQ1_S;// ||
  6022. //type == GGML_TYPE_IQ1_M;
  6023. }
  6024. size_t ggml_quantize_chunk(
  6025. enum ggml_type type,
  6026. const float * src,
  6027. void * dst,
  6028. int64_t start,
  6029. int64_t nrows,
  6030. int64_t n_per_row,
  6031. const float * imatrix) {
  6032. const int64_t n = (int64_t) nrows * n_per_row;
  6033. if (ggml_quantize_requires_imatrix(type)) {
  6034. GGML_ASSERT(imatrix != NULL);
  6035. }
  6036. GGML_ASSERT(start % type_traits[type].blck_size == 0);
  6037. GGML_ASSERT(start % n_per_row == 0);
  6038. ggml_quantize_init(type); // this is noop if already initialized
  6039. const size_t start_row = start / n_per_row;
  6040. const size_t row_size = ggml_row_size(type, n_per_row);
  6041. size_t result = 0;
  6042. switch (type) {
  6043. case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6044. case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6045. case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6046. case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6047. case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6048. case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6049. case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6050. case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6051. case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6052. case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6053. case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6054. case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6055. case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6056. case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6057. case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6058. case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6059. case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6060. case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6061. case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6062. case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6063. case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6064. case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6065. case GGML_TYPE_F16:
  6066. {
  6067. size_t elemsize = sizeof(ggml_fp16_t);
  6068. ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
  6069. result = n * elemsize;
  6070. } break;
  6071. case GGML_TYPE_BF16:
  6072. {
  6073. size_t elemsize = sizeof(ggml_bf16_t);
  6074. ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);
  6075. result = n * elemsize;
  6076. } break;
  6077. case GGML_TYPE_F32:
  6078. {
  6079. size_t elemsize = sizeof(float);
  6080. result = n * elemsize;
  6081. memcpy((uint8_t *)dst + start * elemsize, src + start, result);
  6082. } break;
  6083. default:
  6084. assert(false);
  6085. }
  6086. GGML_ASSERT(result == nrows * row_size);
  6087. return result;
  6088. }
  6089. ////////////////////////////////////////////////////////////////////////////////
  6090. void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
  6091. g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
  6092. g_logger_state.log_callback_user_data = user_data;
  6093. }
  6094. void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
  6095. p->n_threads = n_threads;
  6096. p->prio = 0; // default priority (usually means normal or inherited)
  6097. p->poll = 50; // hybrid-polling enabled
  6098. p->strict_cpu = false; // no strict placement (all threads share same cpumask)
  6099. p->paused = false; // threads are ready to go
  6100. memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
  6101. }
  6102. struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
  6103. struct ggml_threadpool_params p;
  6104. ggml_threadpool_params_init(&p, n_threads);
  6105. return p;
  6106. }
  6107. bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
  6108. if (p0->n_threads != p1->n_threads ) return false;
  6109. if (p0->prio != p1->prio ) return false;
  6110. if (p0->poll != p1->poll ) return false;
  6111. if (p0->strict_cpu != p1->strict_cpu ) return false;
  6112. return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
  6113. }