ops.cpp 332 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966996799689969997099719972997399749975997699779978997999809981998299839984998599869987998899899990999199929993999499959996999799989999100001000110002100031000410005100061000710008100091001010011100121001310014100151001610017100181001910020100211002210023100241002510026100271002810029100301003110032100331003410035100361003710038100391004010041100421004310044100451004610047100481004910050100511005210053100541005510056100571005810059100601006110062100631006410065100661006710068100691007010071100721007310074100751007610077100781007910080100811008210083100841008510086100871008810089100901009110092100931009410095100961009710098100991010010101
  1. #include "ops.h"
  2. #include "ggml-cpu.h"
  3. #include "ggml-impl.h"
  4. #include "binary-ops.h"
  5. #include "ggml.h"
  6. #include "unary-ops.h"
  7. #include "vec.h"
  8. #include <cfloat>
  9. #include <algorithm>
  10. #include <functional>
  11. // ggml_compute_forward_dup
  12. static void ggml_compute_forward_dup_same_cont(
  13. const ggml_compute_params * params,
  14. ggml_tensor * dst) {
  15. const ggml_tensor * src0 = dst->src[0];
  16. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  17. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  18. GGML_ASSERT(src0->type == dst->type);
  19. const size_t nb0 = ggml_type_size(src0->type);
  20. const int ith = params->ith; // thread index
  21. const int nth = params->nth; // number of threads
  22. // parallelize by blocks
  23. const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
  24. const int dr = (nk + nth - 1) / nth;
  25. const int k0 = dr * ith;
  26. const int k1 = MIN(k0 + dr, nk);
  27. if (k0 < k1) {
  28. memcpy(
  29. ((char *) dst->data + k0*nb0),
  30. ((char *) src0->data + k0*nb0),
  31. (k1 - k0) * nb0);
  32. }
  33. }
  34. template<typename src_t, typename dst_t>
  35. static void ggml_compute_forward_dup_flt(
  36. const ggml_compute_params * params,
  37. ggml_tensor * dst) {
  38. const ggml_tensor * src0 = dst->src[0];
  39. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  40. GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
  41. GGML_TENSOR_UNARY_OP_LOCALS
  42. const int ith = params->ith; // thread index
  43. const int nth = params->nth; // number of threads
  44. // parallelize by rows
  45. const int nr = ne01;
  46. // number of rows per thread
  47. const int dr = (nr + nth - 1) / nth;
  48. // row range for this thread
  49. const int ir0 = dr * ith;
  50. const int ir1 = MIN(ir0 + dr, nr);
  51. // case: type & row size equal
  52. if (src0->type == dst->type &&
  53. ne00 == ne0 &&
  54. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  55. // copy by rows
  56. const size_t rs = ne00*nb00;
  57. for (int64_t i03 = 0; i03 < ne03; i03++) {
  58. for (int64_t i02 = 0; i02 < ne02; i02++) {
  59. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  60. memcpy(
  61. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  62. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  63. rs);
  64. }
  65. }
  66. }
  67. return;
  68. }
  69. // case: dst tensor is contiguous
  70. if (ggml_is_contiguous(dst)) {
  71. if (nb00 == sizeof(src_t)) {
  72. if constexpr (std::is_same_v<dst_t, src_t>) {
  73. // same type
  74. size_t id = 0;
  75. const size_t rs = ne00 * nb00;
  76. char * dst_ptr = (char *) dst->data;
  77. for (int i03 = 0; i03 < ne03; i03++) {
  78. for (int i02 = 0; i02 < ne02; i02++) {
  79. id += rs * ir0;
  80. for (int i01 = ir0; i01 < ir1; i01++) {
  81. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  82. memcpy(dst_ptr + id, src0_ptr, rs);
  83. id += rs;
  84. }
  85. id += rs * (ne01 - ir1);
  86. }
  87. }
  88. } else {
  89. // casting between non-quantized types
  90. size_t id = 0;
  91. dst_t * dst_ptr = (dst_t *) dst->data;
  92. for (int i03 = 0; i03 < ne03; i03++) {
  93. for (int i02 = 0; i02 < ne02; i02++) {
  94. id += ne00 * ir0;
  95. for (int i01 = ir0; i01 < ir1; i01++) {
  96. const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  97. for (int i00 = 0; i00 < ne00; i00++) {
  98. float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
  99. dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
  100. id++;
  101. }
  102. }
  103. id += ne00 * (ne01 - ir1);
  104. }
  105. }
  106. }
  107. } else {
  108. //printf("%s: this is not optimal - fix me\n", __func__);
  109. size_t id = 0;
  110. dst_t * dst_ptr = (dst_t *) dst->data;
  111. for (int i03 = 0; i03 < ne03; i03++) {
  112. for (int i02 = 0; i02 < ne02; i02++) {
  113. id += ne00 * ir0;
  114. for (int i01 = ir0; i01 < ir1; i01++) {
  115. for (int i00 = 0; i00 < ne00; i00++) {
  116. const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  117. float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
  118. dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
  119. id++;
  120. }
  121. }
  122. id += ne00 * (ne01 - ir1);
  123. }
  124. }
  125. }
  126. return;
  127. }
  128. // dst counters
  129. int64_t i10 = 0;
  130. int64_t i11 = 0;
  131. int64_t i12 = 0;
  132. int64_t i13 = 0;
  133. if constexpr (std::is_same_v<dst_t, src_t>) {
  134. for (int64_t i03 = 0; i03 < ne03; i03++) {
  135. for (int64_t i02 = 0; i02 < ne02; i02++) {
  136. i10 += ne00 * ir0;
  137. while (i10 >= ne0) {
  138. i10 -= ne0;
  139. if (++i11 == ne1) {
  140. i11 = 0;
  141. if (++i12 == ne2) {
  142. i12 = 0;
  143. if (++i13 == ne3) {
  144. i13 = 0;
  145. }
  146. }
  147. }
  148. }
  149. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  150. for (int64_t i00 = 0; i00 < ne00; i00++) {
  151. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  152. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  153. memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
  154. if (++i10 == ne00) {
  155. i10 = 0;
  156. if (++i11 == ne01) {
  157. i11 = 0;
  158. if (++i12 == ne02) {
  159. i12 = 0;
  160. if (++i13 == ne03) {
  161. i13 = 0;
  162. }
  163. }
  164. }
  165. }
  166. }
  167. }
  168. i10 += ne00 * (ne01 - ir1);
  169. while (i10 >= ne0) {
  170. i10 -= ne0;
  171. if (++i11 == ne1) {
  172. i11 = 0;
  173. if (++i12 == ne2) {
  174. i12 = 0;
  175. if (++i13 == ne3) {
  176. i13 = 0;
  177. }
  178. }
  179. }
  180. }
  181. }
  182. }
  183. } else {
  184. for (int64_t i03 = 0; i03 < ne03; i03++) {
  185. for (int64_t i02 = 0; i02 < ne02; i02++) {
  186. i10 += ne00 * ir0;
  187. while (i10 >= ne0) {
  188. i10 -= ne0;
  189. if (++i11 == ne1) {
  190. i11 = 0;
  191. if (++i12 == ne2) {
  192. i12 = 0;
  193. if (++i13 == ne3) {
  194. i13 = 0;
  195. }
  196. }
  197. }
  198. }
  199. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  200. for (int64_t i00 = 0; i00 < ne00; i00++) {
  201. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  202. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  203. float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
  204. *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
  205. if (++i10 == ne0) {
  206. i10 = 0;
  207. if (++i11 == ne1) {
  208. i11 = 0;
  209. if (++i12 == ne2) {
  210. i12 = 0;
  211. if (++i13 == ne3) {
  212. i13 = 0;
  213. }
  214. }
  215. }
  216. }
  217. }
  218. }
  219. i10 += ne00 * (ne01 - ir1);
  220. while (i10 >= ne0) {
  221. i10 -= ne0;
  222. if (++i11 == ne1) {
  223. i11 = 0;
  224. if (++i12 == ne2) {
  225. i12 = 0;
  226. if (++i13 == ne3) {
  227. i13 = 0;
  228. }
  229. }
  230. }
  231. }
  232. }
  233. }
  234. }
  235. }
  236. template<typename src_t>
  237. static void ggml_compute_forward_dup_to_q(
  238. const ggml_compute_params * params,
  239. ggml_tensor * dst) {
  240. const ggml_tensor * src0 = dst->src[0];
  241. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  242. GGML_ASSERT(!ggml_is_quantized(src0->type));
  243. GGML_TENSOR_UNARY_OP_LOCALS
  244. const int ith = params->ith; // thread index
  245. const int nth = params->nth; // number of threads
  246. // parallelize by rows
  247. const int nr = ne01;
  248. // number of rows per thread
  249. const int dr = (nr + nth - 1) / nth;
  250. // row range for this thread
  251. const int ir0 = dr * ith;
  252. const int ir1 = MIN(ir0 + dr, nr);
  253. if (ggml_is_contiguous(dst) &&
  254. nb00 == sizeof(src_t) &&
  255. ggml_get_type_traits_cpu(dst->type)->from_float) {
  256. // casting non-quantized types --> intermediate f32 --> quantized
  257. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  258. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  259. size_t id = 0;
  260. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  261. char * dst_ptr = (char *) dst->data;
  262. for (int i03 = 0; i03 < ne03; i03++) {
  263. for (int i02 = 0; i02 < ne02; i02++) {
  264. id += rs * ir0;
  265. for (int i01 = ir0; i01 < ir1; i01++) {
  266. const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  267. for (int i00 = 0; i00 < ne00; i00++) {
  268. src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
  269. }
  270. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  271. id += rs;
  272. }
  273. id += rs * (ne01 - ir1);
  274. }
  275. }
  276. } else {
  277. // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
  278. GGML_ABORT("not implemented");
  279. }
  280. }
  281. // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
  282. static void ggml_compute_forward_dup_bytes(
  283. const ggml_compute_params * params,
  284. ggml_tensor * dst) {
  285. const ggml_tensor * src0 = dst->src[0];
  286. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  287. GGML_ASSERT(src0->type == dst->type);
  288. GGML_TENSOR_UNARY_OP_LOCALS;
  289. if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
  290. ggml_compute_forward_dup_same_cont(params, dst);
  291. return;
  292. }
  293. const size_t type_size = ggml_type_size(src0->type);
  294. const int ith = params->ith; // thread index
  295. const int nth = params->nth; // number of threads
  296. // parallelize by rows
  297. const int nr = ne01;
  298. // number of rows per thread
  299. const int dr = (nr + nth - 1) / nth;
  300. // row range for this thread
  301. const int ir0 = dr * ith;
  302. const int ir1 = MIN(ir0 + dr, nr);
  303. if (src0->type == dst->type &&
  304. ggml_are_same_shape(src0, dst) &&
  305. nb00 == type_size && nb0 == type_size) {
  306. // copy by rows
  307. const size_t rs = ggml_row_size(src0->type, ne00);
  308. for (int64_t i03 = 0; i03 < ne03; i03++) {
  309. for (int64_t i02 = 0; i02 < ne02; i02++) {
  310. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  311. memcpy(
  312. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  313. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  314. rs);
  315. }
  316. }
  317. }
  318. return;
  319. }
  320. if (ggml_is_contiguous(dst)) {
  321. size_t id = 0;
  322. char * dst_ptr = (char *) dst->data;
  323. const size_t rs = ne00 * type_size;
  324. if (nb00 == type_size) {
  325. // src0 is contigous on first dimension, copy by rows
  326. for (int64_t i03 = 0; i03 < ne03; i03++) {
  327. for (int64_t i02 = 0; i02 < ne02; i02++) {
  328. id += rs * ir0;
  329. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  330. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  331. memcpy(dst_ptr + id, src0_ptr, rs);
  332. id += rs;
  333. }
  334. id += rs * (ne01 - ir1);
  335. }
  336. }
  337. } else {
  338. //printf("%s: this is not optimal - fix me\n", __func__);
  339. for (int64_t i03 = 0; i03 < ne03; i03++) {
  340. for (int64_t i02 = 0; i02 < ne02; i02++) {
  341. id += rs * ir0;
  342. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  343. for (int64_t i00 = 0; i00 < ne00; i00++) {
  344. const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
  345. memcpy(dst_ptr + id, src0_ptr, type_size);
  346. id += type_size;
  347. }
  348. }
  349. id += rs * (ne01 - ir1);
  350. }
  351. }
  352. }
  353. return;
  354. }
  355. // dst counters
  356. int64_t k10 = 0;
  357. int64_t i11 = 0;
  358. int64_t i12 = 0;
  359. int64_t i13 = 0;
  360. // number of blocks in a row
  361. const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
  362. const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
  363. for (int64_t i03 = 0; i03 < ne03; i03++) {
  364. for (int64_t i02 = 0; i02 < ne02; i02++) {
  365. k10 += nk00 * ir0;
  366. while (k10 >= nk0) {
  367. k10 -= nk0;
  368. if (++i11 == ne1) {
  369. i11 = 0;
  370. if (++i12 == ne2) {
  371. i12 = 0;
  372. if (++i13 == ne3) {
  373. i13 = 0;
  374. }
  375. }
  376. }
  377. }
  378. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  379. for (int64_t k00 = 0; k00 < nk00; k00++) {
  380. const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  381. char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  382. memcpy(dst_ptr, src0_ptr, type_size);
  383. if (++k10 == nk0) {
  384. k10 = 0;
  385. if (++i11 == ne1) {
  386. i11 = 0;
  387. if (++i12 == ne2) {
  388. i12 = 0;
  389. if (++i13 == ne3) {
  390. i13 = 0;
  391. }
  392. }
  393. }
  394. }
  395. }
  396. }
  397. k10 += nk00 * (ne01 - ir1);
  398. while (k10 >= nk0) {
  399. k10 -= nk0;
  400. if (++i11 == ne1) {
  401. i11 = 0;
  402. if (++i12 == ne2) {
  403. i12 = 0;
  404. if (++i13 == ne3) {
  405. i13 = 0;
  406. }
  407. }
  408. }
  409. }
  410. }
  411. }
  412. }
  413. static void ggml_compute_forward_dup_from_q(
  414. const ggml_compute_params * params,
  415. ggml_tensor * dst) {
  416. const ggml_tensor * src0 = dst->src[0];
  417. const ggml_tensor * src1 = dst->src[1];
  418. GGML_TENSOR_BINARY_OP_LOCALS
  419. const ggml_type type = src0->type;
  420. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  421. size_t qk = ggml_blck_size(type);
  422. const int64_t nr = ggml_nelements(src1) / qk;
  423. // destination must be contiguous in the first dimension
  424. GGML_ASSERT(nb10 == ggml_type_size(dst->type));
  425. // must either have first dimension large enough to hold a row, or fully contiguous
  426. GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
  427. const int ith = params->ith;
  428. const int nth = params->nth;
  429. const int dr = (nr + nth - 1)/nth;
  430. // row range for this thread
  431. const int ir0 = dr*ith;
  432. const int ir1 = MIN(ir0 + dr, nr);
  433. for (int64_t ir = ir0; ir < ir1; ++ir) {
  434. uint32_t i = ir * qk;
  435. const int64_t i03 = i/(ne00 * ne01 * ne02);
  436. const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
  437. const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
  438. const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
  439. const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
  440. const int64_t i13 = i/(ne10 * ne11 * ne12);
  441. const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
  442. const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
  443. const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
  444. const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
  445. dequantize_row_q(
  446. (const void *) ((char *) src0->data + x_offset),
  447. (float *) ((char *) dst->data + dst_offset), qk);
  448. }
  449. }
  450. void ggml_compute_forward_dup(
  451. const ggml_compute_params * params,
  452. ggml_tensor * dst) {
  453. const ggml_tensor * src0 = dst->src[0];
  454. if (src0->type == dst->type) {
  455. ggml_compute_forward_dup_bytes(params, dst);
  456. return;
  457. }
  458. switch (src0->type) {
  459. case GGML_TYPE_F16:
  460. {
  461. /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
  462. else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
  463. else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
  464. else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
  465. } break;
  466. case GGML_TYPE_BF16:
  467. {
  468. /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
  469. else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
  470. else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
  471. else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
  472. } break;
  473. case GGML_TYPE_F32:
  474. {
  475. /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
  476. else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
  477. else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
  478. else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
  479. else ggml_compute_forward_dup_to_q<float>(params, dst);
  480. } break;
  481. case GGML_TYPE_I32:
  482. {
  483. if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
  484. else GGML_ABORT("not implemented");
  485. } break;
  486. default:
  487. {
  488. if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
  489. ggml_compute_forward_dup_from_q(params, dst);
  490. break;
  491. }
  492. GGML_ABORT("fatal error");
  493. }
  494. }
  495. }
  496. // ggml_compute_forward_add
  497. static void ggml_compute_forward_add_q_f32(
  498. const ggml_compute_params * params,
  499. ggml_tensor * dst) {
  500. const ggml_tensor * src0 = dst->src[0];
  501. const ggml_tensor * src1 = dst->src[1];
  502. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  503. const int nr = ggml_nrows(src0);
  504. GGML_TENSOR_BINARY_OP_LOCALS
  505. const int ith = params->ith;
  506. const int nth = params->nth;
  507. const ggml_type type = src0->type;
  508. const ggml_type dtype = dst->type;
  509. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  510. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
  511. // we don't support permuted src0 or src1
  512. GGML_ASSERT(nb00 == ggml_type_size(type));
  513. GGML_ASSERT(nb10 == sizeof(float));
  514. // dst cannot be transposed or permuted
  515. GGML_ASSERT(nb0 <= nb1);
  516. GGML_ASSERT(nb1 <= nb2);
  517. GGML_ASSERT(nb2 <= nb3);
  518. GGML_ASSERT(ggml_is_quantized(src0->type));
  519. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  520. // rows per thread
  521. const int dr = (nr + nth - 1)/nth;
  522. // row range for this thread
  523. const int ir0 = dr*ith;
  524. const int ir1 = MIN(ir0 + dr, nr);
  525. float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  526. for (int ir = ir0; ir < ir1; ++ir) {
  527. // src0 indices
  528. const int i03 = ir/(ne02*ne01);
  529. const int i02 = (ir - i03*ne02*ne01)/ne01;
  530. const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
  531. // src1 and dst are same shape as src0 => same indices
  532. const int i13 = i03;
  533. const int i12 = i02;
  534. const int i11 = i01;
  535. const int i3 = i03;
  536. const int i2 = i02;
  537. const int i1 = i01;
  538. void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
  539. float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
  540. void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  541. assert(ne00 % 32 == 0);
  542. // unquantize row from src0 to temp buffer
  543. dequantize_row_q(src0_row, wdata, ne00);
  544. // add src1
  545. ggml_vec_acc_f32(ne00, wdata, src1_row);
  546. // quantize row to dst
  547. if (quantize_row_q != NULL) {
  548. quantize_row_q(wdata, dst_row, ne00);
  549. } else {
  550. memcpy(dst_row, wdata, ne0*nb0);
  551. }
  552. }
  553. }
  554. void ggml_compute_forward_add(
  555. const ggml_compute_params * params,
  556. ggml_tensor * dst) {
  557. const ggml_tensor * src0 = dst->src[0];
  558. switch (src0->type) {
  559. case GGML_TYPE_F32:
  560. case GGML_TYPE_F16:
  561. case GGML_TYPE_BF16:
  562. {
  563. ggml_compute_forward_add_non_quantized(params, dst);
  564. } break;
  565. case GGML_TYPE_Q4_0:
  566. case GGML_TYPE_Q4_1:
  567. case GGML_TYPE_Q5_0:
  568. case GGML_TYPE_Q5_1:
  569. case GGML_TYPE_Q8_0:
  570. case GGML_TYPE_MXFP4:
  571. case GGML_TYPE_Q2_K:
  572. case GGML_TYPE_Q3_K:
  573. case GGML_TYPE_Q4_K:
  574. case GGML_TYPE_Q5_K:
  575. case GGML_TYPE_Q6_K:
  576. case GGML_TYPE_TQ1_0:
  577. case GGML_TYPE_TQ2_0:
  578. case GGML_TYPE_IQ2_XXS:
  579. case GGML_TYPE_IQ2_XS:
  580. case GGML_TYPE_IQ3_XXS:
  581. case GGML_TYPE_IQ1_S:
  582. case GGML_TYPE_IQ1_M:
  583. case GGML_TYPE_IQ4_NL:
  584. case GGML_TYPE_IQ4_XS:
  585. case GGML_TYPE_IQ3_S:
  586. case GGML_TYPE_IQ2_S:
  587. {
  588. ggml_compute_forward_add_q_f32(params, dst);
  589. } break;
  590. default:
  591. {
  592. GGML_ABORT("fatal error");
  593. }
  594. }
  595. }
  596. // ggml_compute_forward_add_id
  597. static void ggml_compute_forward_add_id_f32(
  598. const ggml_compute_params * params,
  599. ggml_tensor * dst) {
  600. const ggml_tensor * src0 = dst->src[0];
  601. const ggml_tensor * src1 = dst->src[1];
  602. const ggml_tensor * src2 = dst->src[2];
  603. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  604. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  605. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  606. GGML_ASSERT(src2->type == GGML_TYPE_I32);
  607. GGML_ASSERT(src0->nb[0] == sizeof(float));
  608. GGML_ASSERT(src1->nb[0] == sizeof(float));
  609. const int ith = params->ith;
  610. const int nth = params->nth;
  611. const int nr = ggml_nrows(src0);
  612. GGML_TENSOR_TERNARY_OP_LOCALS
  613. GGML_ASSERT( nb0 == sizeof(float));
  614. GGML_ASSERT(nb10 == sizeof(float));
  615. // rows per thread
  616. const int dr = (nr + nth - 1)/nth;
  617. // row range for this thread
  618. const int ir0 = dr*ith;
  619. const int ir1 = MIN(ir0 + dr, nr);
  620. for (int ir = ir0; ir < ir1; ++ir) {
  621. // src0 indices
  622. const int i3 = ir/(ne2*ne1);
  623. const int i2 = (ir - i3*ne2*ne1)/ne1;
  624. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  625. // src1 indices
  626. const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
  627. GGML_ASSERT(i11 >= 0 && i11 < ne11);
  628. ggml_vec_add_f32(ne0,
  629. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
  630. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  631. (float *) ((char *) src1->data + i11*nb11));
  632. }
  633. }
  634. void ggml_compute_forward_add_id(
  635. const ggml_compute_params * params,
  636. ggml_tensor * dst) {
  637. const ggml_tensor * src0 = dst->src[0];
  638. switch (src0->type) {
  639. case GGML_TYPE_F32:
  640. {
  641. ggml_compute_forward_add_id_f32(params, dst);
  642. } break;
  643. default:
  644. {
  645. GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
  646. }
  647. }
  648. }
  649. // ggml_compute_forward_add1
  650. static void ggml_compute_forward_add1_f32(
  651. const ggml_compute_params * params,
  652. ggml_tensor * dst) {
  653. const ggml_tensor * src0 = dst->src[0];
  654. const ggml_tensor * src1 = dst->src[1];
  655. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  656. GGML_ASSERT(ggml_is_scalar(src1));
  657. const int ith = params->ith;
  658. const int nth = params->nth;
  659. const int nr = ggml_nrows(src0);
  660. GGML_TENSOR_UNARY_OP_LOCALS
  661. GGML_ASSERT( nb0 == sizeof(float));
  662. GGML_ASSERT(nb00 == sizeof(float));
  663. // rows per thread
  664. const int dr = (nr + nth - 1)/nth;
  665. // row range for this thread
  666. const int ir0 = dr*ith;
  667. const int ir1 = MIN(ir0 + dr, nr);
  668. for (int ir = ir0; ir < ir1; ++ir) {
  669. // src0 and dst are same shape => same indices
  670. const int i3 = ir/(ne2*ne1);
  671. const int i2 = (ir - i3*ne2*ne1)/ne1;
  672. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  673. #ifdef GGML_USE_ACCELERATE
  674. GGML_UNUSED(ggml_vec_add1_f32);
  675. vDSP_vadd(
  676. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
  677. (float *) ((char *) src1->data), 0,
  678. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
  679. ne0);
  680. #else
  681. ggml_vec_add1_f32(ne0,
  682. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
  683. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  684. *(float *) src1->data);
  685. #endif
  686. }
  687. }
  688. static void ggml_compute_forward_add1_f16_f32(
  689. const ggml_compute_params * params,
  690. ggml_tensor * dst) {
  691. const ggml_tensor * src0 = dst->src[0];
  692. const ggml_tensor * src1 = dst->src[1];
  693. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  694. GGML_ASSERT(ggml_is_scalar(src1));
  695. // scalar to add
  696. const float v = *(float *) src1->data;
  697. const int ith = params->ith;
  698. const int nth = params->nth;
  699. const int nr = ggml_nrows(src0);
  700. GGML_TENSOR_UNARY_OP_LOCALS
  701. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  702. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  703. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  704. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  705. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  706. // rows per thread
  707. const int dr = (nr + nth - 1)/nth;
  708. // row range for this thread
  709. const int ir0 = dr*ith;
  710. const int ir1 = MIN(ir0 + dr, nr);
  711. for (int ir = ir0; ir < ir1; ++ir) {
  712. // src0 and dst are same shape => same indices
  713. const int i3 = ir/(ne2*ne1);
  714. const int i2 = (ir - i3*ne2*ne1)/ne1;
  715. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  716. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  717. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  718. for (int i = 0; i < ne0; i++) {
  719. dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
  720. }
  721. }
  722. }
  723. static void ggml_compute_forward_add1_f16_f16(
  724. const ggml_compute_params * params,
  725. ggml_tensor * dst) {
  726. const ggml_tensor * src0 = dst->src[0];
  727. const ggml_tensor * src1 = dst->src[1];
  728. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  729. GGML_ASSERT(ggml_is_scalar(src1));
  730. // scalar to add
  731. const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
  732. const int ith = params->ith;
  733. const int nth = params->nth;
  734. const int nr = ggml_nrows(src0);
  735. GGML_TENSOR_UNARY_OP_LOCALS
  736. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  737. GGML_ASSERT(src1->type == GGML_TYPE_F16);
  738. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  739. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  740. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  741. // rows per thread
  742. const int dr = (nr + nth - 1)/nth;
  743. // row range for this thread
  744. const int ir0 = dr*ith;
  745. const int ir1 = MIN(ir0 + dr, nr);
  746. for (int ir = ir0; ir < ir1; ++ir) {
  747. // src0 and dst are same shape => same indices
  748. const int i3 = ir/(ne2*ne1);
  749. const int i2 = (ir - i3*ne2*ne1)/ne1;
  750. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  751. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  752. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  753. for (int i = 0; i < ne0; i++) {
  754. dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
  755. }
  756. }
  757. }
  758. static void ggml_compute_forward_add1_q_f32(
  759. const ggml_compute_params * params,
  760. ggml_tensor * dst) {
  761. const ggml_tensor * src0 = dst->src[0];
  762. const ggml_tensor * src1 = dst->src[1];
  763. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  764. GGML_ASSERT(ggml_is_scalar(src1));
  765. // scalar to add
  766. const float v = *(float *) src1->data;
  767. const int ith = params->ith;
  768. const int nth = params->nth;
  769. const int nr = ggml_nrows(src0);
  770. GGML_TENSOR_UNARY_OP_LOCALS
  771. const ggml_type type = src0->type;
  772. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  773. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
  774. // we don't support permuted src0
  775. GGML_ASSERT(nb00 == ggml_type_size(type));
  776. // dst cannot be transposed or permuted
  777. GGML_ASSERT(nb0 <= nb1);
  778. GGML_ASSERT(nb1 <= nb2);
  779. GGML_ASSERT(nb2 <= nb3);
  780. GGML_ASSERT(ggml_is_quantized(src0->type));
  781. GGML_ASSERT(dst->type == src0->type);
  782. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  783. // rows per thread
  784. const int dr = (nr + nth - 1)/nth;
  785. // row range for this thread
  786. const int ir0 = dr*ith;
  787. const int ir1 = MIN(ir0 + dr, nr);
  788. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  789. for (int ir = ir0; ir < ir1; ++ir) {
  790. // src0 and dst are same shape => same indices
  791. const int i3 = ir/(ne2*ne1);
  792. const int i2 = (ir - i3*ne2*ne1)/ne1;
  793. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  794. void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
  795. void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
  796. assert(ne0 % 32 == 0);
  797. // unquantize row from src0 to temp buffer
  798. dequantize_row_q(src0_row, wdata, ne0);
  799. // add src1
  800. ggml_vec_acc1_f32(ne0, wdata, v);
  801. // quantize row to dst
  802. quantize_row_q(wdata, dst_row, ne0);
  803. }
  804. }
  805. static void ggml_compute_forward_add1_bf16_f32(
  806. const ggml_compute_params * params,
  807. ggml_tensor * dst) {
  808. const ggml_tensor * src0 = dst->src[0];
  809. const ggml_tensor * src1 = dst->src[1];
  810. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  811. GGML_ASSERT(ggml_is_scalar(src1));
  812. // scalar to add
  813. const float v = *(float *) src1->data;
  814. const int ith = params->ith;
  815. const int nth = params->nth;
  816. const int nr = ggml_nrows(src0);
  817. GGML_TENSOR_UNARY_OP_LOCALS
  818. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  819. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  820. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  821. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  822. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  823. // rows per thread
  824. const int dr = (nr + nth - 1)/nth;
  825. // row range for this thread
  826. const int ir0 = dr*ith;
  827. const int ir1 = MIN(ir0 + dr, nr);
  828. for (int ir = ir0; ir < ir1; ++ir) {
  829. // src0 and dst are same shape => same indices
  830. const int i3 = ir/(ne2*ne1);
  831. const int i2 = (ir - i3*ne2*ne1)/ne1;
  832. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  833. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  834. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  835. for (int i = 0; i < ne0; i++) {
  836. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  837. }
  838. }
  839. }
  840. static void ggml_compute_forward_add1_bf16_bf16(
  841. const ggml_compute_params * params,
  842. ggml_tensor * dst) {
  843. const ggml_tensor * src0 = dst->src[0];
  844. const ggml_tensor * src1 = dst->src[1];
  845. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  846. GGML_ASSERT(ggml_is_scalar(src1));
  847. // scalar to add
  848. const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
  849. const int ith = params->ith;
  850. const int nth = params->nth;
  851. const int nr = ggml_nrows(src0);
  852. GGML_TENSOR_UNARY_OP_LOCALS
  853. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  854. GGML_ASSERT(src1->type == GGML_TYPE_BF16);
  855. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  856. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  857. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  858. // rows per thread
  859. const int dr = (nr + nth - 1)/nth;
  860. // row range for this thread
  861. const int ir0 = dr*ith;
  862. const int ir1 = MIN(ir0 + dr, nr);
  863. for (int ir = ir0; ir < ir1; ++ir) {
  864. // src0 and dst are same shape => same indices
  865. const int i3 = ir/(ne2*ne1);
  866. const int i2 = (ir - i3*ne2*ne1)/ne1;
  867. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  868. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  869. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  870. for (int i = 0; i < ne0; i++) {
  871. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  872. }
  873. }
  874. }
  875. void ggml_compute_forward_add1(
  876. const ggml_compute_params * params,
  877. ggml_tensor * dst) {
  878. const ggml_tensor * src0 = dst->src[0];
  879. const ggml_tensor * src1 = dst->src[1];
  880. switch (src0->type) {
  881. case GGML_TYPE_F32:
  882. {
  883. ggml_compute_forward_add1_f32(params, dst);
  884. } break;
  885. case GGML_TYPE_F16:
  886. {
  887. if (src1->type == GGML_TYPE_F16) {
  888. ggml_compute_forward_add1_f16_f16(params, dst);
  889. }
  890. else if (src1->type == GGML_TYPE_F32) {
  891. ggml_compute_forward_add1_f16_f32(params, dst);
  892. }
  893. else {
  894. GGML_ABORT("fatal error");
  895. }
  896. } break;
  897. case GGML_TYPE_BF16:
  898. {
  899. if (src1->type == GGML_TYPE_BF16) {
  900. ggml_compute_forward_add1_bf16_bf16(params, dst);
  901. }
  902. else if (src1->type == GGML_TYPE_F32) {
  903. ggml_compute_forward_add1_bf16_f32(params, dst);
  904. }
  905. else {
  906. GGML_ABORT("fatal error");
  907. }
  908. } break;
  909. case GGML_TYPE_Q4_0:
  910. case GGML_TYPE_Q4_1:
  911. case GGML_TYPE_Q5_0:
  912. case GGML_TYPE_Q5_1:
  913. case GGML_TYPE_Q8_0:
  914. case GGML_TYPE_Q8_1:
  915. case GGML_TYPE_MXFP4:
  916. case GGML_TYPE_Q2_K:
  917. case GGML_TYPE_Q3_K:
  918. case GGML_TYPE_Q4_K:
  919. case GGML_TYPE_Q5_K:
  920. case GGML_TYPE_Q6_K:
  921. case GGML_TYPE_TQ1_0:
  922. case GGML_TYPE_TQ2_0:
  923. case GGML_TYPE_IQ2_XXS:
  924. case GGML_TYPE_IQ2_XS:
  925. case GGML_TYPE_IQ3_XXS:
  926. case GGML_TYPE_IQ1_S:
  927. case GGML_TYPE_IQ1_M:
  928. case GGML_TYPE_IQ4_NL:
  929. case GGML_TYPE_IQ4_XS:
  930. case GGML_TYPE_IQ3_S:
  931. case GGML_TYPE_IQ2_S:
  932. {
  933. ggml_compute_forward_add1_q_f32(params, dst);
  934. } break;
  935. default:
  936. {
  937. GGML_ABORT("fatal error");
  938. }
  939. }
  940. }
  941. // ggml_compute_forward_acc
  942. static void ggml_compute_forward_acc_f32(
  943. const ggml_compute_params * params,
  944. ggml_tensor * dst) {
  945. const ggml_tensor * src0 = dst->src[0];
  946. const ggml_tensor * src1 = dst->src[1];
  947. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  948. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  949. // view src0 and dst with these strides and data offset inbytes during acc
  950. // nb0 is implicitly element_size because src0 and dst are contiguous
  951. size_t nb1 = ((int32_t *) dst->op_params)[0];
  952. size_t nb2 = ((int32_t *) dst->op_params)[1];
  953. size_t nb3 = ((int32_t *) dst->op_params)[2];
  954. size_t offset = ((int32_t *) dst->op_params)[3];
  955. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  956. if (!inplace) {
  957. if (params->ith == 0) {
  958. // memcpy needs to be synchronized across threads to avoid race conditions.
  959. // => do it in INIT phase
  960. memcpy(
  961. ((char *) dst->data),
  962. ((char *) src0->data),
  963. ggml_nbytes(dst));
  964. }
  965. ggml_barrier(params->threadpool);
  966. }
  967. const int ith = params->ith;
  968. const int nth = params->nth;
  969. const int nr = ggml_nrows(src1);
  970. const int nc = src1->ne[0];
  971. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  972. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  973. // src0 and dst as viewed during acc
  974. const size_t nb0 = ggml_element_size(src0);
  975. const size_t nb00 = nb0;
  976. const size_t nb01 = nb1;
  977. const size_t nb02 = nb2;
  978. const size_t nb03 = nb3;
  979. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
  980. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
  981. GGML_ASSERT(nb10 == sizeof(float));
  982. // rows per thread
  983. const int dr = (nr + nth - 1)/nth;
  984. // row range for this thread
  985. const int ir0 = dr*ith;
  986. const int ir1 = MIN(ir0 + dr, nr);
  987. for (int ir = ir0; ir < ir1; ++ir) {
  988. // src0 and dst are viewed with shape of src1 and offset
  989. // => same indices
  990. const int i3 = ir/(ne12*ne11);
  991. const int i2 = (ir - i3*ne12*ne11)/ne11;
  992. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  993. #ifdef GGML_USE_ACCELERATE
  994. vDSP_vadd(
  995. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
  996. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
  997. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
  998. #else
  999. ggml_vec_add_f32(nc,
  1000. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  1001. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
  1002. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  1003. #endif
  1004. }
  1005. }
  1006. void ggml_compute_forward_acc(
  1007. const ggml_compute_params * params,
  1008. ggml_tensor * dst) {
  1009. const ggml_tensor * src0 = dst->src[0];
  1010. switch (src0->type) {
  1011. case GGML_TYPE_F32:
  1012. {
  1013. ggml_compute_forward_acc_f32(params, dst);
  1014. } break;
  1015. case GGML_TYPE_F16:
  1016. case GGML_TYPE_BF16:
  1017. case GGML_TYPE_Q4_0:
  1018. case GGML_TYPE_Q4_1:
  1019. case GGML_TYPE_Q5_0:
  1020. case GGML_TYPE_Q5_1:
  1021. case GGML_TYPE_Q8_0:
  1022. case GGML_TYPE_Q8_1:
  1023. case GGML_TYPE_MXFP4:
  1024. case GGML_TYPE_Q2_K:
  1025. case GGML_TYPE_Q3_K:
  1026. case GGML_TYPE_Q4_K:
  1027. case GGML_TYPE_Q5_K:
  1028. case GGML_TYPE_Q6_K:
  1029. case GGML_TYPE_TQ1_0:
  1030. case GGML_TYPE_TQ2_0:
  1031. case GGML_TYPE_IQ2_XXS:
  1032. case GGML_TYPE_IQ2_XS:
  1033. case GGML_TYPE_IQ3_XXS:
  1034. case GGML_TYPE_IQ1_S:
  1035. case GGML_TYPE_IQ1_M:
  1036. case GGML_TYPE_IQ4_NL:
  1037. case GGML_TYPE_IQ4_XS:
  1038. case GGML_TYPE_IQ3_S:
  1039. case GGML_TYPE_IQ2_S:
  1040. default:
  1041. {
  1042. GGML_ABORT("fatal error");
  1043. }
  1044. }
  1045. }
  1046. // ggml_compute_forward_sum
  1047. static void ggml_compute_forward_sum_f32(
  1048. const ggml_compute_params * params,
  1049. ggml_tensor * dst) {
  1050. const ggml_tensor * src0 = dst->src[0];
  1051. if (params->ith != 0) {
  1052. return;
  1053. }
  1054. assert(ggml_is_scalar(dst));
  1055. assert(src0->nb[0] == sizeof(float));
  1056. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1057. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1058. ggml_float sum = 0;
  1059. ggml_float row_sum = 0;
  1060. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1061. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1062. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1063. ggml_vec_sum_f32_ggf(ne00,
  1064. &row_sum,
  1065. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  1066. sum += row_sum;
  1067. }
  1068. }
  1069. }
  1070. ((float *) dst->data)[0] = sum;
  1071. }
  1072. static void ggml_compute_forward_sum_f16(
  1073. const ggml_compute_params * params,
  1074. ggml_tensor * dst) {
  1075. const ggml_tensor * src0 = dst->src[0];
  1076. if (params->ith != 0) {
  1077. return;
  1078. }
  1079. assert(ggml_is_scalar(dst));
  1080. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  1081. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1082. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1083. float sum = 0;
  1084. float row_sum = 0;
  1085. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1086. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1087. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1088. ggml_vec_sum_f16_ggf(ne00,
  1089. &row_sum,
  1090. (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  1091. sum += row_sum;
  1092. }
  1093. }
  1094. }
  1095. ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
  1096. }
  1097. static void ggml_compute_forward_sum_bf16(
  1098. const ggml_compute_params * params,
  1099. ggml_tensor * dst) {
  1100. const ggml_tensor * src0 = dst->src[0];
  1101. if (params->ith != 0) {
  1102. return;
  1103. }
  1104. assert(ggml_is_scalar(dst));
  1105. assert(src0->nb[0] == sizeof(ggml_bf16_t));
  1106. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1107. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1108. float sum = 0;
  1109. float row_sum = 0;
  1110. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1111. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1112. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1113. ggml_vec_sum_bf16_ggf(ne00,
  1114. &row_sum,
  1115. (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  1116. sum += row_sum;
  1117. }
  1118. }
  1119. }
  1120. ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
  1121. }
  1122. void ggml_compute_forward_sum(
  1123. const ggml_compute_params * params,
  1124. ggml_tensor * dst) {
  1125. const ggml_tensor * src0 = dst->src[0];
  1126. switch (src0->type) {
  1127. case GGML_TYPE_F32:
  1128. {
  1129. ggml_compute_forward_sum_f32(params, dst);
  1130. } break;
  1131. case GGML_TYPE_F16:
  1132. {
  1133. ggml_compute_forward_sum_f16(params, dst);
  1134. } break;
  1135. case GGML_TYPE_BF16:
  1136. {
  1137. ggml_compute_forward_sum_bf16(params, dst);
  1138. } break;
  1139. default:
  1140. {
  1141. GGML_ABORT("fatal error");
  1142. }
  1143. }
  1144. }
  1145. // ggml_compute_forward_sum_rows
  1146. static void ggml_compute_forward_sum_rows_f32(
  1147. const ggml_compute_params * params,
  1148. ggml_tensor * dst) {
  1149. const ggml_tensor * src0 = dst->src[0];
  1150. if (params->ith != 0) {
  1151. return;
  1152. }
  1153. GGML_ASSERT(src0->nb[0] == sizeof(float));
  1154. GGML_ASSERT(dst->nb[0] == sizeof(float));
  1155. GGML_TENSOR_UNARY_OP_LOCALS
  1156. GGML_ASSERT(ne0 == 1);
  1157. GGML_ASSERT(ne1 == ne01);
  1158. GGML_ASSERT(ne2 == ne02);
  1159. GGML_ASSERT(ne3 == ne03);
  1160. for (int64_t i3 = 0; i3 < ne03; i3++) {
  1161. for (int64_t i2 = 0; i2 < ne02; i2++) {
  1162. for (int64_t i1 = 0; i1 < ne01; i1++) {
  1163. float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
  1164. float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
  1165. float row_sum = 0;
  1166. ggml_vec_sum_f32(ne00, &row_sum, src_row);
  1167. dst_row[0] = row_sum;
  1168. }
  1169. }
  1170. }
  1171. }
  1172. void ggml_compute_forward_sum_rows(
  1173. const ggml_compute_params * params,
  1174. ggml_tensor * dst) {
  1175. const ggml_tensor * src0 = dst->src[0];
  1176. switch (src0->type) {
  1177. case GGML_TYPE_F32:
  1178. {
  1179. ggml_compute_forward_sum_rows_f32(params, dst);
  1180. } break;
  1181. default:
  1182. {
  1183. GGML_ABORT("fatal error");
  1184. }
  1185. }
  1186. }
  1187. // ggml_compute_forward_mean
  1188. static void ggml_compute_forward_mean_f32(
  1189. const ggml_compute_params * params,
  1190. ggml_tensor * dst) {
  1191. const ggml_tensor * src0 = dst->src[0];
  1192. if (params->ith != 0) {
  1193. return;
  1194. }
  1195. assert(src0->nb[0] == sizeof(float));
  1196. GGML_TENSOR_UNARY_OP_LOCALS
  1197. assert(ne0 == 1);
  1198. assert(ne1 == ne01);
  1199. assert(ne2 == ne02);
  1200. assert(ne3 == ne03);
  1201. GGML_UNUSED(ne0);
  1202. GGML_UNUSED(ne1);
  1203. GGML_UNUSED(ne2);
  1204. GGML_UNUSED(ne3);
  1205. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1206. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1207. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1208. ggml_vec_sum_f32(ne00,
  1209. (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  1210. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  1211. *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
  1212. }
  1213. }
  1214. }
  1215. }
  1216. void ggml_compute_forward_mean(
  1217. const ggml_compute_params * params,
  1218. ggml_tensor * dst) {
  1219. const ggml_tensor * src0 = dst->src[0];
  1220. switch (src0->type) {
  1221. case GGML_TYPE_F32:
  1222. {
  1223. ggml_compute_forward_mean_f32(params, dst);
  1224. } break;
  1225. default:
  1226. {
  1227. GGML_ABORT("fatal error");
  1228. }
  1229. }
  1230. }
  1231. // ggml_compute_forward_argmax
  1232. static void ggml_compute_forward_argmax_f32(
  1233. const ggml_compute_params * params,
  1234. ggml_tensor * dst) {
  1235. const ggml_tensor * src0 = dst->src[0];
  1236. if (params->ith != 0) {
  1237. return;
  1238. }
  1239. assert(src0->nb[0] == sizeof(float));
  1240. assert(dst->nb[0] == sizeof(float));
  1241. const int64_t ne00 = src0->ne[0];
  1242. const int64_t ne01 = src0->ne[1];
  1243. const size_t nb01 = src0->nb[1];
  1244. const size_t nb0 = dst->nb[0];
  1245. for (int64_t i1 = 0; i1 < ne01; i1++) {
  1246. float * src = (float *) ((char *) src0->data + i1*nb01);
  1247. int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
  1248. int v = 0;
  1249. ggml_vec_argmax_f32(ne00, &v, src);
  1250. dst_[0] = v;
  1251. }
  1252. }
  1253. void ggml_compute_forward_argmax(
  1254. const ggml_compute_params * params,
  1255. ggml_tensor * dst) {
  1256. const ggml_tensor * src0 = dst->src[0];
  1257. switch (src0->type) {
  1258. case GGML_TYPE_F32:
  1259. {
  1260. ggml_compute_forward_argmax_f32(params, dst);
  1261. } break;
  1262. default:
  1263. {
  1264. GGML_ABORT("fatal error");
  1265. }
  1266. }
  1267. }
  1268. // ggml_compute_forward_count_equal
  1269. static void ggml_compute_forward_count_equal_i32(
  1270. const ggml_compute_params * params,
  1271. ggml_tensor * dst) {
  1272. const ggml_tensor * src0 = dst->src[0];
  1273. const ggml_tensor * src1 = dst->src[1];
  1274. GGML_TENSOR_BINARY_OP_LOCALS;
  1275. GGML_ASSERT(src0->type == GGML_TYPE_I32);
  1276. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  1277. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  1278. GGML_ASSERT(ggml_is_scalar(dst));
  1279. GGML_ASSERT(dst->type == GGML_TYPE_I64);
  1280. const int64_t nr = ggml_nrows(src0);
  1281. const int ith = params->ith;
  1282. const int nth = params->nth;
  1283. int64_t * sums = (int64_t *) params->wdata;
  1284. int64_t sum_thread = 0;
  1285. // rows per thread
  1286. const int64_t dr = (nr + nth - 1)/nth;
  1287. // row range for this thread
  1288. const int64_t ir0 = dr*ith;
  1289. const int64_t ir1 = MIN(ir0 + dr, nr);
  1290. for (int64_t ir = ir0; ir < ir1; ++ir) {
  1291. const int64_t i03 = ir / (ne02*ne01);
  1292. const int64_t i02 = (ir - i03*ne03) / ne01;
  1293. const int64_t i01 = ir - i03*ne03 - i02*ne02;
  1294. const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
  1295. const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
  1296. for (int64_t i00 = 0; i00 < ne00; ++i00) {
  1297. const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
  1298. const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
  1299. sum_thread += val0 == val1;
  1300. }
  1301. }
  1302. if (ith != 0) {
  1303. sums[ith] = sum_thread;
  1304. }
  1305. ggml_barrier(params->threadpool);
  1306. if (ith != 0) {
  1307. return;
  1308. }
  1309. for (int ith_other = 1; ith_other < nth; ++ith_other) {
  1310. sum_thread += sums[ith_other];
  1311. }
  1312. *((int64_t *) dst->data) = sum_thread;
  1313. }
  1314. void ggml_compute_forward_count_equal(
  1315. const ggml_compute_params * params,
  1316. ggml_tensor * dst) {
  1317. const ggml_tensor * src0 = dst->src[0];
  1318. switch (src0->type) {
  1319. case GGML_TYPE_I32:
  1320. {
  1321. ggml_compute_forward_count_equal_i32(params, dst);
  1322. } break;
  1323. default:
  1324. {
  1325. GGML_ABORT("fatal error");
  1326. }
  1327. }
  1328. }
  1329. // ggml_compute_forward_repeat
  1330. static void ggml_compute_forward_repeat_f32(
  1331. const ggml_compute_params * params,
  1332. ggml_tensor * dst) {
  1333. const ggml_tensor * src0 = dst->src[0];
  1334. if (params->ith != 0) {
  1335. return;
  1336. }
  1337. GGML_ASSERT(ggml_can_repeat(src0, dst));
  1338. GGML_TENSOR_UNARY_OP_LOCALS
  1339. // guaranteed to be an integer due to the check in ggml_can_repeat
  1340. const int nr0 = (int)(ne0/ne00);
  1341. const int nr1 = (int)(ne1/ne01);
  1342. const int nr2 = (int)(ne2/ne02);
  1343. const int nr3 = (int)(ne3/ne03);
  1344. // TODO: support for transposed / permuted tensors
  1345. GGML_ASSERT(nb0 == sizeof(float));
  1346. GGML_ASSERT(nb00 == sizeof(float));
  1347. // TODO: maybe this is not optimal?
  1348. for (int i3 = 0; i3 < nr3; i3++) {
  1349. for (int k3 = 0; k3 < ne03; k3++) {
  1350. for (int i2 = 0; i2 < nr2; i2++) {
  1351. for (int k2 = 0; k2 < ne02; k2++) {
  1352. for (int i1 = 0; i1 < nr1; i1++) {
  1353. for (int k1 = 0; k1 < ne01; k1++) {
  1354. for (int i0 = 0; i0 < nr0; i0++) {
  1355. ggml_vec_cpy_f32(ne00,
  1356. (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
  1357. (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
  1358. }
  1359. }
  1360. }
  1361. }
  1362. }
  1363. }
  1364. }
  1365. }
  1366. static void ggml_compute_forward_repeat_f16(
  1367. const ggml_compute_params * params,
  1368. ggml_tensor * dst) {
  1369. const ggml_tensor * src0 = dst->src[0];
  1370. if (params->ith != 0) {
  1371. return;
  1372. }
  1373. GGML_ASSERT(ggml_can_repeat(src0, dst));
  1374. GGML_TENSOR_UNARY_OP_LOCALS
  1375. // guaranteed to be an integer due to the check in ggml_can_repeat
  1376. const int nr0 = (int)(ne0/ne00);
  1377. const int nr1 = (int)(ne1/ne01);
  1378. const int nr2 = (int)(ne2/ne02);
  1379. const int nr3 = (int)(ne3/ne03);
  1380. // TODO: support for transposed / permuted tensors
  1381. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  1382. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1383. // TODO: maybe this is not optimal?
  1384. for (int i3 = 0; i3 < nr3; i3++) {
  1385. for (int k3 = 0; k3 < ne03; k3++) {
  1386. for (int i2 = 0; i2 < nr2; i2++) {
  1387. for (int k2 = 0; k2 < ne02; k2++) {
  1388. for (int i1 = 0; i1 < nr1; i1++) {
  1389. for (int k1 = 0; k1 < ne01; k1++) {
  1390. for (int i0 = 0; i0 < nr0; i0++) {
  1391. ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
  1392. ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
  1393. // ggml_vec_cpy_f16(ne00, y, x)
  1394. for (int i = 0; i < ne00; ++i) {
  1395. y[i] = x[i];
  1396. }
  1397. }
  1398. }
  1399. }
  1400. }
  1401. }
  1402. }
  1403. }
  1404. }
  1405. void ggml_compute_forward_repeat(
  1406. const ggml_compute_params * params,
  1407. ggml_tensor * dst) {
  1408. const ggml_tensor * src0 = dst->src[0];
  1409. switch (src0->type) {
  1410. case GGML_TYPE_F16:
  1411. case GGML_TYPE_BF16:
  1412. case GGML_TYPE_I16:
  1413. {
  1414. ggml_compute_forward_repeat_f16(params, dst);
  1415. } break;
  1416. case GGML_TYPE_F32:
  1417. case GGML_TYPE_I32:
  1418. {
  1419. ggml_compute_forward_repeat_f32(params, dst);
  1420. } break;
  1421. // TODO: templateify the implemenation and support for I64
  1422. // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
  1423. //case GGML_TYPE_I64:
  1424. // {
  1425. // ggml_compute_forward_repeat_i64(params, dst);
  1426. // } break;
  1427. default:
  1428. {
  1429. GGML_ABORT("fatal error");
  1430. }
  1431. }
  1432. }
  1433. // ggml_compute_forward_repeat_back
  1434. static void ggml_compute_forward_repeat_back_f32(
  1435. const ggml_compute_params * params,
  1436. ggml_tensor * dst) {
  1437. const ggml_tensor * src0 = dst->src[0];
  1438. if (params->ith != 0) {
  1439. return;
  1440. }
  1441. GGML_ASSERT(ggml_can_repeat(dst, src0));
  1442. GGML_TENSOR_UNARY_OP_LOCALS
  1443. // guaranteed to be an integer due to the check in ggml_can_repeat
  1444. const int nr0 = (int)(ne00/ne0);
  1445. const int nr1 = (int)(ne01/ne1);
  1446. const int nr2 = (int)(ne02/ne2);
  1447. const int nr3 = (int)(ne03/ne3);
  1448. // TODO: support for transposed / permuted tensors
  1449. GGML_ASSERT(nb0 == sizeof(float));
  1450. GGML_ASSERT(nb00 == sizeof(float));
  1451. if (ggml_is_contiguous(dst)) {
  1452. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  1453. } else {
  1454. for (int k3 = 0; k3 < ne3; k3++) {
  1455. for (int k2 = 0; k2 < ne2; k2++) {
  1456. for (int k1 = 0; k1 < ne1; k1++) {
  1457. ggml_vec_set_f32(ne0,
  1458. (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
  1459. 0);
  1460. }
  1461. }
  1462. }
  1463. }
  1464. // TODO: maybe this is not optimal?
  1465. for (int i3 = 0; i3 < nr3; i3++) {
  1466. for (int k3 = 0; k3 < ne3; k3++) {
  1467. for (int i2 = 0; i2 < nr2; i2++) {
  1468. for (int k2 = 0; k2 < ne2; k2++) {
  1469. for (int i1 = 0; i1 < nr1; i1++) {
  1470. for (int k1 = 0; k1 < ne1; k1++) {
  1471. for (int i0 = 0; i0 < nr0; i0++) {
  1472. ggml_vec_acc_f32(ne0,
  1473. (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
  1474. (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
  1475. }
  1476. }
  1477. }
  1478. }
  1479. }
  1480. }
  1481. }
  1482. }
  1483. void ggml_compute_forward_repeat_back(
  1484. const ggml_compute_params * params,
  1485. ggml_tensor * dst) {
  1486. const ggml_tensor * src0 = dst->src[0];
  1487. switch (src0->type) {
  1488. case GGML_TYPE_F32:
  1489. {
  1490. ggml_compute_forward_repeat_back_f32(params, dst);
  1491. } break;
  1492. default:
  1493. {
  1494. GGML_ABORT("fatal error");
  1495. }
  1496. }
  1497. }
  1498. // ggml_compute_forward_concat
  1499. static void ggml_compute_forward_concat_any(
  1500. const ggml_compute_params * params,
  1501. ggml_tensor * dst) {
  1502. const ggml_tensor * src0 = dst->src[0];
  1503. const ggml_tensor * src1 = dst->src[1];
  1504. const size_t len = ggml_type_size(src0->type);
  1505. const int ith = params->ith;
  1506. const int nth = params->nth;
  1507. GGML_TENSOR_BINARY_OP_LOCALS
  1508. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  1509. GGML_ASSERT(dim >= 0 && dim < 4);
  1510. int64_t o[4] = {0, 0, 0, 0};
  1511. o[dim] = src0->ne[dim];
  1512. const char * x;
  1513. // TODO: smarter multi-theading
  1514. for (int i3 = 0; i3 < ne3; i3++) {
  1515. for (int i2 = ith; i2 < ne2; i2 += nth) {
  1516. for (int i1 = 0; i1 < ne1; i1++) {
  1517. for (int i0 = 0; i0 < ne0; i0++) {
  1518. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  1519. x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
  1520. } else {
  1521. x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
  1522. }
  1523. char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
  1524. memcpy(y, x, len);
  1525. }
  1526. }
  1527. }
  1528. }
  1529. }
  1530. static void ggml_compute_forward_concat_i8(
  1531. const ggml_compute_params * params,
  1532. ggml_tensor * dst) {
  1533. const ggml_tensor * src0 = dst->src[0];
  1534. const ggml_tensor * src1 = dst->src[1];
  1535. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
  1536. const int ith = params->ith;
  1537. const int nth = params->nth;
  1538. GGML_TENSOR_BINARY_OP_LOCALS
  1539. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  1540. GGML_ASSERT(dim >= 0 && dim < 4);
  1541. int64_t o[4] = {0, 0, 0, 0};
  1542. o[dim] = src0->ne[dim];
  1543. const int8_t * x;
  1544. // TODO: smarter multi-theading
  1545. for (int i3 = 0; i3 < ne3; i3++) {
  1546. for (int i2 = ith; i2 < ne2; i2 += nth) {
  1547. for (int i1 = 0; i1 < ne1; i1++) {
  1548. for (int i0 = 0; i0 < ne0; i0++) {
  1549. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  1550. x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  1551. } else {
  1552. x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  1553. }
  1554. int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  1555. *y = *x;
  1556. }
  1557. }
  1558. }
  1559. }
  1560. }
  1561. static void ggml_compute_forward_concat_f16(
  1562. const ggml_compute_params * params,
  1563. ggml_tensor * dst) {
  1564. const ggml_tensor * src0 = dst->src[0];
  1565. const ggml_tensor * src1 = dst->src[1];
  1566. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
  1567. const int ith = params->ith;
  1568. const int nth = params->nth;
  1569. GGML_TENSOR_BINARY_OP_LOCALS
  1570. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  1571. GGML_ASSERT(dim >= 0 && dim < 4);
  1572. int64_t o[4] = {0, 0, 0, 0};
  1573. o[dim] = src0->ne[dim];
  1574. const ggml_fp16_t * x;
  1575. // TODO: smarter multi-theading
  1576. for (int i3 = 0; i3 < ne3; i3++) {
  1577. for (int i2 = ith; i2 < ne2; i2 += nth) {
  1578. for (int i1 = 0; i1 < ne1; i1++) {
  1579. for (int i0 = 0; i0 < ne0; i0++) {
  1580. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  1581. x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  1582. } else {
  1583. x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  1584. }
  1585. ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  1586. *y = *x;
  1587. }
  1588. }
  1589. }
  1590. }
  1591. }
  1592. static void ggml_compute_forward_concat_f32(
  1593. const ggml_compute_params * params,
  1594. ggml_tensor * dst) {
  1595. const ggml_tensor * src0 = dst->src[0];
  1596. const ggml_tensor * src1 = dst->src[1];
  1597. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
  1598. const int ith = params->ith;
  1599. const int nth = params->nth;
  1600. GGML_TENSOR_BINARY_OP_LOCALS
  1601. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  1602. GGML_ASSERT(dim >= 0 && dim < 4);
  1603. int64_t o[4] = {0, 0, 0, 0};
  1604. o[dim] = src0->ne[dim];
  1605. const float * x;
  1606. // TODO: smarter multi-theading
  1607. for (int i3 = 0; i3 < ne3; i3++) {
  1608. for (int i2 = ith; i2 < ne2; i2 += nth) {
  1609. for (int i1 = 0; i1 < ne1; i1++) {
  1610. for (int i0 = 0; i0 < ne0; i0++) {
  1611. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  1612. x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  1613. } else {
  1614. x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  1615. }
  1616. float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  1617. *y = *x;
  1618. }
  1619. }
  1620. }
  1621. }
  1622. }
  1623. void ggml_compute_forward_concat(
  1624. const ggml_compute_params * params,
  1625. ggml_tensor * dst) {
  1626. const ggml_tensor * src0 = dst->src[0];
  1627. switch (src0->type) {
  1628. case GGML_TYPE_F16:
  1629. case GGML_TYPE_BF16:
  1630. case GGML_TYPE_I16:
  1631. {
  1632. ggml_compute_forward_concat_f16(params, dst);
  1633. } break;
  1634. case GGML_TYPE_I8:
  1635. {
  1636. ggml_compute_forward_concat_i8(params, dst);
  1637. } break;
  1638. case GGML_TYPE_F32:
  1639. case GGML_TYPE_I32:
  1640. {
  1641. ggml_compute_forward_concat_f32(params, dst);
  1642. } break;
  1643. default:
  1644. {
  1645. ggml_compute_forward_concat_any(params, dst);
  1646. }
  1647. }
  1648. }
  1649. // ggml_compute_forward_gelu
  1650. static void ggml_compute_forward_gelu_f32(
  1651. const ggml_compute_params * params,
  1652. ggml_tensor * dst) {
  1653. const ggml_tensor * src0 = dst->src[0];
  1654. assert(ggml_is_contiguous_1(src0));
  1655. assert(ggml_is_contiguous_1(dst));
  1656. assert(ggml_are_same_shape(src0, dst));
  1657. const int ith = params->ith;
  1658. const int nth = params->nth;
  1659. const int nc = src0->ne[0];
  1660. const int nr = ggml_nrows(src0);
  1661. // rows per thread
  1662. const int dr = (nr + nth - 1)/nth;
  1663. // row range for this thread
  1664. const int ir0 = dr*ith;
  1665. const int ir1 = MIN(ir0 + dr, nr);
  1666. for (int i1 = ir0; i1 < ir1; i1++) {
  1667. ggml_vec_gelu_f32(nc,
  1668. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  1669. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  1670. #ifndef NDEBUG
  1671. for (int k = 0; k < nc; k++) {
  1672. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  1673. GGML_UNUSED(x);
  1674. assert(!isnan(x));
  1675. assert(!isinf(x));
  1676. }
  1677. #endif
  1678. }
  1679. }
  1680. static void ggml_compute_forward_gelu_f16(
  1681. const ggml_compute_params * params,
  1682. ggml_tensor * dst) {
  1683. const ggml_tensor * src0 = dst->src[0];
  1684. assert(ggml_is_contiguous_1(src0));
  1685. assert(ggml_is_contiguous_1(dst));
  1686. assert(ggml_are_same_shape(src0, dst));
  1687. const int ith = params->ith;
  1688. const int nth = params->nth;
  1689. const int nc = src0->ne[0];
  1690. const int nr = ggml_nrows(src0);
  1691. // rows per thread
  1692. const int dr = (nr + nth - 1)/nth;
  1693. // row range for this thread
  1694. const int ir0 = dr*ith;
  1695. const int ir1 = MIN(ir0 + dr, nr);
  1696. for (int i1 = ir0; i1 < ir1; i1++) {
  1697. ggml_vec_gelu_f16(nc,
  1698. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  1699. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  1700. #ifndef NDEBUG
  1701. for (int k = 0; k < nc; k++) {
  1702. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  1703. const float v = GGML_CPU_FP16_TO_FP32(x);
  1704. GGML_UNUSED(v);
  1705. assert(!isnan(v));
  1706. assert(!isinf(v));
  1707. }
  1708. #endif
  1709. }
  1710. }
  1711. static void ggml_compute_forward_gelu(
  1712. const ggml_compute_params * params,
  1713. ggml_tensor * dst) {
  1714. const ggml_tensor * src0 = dst->src[0];
  1715. switch (src0->type) {
  1716. case GGML_TYPE_F32:
  1717. {
  1718. ggml_compute_forward_gelu_f32(params, dst);
  1719. } break;
  1720. case GGML_TYPE_F16:
  1721. {
  1722. ggml_compute_forward_gelu_f16(params, dst);
  1723. } break;
  1724. default:
  1725. {
  1726. GGML_ABORT("fatal error");
  1727. }
  1728. }
  1729. }
  1730. // ggml_compute_forward_gelu_erf
  1731. static void ggml_compute_forward_gelu_erf_f32(
  1732. const ggml_compute_params * params,
  1733. ggml_tensor * dst) {
  1734. const ggml_tensor * src0 = dst->src[0];
  1735. assert(ggml_is_contiguous_1(src0));
  1736. assert(ggml_is_contiguous_1(dst));
  1737. assert(ggml_are_same_shape(src0, dst));
  1738. const int ith = params->ith;
  1739. const int nth = params->nth;
  1740. const int nc = src0->ne[0];
  1741. const int nr = ggml_nrows(src0);
  1742. // rows per thread
  1743. const int dr = (nr + nth - 1)/nth;
  1744. // row range for this thread
  1745. const int ir0 = dr*ith;
  1746. const int ir1 = MIN(ir0 + dr, nr);
  1747. for (int i1 = ir0; i1 < ir1; i1++) {
  1748. ggml_vec_gelu_erf_f32(nc,
  1749. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  1750. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  1751. #ifndef NDEBUG
  1752. for (int k = 0; k < nc; k++) {
  1753. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  1754. GGML_UNUSED(x);
  1755. assert(!isnan(x));
  1756. assert(!isinf(x));
  1757. }
  1758. #endif
  1759. }
  1760. }
  1761. static void ggml_compute_forward_gelu_erf_f16(
  1762. const ggml_compute_params * params,
  1763. ggml_tensor * dst) {
  1764. const ggml_tensor * src0 = dst->src[0];
  1765. assert(ggml_is_contiguous_1(src0));
  1766. assert(ggml_is_contiguous_1(dst));
  1767. assert(ggml_are_same_shape(src0, dst));
  1768. const int ith = params->ith;
  1769. const int nth = params->nth;
  1770. const int nc = src0->ne[0];
  1771. const int nr = ggml_nrows(src0);
  1772. // rows per thread
  1773. const int dr = (nr + nth - 1)/nth;
  1774. // row range for this thread
  1775. const int ir0 = dr*ith;
  1776. const int ir1 = MIN(ir0 + dr, nr);
  1777. for (int i1 = ir0; i1 < ir1; i1++) {
  1778. ggml_vec_gelu_erf_f16(nc,
  1779. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  1780. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  1781. #ifndef NDEBUG
  1782. for (int k = 0; k < nc; k++) {
  1783. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  1784. const float v = GGML_CPU_FP16_TO_FP32(x);
  1785. GGML_UNUSED(v);
  1786. assert(!isnan(v));
  1787. assert(!isinf(v));
  1788. }
  1789. #endif
  1790. }
  1791. }
  1792. static void ggml_compute_forward_gelu_erf(
  1793. const ggml_compute_params * params,
  1794. ggml_tensor * dst) {
  1795. const ggml_tensor * src0 = dst->src[0];
  1796. switch (src0->type) {
  1797. case GGML_TYPE_F32:
  1798. {
  1799. ggml_compute_forward_gelu_erf_f32(params, dst);
  1800. } break;
  1801. case GGML_TYPE_F16:
  1802. {
  1803. ggml_compute_forward_gelu_erf_f16(params, dst);
  1804. } break;
  1805. default:
  1806. {
  1807. GGML_ABORT("fatal error");
  1808. }
  1809. }
  1810. }
  1811. // ggml_compute_forward_gelu_quick
  1812. static void ggml_compute_forward_gelu_quick_f32(
  1813. const ggml_compute_params * params,
  1814. ggml_tensor * dst) {
  1815. const ggml_tensor * src0 = dst->src[0];
  1816. assert(ggml_is_contiguous_1(src0));
  1817. assert(ggml_is_contiguous_1(dst));
  1818. assert(ggml_are_same_shape(src0, dst));
  1819. const int ith = params->ith;
  1820. const int nth = params->nth;
  1821. const int nc = src0->ne[0];
  1822. const int nr = ggml_nrows(src0);
  1823. // rows per thread
  1824. const int dr = (nr + nth - 1)/nth;
  1825. // row range for this thread
  1826. const int ir0 = dr*ith;
  1827. const int ir1 = MIN(ir0 + dr, nr);
  1828. for (int i1 = ir0; i1 < ir1; i1++) {
  1829. ggml_vec_gelu_quick_f32(nc,
  1830. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  1831. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  1832. #ifndef NDEBUG
  1833. for (int k = 0; k < nc; k++) {
  1834. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  1835. GGML_UNUSED(x);
  1836. assert(!isnan(x));
  1837. assert(!isinf(x));
  1838. }
  1839. #endif
  1840. }
  1841. }
  1842. static void ggml_compute_forward_gelu_quick_f16(
  1843. const ggml_compute_params * params,
  1844. ggml_tensor * dst) {
  1845. const ggml_tensor * src0 = dst->src[0];
  1846. assert(ggml_is_contiguous_1(src0));
  1847. assert(ggml_is_contiguous_1(dst));
  1848. assert(ggml_are_same_shape(src0, dst));
  1849. const int ith = params->ith;
  1850. const int nth = params->nth;
  1851. const int nc = src0->ne[0];
  1852. const int nr = ggml_nrows(src0);
  1853. // rows per thread
  1854. const int dr = (nr + nth - 1)/nth;
  1855. // row range for this thread
  1856. const int ir0 = dr*ith;
  1857. const int ir1 = MIN(ir0 + dr, nr);
  1858. for (int i1 = ir0; i1 < ir1; i1++) {
  1859. ggml_vec_gelu_quick_f16(nc,
  1860. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  1861. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  1862. #ifndef NDEBUG
  1863. for (int k = 0; k < nc; k++) {
  1864. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  1865. const float v = GGML_CPU_FP16_TO_FP32(x);
  1866. GGML_UNUSED(v);
  1867. assert(!isnan(v));
  1868. assert(!isinf(v));
  1869. }
  1870. #endif
  1871. }
  1872. }
  1873. static void ggml_compute_forward_gelu_quick(
  1874. const ggml_compute_params * params,
  1875. ggml_tensor * dst) {
  1876. const ggml_tensor * src0 = dst->src[0];
  1877. switch (src0->type) {
  1878. case GGML_TYPE_F32:
  1879. {
  1880. ggml_compute_forward_gelu_quick_f32(params, dst);
  1881. } break;
  1882. case GGML_TYPE_F16:
  1883. {
  1884. ggml_compute_forward_gelu_quick_f16(params, dst);
  1885. } break;
  1886. default:
  1887. {
  1888. GGML_ABORT("fatal error");
  1889. }
  1890. }
  1891. }
  1892. // ggml_compute_forward_silu
  1893. static void ggml_compute_forward_silu_f32(
  1894. const ggml_compute_params * params,
  1895. ggml_tensor * dst) {
  1896. const ggml_tensor * src0 = dst->src[0];
  1897. assert(ggml_is_contiguous_1(src0));
  1898. assert(ggml_is_contiguous_1(dst));
  1899. assert(ggml_are_same_shape(src0, dst));
  1900. const int ith = params->ith;
  1901. const int nth = params->nth;
  1902. const int nc = src0->ne[0];
  1903. const int nr = ggml_nrows(src0);
  1904. // rows per thread
  1905. const int dr = (nr + nth - 1)/nth;
  1906. // row range for this thread
  1907. const int ir0 = dr*ith;
  1908. const int ir1 = MIN(ir0 + dr, nr);
  1909. for (int i1 = ir0; i1 < ir1; i1++) {
  1910. ggml_vec_silu_f32(nc,
  1911. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  1912. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  1913. #ifndef NDEBUG
  1914. for (int k = 0; k < nc; k++) {
  1915. const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  1916. GGML_UNUSED(x);
  1917. assert(!isnan(x));
  1918. assert(!isinf(x));
  1919. }
  1920. #endif
  1921. }
  1922. }
  1923. static void ggml_compute_forward_silu_f16(
  1924. const ggml_compute_params * params,
  1925. ggml_tensor * dst) {
  1926. const ggml_tensor * src0 = dst->src[0];
  1927. assert(ggml_is_contiguous_1(src0));
  1928. assert(ggml_is_contiguous_1(dst));
  1929. assert(ggml_are_same_shape(src0, dst));
  1930. const int ith = params->ith;
  1931. const int nth = params->nth;
  1932. const int nc = src0->ne[0];
  1933. const int nr = ggml_nrows(src0);
  1934. // rows per thread
  1935. const int dr = (nr + nth - 1)/nth;
  1936. // row range for this thread
  1937. const int ir0 = dr*ith;
  1938. const int ir1 = MIN(ir0 + dr, nr);
  1939. for (int i1 = ir0; i1 < ir1; i1++) {
  1940. ggml_vec_silu_f16(nc,
  1941. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  1942. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  1943. #ifndef NDEBUG
  1944. for (int k = 0; k < nc; k++) {
  1945. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  1946. const float v = GGML_CPU_FP16_TO_FP32(x);
  1947. GGML_UNUSED(v);
  1948. assert(!isnan(v));
  1949. assert(!isinf(v));
  1950. }
  1951. #endif
  1952. }
  1953. }
  1954. static void ggml_compute_forward_silu(
  1955. const ggml_compute_params * params,
  1956. ggml_tensor * dst) {
  1957. const ggml_tensor * src0 = dst->src[0];
  1958. switch (src0->type) {
  1959. case GGML_TYPE_F32:
  1960. {
  1961. ggml_compute_forward_silu_f32(params, dst);
  1962. } break;
  1963. case GGML_TYPE_F16:
  1964. {
  1965. ggml_compute_forward_silu_f16(params, dst);
  1966. } break;
  1967. default:
  1968. {
  1969. GGML_ABORT("fatal error");
  1970. }
  1971. }
  1972. }
  1973. // ggml_compute_forward_leaky_relu
  1974. static void ggml_compute_forward_leaky_relu_f32(
  1975. const ggml_compute_params * params,
  1976. ggml_tensor * dst) {
  1977. const ggml_tensor * src0 = dst->src[0];
  1978. if (params->ith != 0) {
  1979. return;
  1980. }
  1981. assert(ggml_is_contiguous_1(src0));
  1982. assert(ggml_is_contiguous_1(dst));
  1983. assert(ggml_are_same_shape(src0, dst));
  1984. const int n = ggml_nrows(src0);
  1985. const int nc = src0->ne[0];
  1986. float negative_slope;
  1987. memcpy(&negative_slope, dst->op_params, sizeof(float));
  1988. assert(dst->nb[0] == sizeof(float));
  1989. assert(src0->nb[0] == sizeof(float));
  1990. for (int i = 0; i < n; i++) {
  1991. ggml_vec_leaky_relu_f32(nc,
  1992. (float *) ((char *) dst->data + i*( dst->nb[1])),
  1993. (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  1994. }
  1995. }
  1996. static void ggml_compute_forward_leaky_relu_f16(
  1997. const ggml_compute_params * params,
  1998. ggml_tensor * dst) {
  1999. const ggml_tensor * src0 = dst->src[0];
  2000. if (params->ith != 0) {
  2001. return;
  2002. }
  2003. assert(ggml_is_contiguous_1(src0));
  2004. assert(ggml_is_contiguous_1(dst));
  2005. assert(ggml_are_same_shape(src0, dst));
  2006. const int n = ggml_nrows(src0);
  2007. const int nc = src0->ne[0];
  2008. float negative_slope;
  2009. memcpy(&negative_slope, dst->op_params, sizeof(float));
  2010. assert(dst->nb[0] == sizeof(ggml_fp16_t));
  2011. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  2012. for (int i = 0; i < n; i++) {
  2013. ggml_vec_leaky_relu_f16(nc,
  2014. (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
  2015. (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  2016. }
  2017. }
  2018. void ggml_compute_forward_leaky_relu(
  2019. const ggml_compute_params * params,
  2020. ggml_tensor * dst) {
  2021. const ggml_tensor * src0 = dst->src[0];
  2022. switch (src0->type) {
  2023. case GGML_TYPE_F32:
  2024. {
  2025. ggml_compute_forward_leaky_relu_f32(params, dst);
  2026. } break;
  2027. case GGML_TYPE_F16:
  2028. {
  2029. ggml_compute_forward_leaky_relu_f16(params, dst);
  2030. } break;
  2031. default:
  2032. {
  2033. GGML_ABORT("fatal error");
  2034. }
  2035. }
  2036. }
  2037. // ggml_compute_forward_silu_back
  2038. static void ggml_compute_forward_silu_back_f32(
  2039. const ggml_compute_params * params,
  2040. ggml_tensor * dst) {
  2041. const ggml_tensor * grad = dst->src[0];
  2042. const ggml_tensor * src1 = dst->src[1];
  2043. assert(ggml_is_contiguous_1(grad));
  2044. assert(ggml_is_contiguous_1(src1));
  2045. assert(ggml_is_contiguous_1(dst));
  2046. assert(ggml_are_same_shape(src1, dst));
  2047. assert(ggml_are_same_shape(src1, grad));
  2048. const int ith = params->ith;
  2049. const int nth = params->nth;
  2050. const int nc = src1->ne[0];
  2051. const int nr = ggml_nrows(src1);
  2052. // rows per thread
  2053. const int dr = (nr + nth - 1)/nth;
  2054. // row range for this thread
  2055. const int ir0 = dr*ith;
  2056. const int ir1 = MIN(ir0 + dr, nr);
  2057. for (int i1 = ir0; i1 < ir1; i1++) {
  2058. ggml_vec_silu_backward_f32(nc,
  2059. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2060. (float *) ((char *) src1->data + i1*(src1->nb[1])),
  2061. (float *) ((char *) grad->data + i1*(grad->nb[1])));
  2062. #ifndef NDEBUG
  2063. for (int k = 0; k < nc; k++) {
  2064. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2065. GGML_UNUSED(x);
  2066. assert(!isnan(x));
  2067. assert(!isinf(x));
  2068. }
  2069. #endif
  2070. }
  2071. }
  2072. static void ggml_compute_forward_silu_back_f16(
  2073. const ggml_compute_params * params,
  2074. ggml_tensor * dst) {
  2075. const ggml_tensor * grad = dst->src[0];
  2076. const ggml_tensor * src1 = dst->src[1];
  2077. assert(ggml_is_contiguous_1(grad));
  2078. assert(ggml_is_contiguous_1(src1));
  2079. assert(ggml_is_contiguous_1(dst));
  2080. assert(ggml_are_same_shape(src1, dst));
  2081. assert(ggml_are_same_shape(src1, grad));
  2082. const int ith = params->ith;
  2083. const int nth = params->nth;
  2084. const int nc = src1->ne[0];
  2085. const int nr = ggml_nrows(src1);
  2086. // rows per thread
  2087. const int dr = (nr + nth - 1)/nth;
  2088. // row range for this thread
  2089. const int ir0 = dr*ith;
  2090. const int ir1 = MIN(ir0 + dr, nr);
  2091. for (int i1 = ir0; i1 < ir1; i1++) {
  2092. ggml_vec_silu_backward_f16(nc,
  2093. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2094. (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
  2095. (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
  2096. #ifndef NDEBUG
  2097. for (int k = 0; k < nc; k++) {
  2098. const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2099. const float v = GGML_CPU_FP16_TO_FP32(x);
  2100. GGML_UNUSED(v);
  2101. assert(!isnan(v));
  2102. assert(!isinf(v));
  2103. }
  2104. #endif
  2105. }
  2106. }
  2107. void ggml_compute_forward_silu_back(
  2108. const ggml_compute_params * params,
  2109. ggml_tensor * dst) {
  2110. const ggml_tensor * src0 = dst->src[0];
  2111. switch (src0->type) {
  2112. case GGML_TYPE_F32:
  2113. {
  2114. ggml_compute_forward_silu_back_f32(params, dst);
  2115. } break;
  2116. case GGML_TYPE_F16:
  2117. {
  2118. ggml_compute_forward_silu_back_f16(params, dst);
  2119. } break;
  2120. default:
  2121. {
  2122. GGML_ABORT("fatal error");
  2123. }
  2124. }
  2125. }
  2126. // ggml_compute_forward_reglu
  2127. static void ggml_compute_forward_reglu_f32(
  2128. const ggml_compute_params * params,
  2129. ggml_tensor * dst) {
  2130. const ggml_tensor * src0 = dst->src[0];
  2131. const ggml_tensor * src1 = dst->src[1];
  2132. char * src0_d = (char *) src0->data;
  2133. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2134. const size_t src0_o = src0->nb[1];
  2135. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2136. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2137. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2138. if (src1) {
  2139. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2140. GGML_ASSERT(src0->type == src1->type);
  2141. }
  2142. const int ith = params->ith;
  2143. const int nth = params->nth;
  2144. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2145. const int nr = ggml_nrows(src0);
  2146. GGML_ASSERT(dst->ne[0] == nc);
  2147. GGML_ASSERT(ggml_nrows(dst) == nr);
  2148. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2149. // rows per thread
  2150. const int dr = (nr + nth - 1)/nth;
  2151. // row range for this thread
  2152. const int ir0 = dr*ith;
  2153. const int ir1 = MIN(ir0 + dr, nr);
  2154. for (int i1 = ir0; i1 < ir1; i1++) {
  2155. float * src0_p = (float *) (src0_d + i1*src0_o);
  2156. float * src1_p = (float *) (src1_d + i1*src1_o);
  2157. if (!src1) {
  2158. src0_p += swapped ? nc : 0;
  2159. src1_p += swapped ? 0 : nc;
  2160. }
  2161. ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2162. #ifndef NDEBUG
  2163. for (int k = 0; k < nc; k++) {
  2164. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2165. GGML_UNUSED(x);
  2166. assert(!isnan(x));
  2167. assert(!isinf(x));
  2168. }
  2169. #endif
  2170. }
  2171. }
  2172. static void ggml_compute_forward_reglu_f16(
  2173. const ggml_compute_params * params,
  2174. ggml_tensor * dst) {
  2175. const ggml_tensor * src0 = dst->src[0];
  2176. const ggml_tensor * src1 = dst->src[1];
  2177. char * src0_d = (char *) src0->data;
  2178. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2179. const size_t src0_o = src0->nb[1];
  2180. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2181. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2182. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2183. if (src1) {
  2184. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2185. GGML_ASSERT(src0->type == src1->type);
  2186. }
  2187. const int ith = params->ith;
  2188. const int nth = params->nth;
  2189. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2190. const int nr = ggml_nrows(src0);
  2191. GGML_ASSERT(dst->ne[0] == nc);
  2192. GGML_ASSERT(ggml_nrows(dst) == nr);
  2193. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2194. // rows per thread
  2195. const int dr = (nr + nth - 1)/nth;
  2196. // row range for this thread
  2197. const int ir0 = dr*ith;
  2198. const int ir1 = MIN(ir0 + dr, nr);
  2199. for (int i1 = ir0; i1 < ir1; i1++) {
  2200. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2201. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2202. if (!src1) {
  2203. src0_p += swapped ? nc : 0;
  2204. src1_p += swapped ? 0 : nc;
  2205. }
  2206. ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2207. #ifndef NDEBUG
  2208. for (int k = 0; k < nc; k++) {
  2209. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2210. const float v = GGML_FP16_TO_FP32(x);
  2211. GGML_UNUSED(v);
  2212. assert(!isnan(v));
  2213. assert(!isinf(v));
  2214. }
  2215. #endif
  2216. }
  2217. }
  2218. static void ggml_compute_forward_reglu(
  2219. const ggml_compute_params * params,
  2220. ggml_tensor * dst) {
  2221. const ggml_tensor * src0 = dst->src[0];
  2222. switch (src0->type) {
  2223. case GGML_TYPE_F32:
  2224. {
  2225. ggml_compute_forward_reglu_f32(params, dst);
  2226. } break;
  2227. case GGML_TYPE_F16:
  2228. {
  2229. ggml_compute_forward_reglu_f16(params, dst);
  2230. } break;
  2231. default:
  2232. {
  2233. GGML_ABORT("fatal error");
  2234. }
  2235. }
  2236. }
  2237. // ggml_compute_forward_geglu
  2238. static void ggml_compute_forward_geglu_f32(
  2239. const ggml_compute_params * params,
  2240. ggml_tensor * dst) {
  2241. const ggml_tensor * src0 = dst->src[0];
  2242. const ggml_tensor * src1 = dst->src[1];
  2243. char * src0_d = (char *) src0->data;
  2244. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2245. const size_t src0_o = src0->nb[1];
  2246. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2247. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2248. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2249. if (src1) {
  2250. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2251. GGML_ASSERT(src0->type == src1->type);
  2252. }
  2253. const int ith = params->ith;
  2254. const int nth = params->nth;
  2255. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2256. const int nr = ggml_nrows(src0);
  2257. GGML_ASSERT(dst->ne[0] == nc);
  2258. GGML_ASSERT(ggml_nrows(dst) == nr);
  2259. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2260. // rows per thread
  2261. const int dr = (nr + nth - 1)/nth;
  2262. // row range for this thread
  2263. const int ir0 = dr*ith;
  2264. const int ir1 = MIN(ir0 + dr, nr);
  2265. for (int i1 = ir0; i1 < ir1; i1++) {
  2266. float * src0_p = (float *) (src0_d + i1*src0_o);
  2267. float * src1_p = (float *) (src1_d + i1*src1_o);
  2268. if (!src1) {
  2269. src0_p += swapped ? nc : 0;
  2270. src1_p += swapped ? 0 : nc;
  2271. }
  2272. ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2273. #ifndef NDEBUG
  2274. for (int k = 0; k < nc; k++) {
  2275. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2276. GGML_UNUSED(x);
  2277. assert(!isnan(x));
  2278. assert(!isinf(x));
  2279. }
  2280. #endif
  2281. }
  2282. }
  2283. static void ggml_compute_forward_geglu_f16(
  2284. const ggml_compute_params * params,
  2285. ggml_tensor * dst) {
  2286. const ggml_tensor * src0 = dst->src[0];
  2287. const ggml_tensor * src1 = dst->src[1];
  2288. char * src0_d = (char *) src0->data;
  2289. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2290. const size_t src0_o = src0->nb[1];
  2291. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2292. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2293. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2294. if (src1) {
  2295. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2296. GGML_ASSERT(src0->type == src1->type);
  2297. }
  2298. const int ith = params->ith;
  2299. const int nth = params->nth;
  2300. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2301. const int nr = ggml_nrows(src0);
  2302. GGML_ASSERT(dst->ne[0] == nc);
  2303. GGML_ASSERT(ggml_nrows(dst) == nr);
  2304. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2305. // rows per thread
  2306. const int dr = (nr + nth - 1)/nth;
  2307. // row range for this thread
  2308. const int ir0 = dr*ith;
  2309. const int ir1 = MIN(ir0 + dr, nr);
  2310. for (int i1 = ir0; i1 < ir1; i1++) {
  2311. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2312. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2313. if (!src1) {
  2314. src0_p += swapped ? nc : 0;
  2315. src1_p += swapped ? 0 : nc;
  2316. }
  2317. ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2318. #ifndef NDEBUG
  2319. for (int k = 0; k < nc; k++) {
  2320. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2321. const float v = GGML_FP16_TO_FP32(x);
  2322. GGML_UNUSED(v);
  2323. assert(!isnan(v));
  2324. assert(!isinf(v));
  2325. }
  2326. #endif
  2327. }
  2328. }
  2329. static void ggml_compute_forward_geglu(
  2330. const ggml_compute_params * params,
  2331. ggml_tensor * dst) {
  2332. const ggml_tensor * src0 = dst->src[0];
  2333. switch (src0->type) {
  2334. case GGML_TYPE_F32:
  2335. {
  2336. ggml_compute_forward_geglu_f32(params, dst);
  2337. } break;
  2338. case GGML_TYPE_F16:
  2339. {
  2340. ggml_compute_forward_geglu_f16(params, dst);
  2341. } break;
  2342. default:
  2343. {
  2344. GGML_ABORT("fatal error");
  2345. }
  2346. }
  2347. }
  2348. // ggml_compute_forward_swiglu
  2349. static void ggml_compute_forward_swiglu_f32(
  2350. const ggml_compute_params * params,
  2351. ggml_tensor * dst) {
  2352. const ggml_tensor * src0 = dst->src[0];
  2353. const ggml_tensor * src1 = dst->src[1];
  2354. char * src0_d = (char *) src0->data;
  2355. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2356. const size_t src0_o = src0->nb[1];
  2357. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2358. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2359. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2360. if (src1) {
  2361. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2362. GGML_ASSERT(src0->type == src1->type);
  2363. }
  2364. const int ith = params->ith;
  2365. const int nth = params->nth;
  2366. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2367. const int nr = ggml_nrows(src0);
  2368. GGML_ASSERT(dst->ne[0] == nc);
  2369. GGML_ASSERT(ggml_nrows(dst) == nr);
  2370. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2371. // rows per thread
  2372. const int dr = (nr + nth - 1)/nth;
  2373. // row range for this thread
  2374. const int ir0 = dr*ith;
  2375. const int ir1 = MIN(ir0 + dr, nr);
  2376. for (int i1 = ir0; i1 < ir1; i1++) {
  2377. float * src0_p = (float *) (src0_d + i1*src0_o);
  2378. float * src1_p = (float *) (src1_d + i1*src1_o);
  2379. if (!src1) {
  2380. src0_p += swapped ? nc : 0;
  2381. src1_p += swapped ? 0 : nc;
  2382. }
  2383. ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2384. #ifndef NDEBUG
  2385. for (int k = 0; k < nc; k++) {
  2386. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2387. GGML_UNUSED(x);
  2388. assert(!isnan(x));
  2389. assert(!isinf(x));
  2390. }
  2391. #endif
  2392. }
  2393. }
  2394. static void ggml_compute_forward_swiglu_f16(
  2395. const ggml_compute_params * params,
  2396. ggml_tensor * dst) {
  2397. const ggml_tensor * src0 = dst->src[0];
  2398. const ggml_tensor * src1 = dst->src[1];
  2399. char * src0_d = (char *) src0->data;
  2400. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2401. const size_t src0_o = src0->nb[1];
  2402. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2403. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2404. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2405. if (src1) {
  2406. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2407. GGML_ASSERT(src0->type == src1->type);
  2408. }
  2409. const int ith = params->ith;
  2410. const int nth = params->nth;
  2411. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2412. const int nr = ggml_nrows(src0);
  2413. GGML_ASSERT(dst->ne[0] == nc);
  2414. GGML_ASSERT(ggml_nrows(dst) == nr);
  2415. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2416. // rows per thread
  2417. const int dr = (nr + nth - 1)/nth;
  2418. // row range for this thread
  2419. const int ir0 = dr*ith;
  2420. const int ir1 = MIN(ir0 + dr, nr);
  2421. for (int i1 = ir0; i1 < ir1; i1++) {
  2422. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2423. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2424. if (!src1) {
  2425. src0_p += swapped ? nc : 0;
  2426. src1_p += swapped ? 0 : nc;
  2427. }
  2428. ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2429. #ifndef NDEBUG
  2430. for (int k = 0; k < nc; k++) {
  2431. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2432. const float v = GGML_FP16_TO_FP32(x);
  2433. GGML_UNUSED(v);
  2434. assert(!isnan(v));
  2435. assert(!isinf(v));
  2436. }
  2437. #endif
  2438. }
  2439. }
  2440. static void ggml_compute_forward_swiglu(
  2441. const ggml_compute_params * params,
  2442. ggml_tensor * dst) {
  2443. const ggml_tensor * src0 = dst->src[0];
  2444. switch (src0->type) {
  2445. case GGML_TYPE_F32:
  2446. {
  2447. ggml_compute_forward_swiglu_f32(params, dst);
  2448. } break;
  2449. case GGML_TYPE_F16:
  2450. {
  2451. ggml_compute_forward_swiglu_f16(params, dst);
  2452. } break;
  2453. default:
  2454. {
  2455. GGML_ABORT("fatal error");
  2456. }
  2457. }
  2458. }
  2459. // ggml_compute_forward_swiglu_oai
  2460. static void ggml_compute_forward_swiglu_oai_f32(
  2461. const ggml_compute_params * params,
  2462. ggml_tensor * dst) {
  2463. const ggml_tensor * src0 = dst->src[0];
  2464. const ggml_tensor * src1 = dst->src[1];
  2465. char * src0_d = (char *) src0->data;
  2466. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2467. const size_t src0_o = src0->nb[1];
  2468. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2469. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2470. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2471. if (src1) {
  2472. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2473. GGML_ASSERT(src0->type == src1->type);
  2474. }
  2475. const int ith = params->ith;
  2476. const int nth = params->nth;
  2477. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2478. const int nr = ggml_nrows(src0);
  2479. GGML_ASSERT(dst->ne[0] == nc);
  2480. GGML_ASSERT(ggml_nrows(dst) == nr);
  2481. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2482. const float alpha = ggml_get_op_params_f32(dst, 2);
  2483. const float limit = ggml_get_op_params_f32(dst, 3);
  2484. // rows per thread
  2485. const int dr = (nr + nth - 1)/nth;
  2486. // row range for this thread
  2487. const int ir0 = dr*ith;
  2488. const int ir1 = MIN(ir0 + dr, nr);
  2489. for (int i1 = ir0; i1 < ir1; i1++) {
  2490. float * src0_p = (float *) (src0_d + i1*src0_o);
  2491. float * src1_p = (float *) (src1_d + i1*src1_o);
  2492. float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
  2493. if (!src1) {
  2494. src0_p += swapped ? nc : 0;
  2495. src1_p += swapped ? 0 : nc;
  2496. }
  2497. for (int k = 0; k < nc; k++) {
  2498. const float x = std::min(src0_p[k], limit);
  2499. const float y = std::clamp(src1_p[k], -limit, limit);
  2500. const float out_glu = x / (1.f + expf(alpha * (-x)));
  2501. dst_p[k] = out_glu * (y + 1.f);
  2502. }
  2503. #ifndef NDEBUG
  2504. for (int k = 0; k < nc; k++) {
  2505. const float x = dst_p[k];
  2506. GGML_UNUSED(x);
  2507. assert(!isnan(x));
  2508. assert(!isinf(x));
  2509. }
  2510. #endif
  2511. }
  2512. }
  2513. static void ggml_compute_forward_swiglu_oai(
  2514. const ggml_compute_params * params,
  2515. ggml_tensor * dst) {
  2516. const ggml_tensor * src0 = dst->src[0];
  2517. switch (src0->type) {
  2518. case GGML_TYPE_F32:
  2519. {
  2520. ggml_compute_forward_swiglu_oai_f32(params, dst);
  2521. } break;
  2522. default:
  2523. {
  2524. GGML_ABORT("fatal error");
  2525. }
  2526. }
  2527. }
  2528. // ggml_compute_forward_geglu_erf
  2529. static void ggml_compute_forward_geglu_erf_f32(
  2530. const ggml_compute_params * params,
  2531. ggml_tensor * dst) {
  2532. const ggml_tensor * src0 = dst->src[0];
  2533. const ggml_tensor * src1 = dst->src[1];
  2534. char * src0_d = (char *) src0->data;
  2535. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2536. const size_t src0_o = src0->nb[1];
  2537. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2538. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2539. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2540. if (src1) {
  2541. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2542. GGML_ASSERT(src0->type == src1->type);
  2543. }
  2544. const int ith = params->ith;
  2545. const int nth = params->nth;
  2546. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2547. const int nr = ggml_nrows(src0);
  2548. GGML_ASSERT(dst->ne[0] == nc);
  2549. GGML_ASSERT(ggml_nrows(dst) == nr);
  2550. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2551. // rows per thread
  2552. const int dr = (nr + nth - 1)/nth;
  2553. // row range for this thread
  2554. const int ir0 = dr*ith;
  2555. const int ir1 = MIN(ir0 + dr, nr);
  2556. for (int i1 = ir0; i1 < ir1; i1++) {
  2557. float * src0_p = (float *) (src0_d + i1*src0_o);
  2558. float * src1_p = (float *) (src1_d + i1*src1_o);
  2559. if (!src1) {
  2560. src0_p += swapped ? nc : 0;
  2561. src1_p += swapped ? 0 : nc;
  2562. }
  2563. ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2564. #ifndef NDEBUG
  2565. for (int k = 0; k < nc; k++) {
  2566. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2567. GGML_UNUSED(x);
  2568. assert(!isnan(x));
  2569. assert(!isinf(x));
  2570. }
  2571. #endif
  2572. }
  2573. }
  2574. static void ggml_compute_forward_geglu_erf_f16(
  2575. const ggml_compute_params * params,
  2576. ggml_tensor * dst) {
  2577. const ggml_tensor * src0 = dst->src[0];
  2578. const ggml_tensor * src1 = dst->src[1];
  2579. char * src0_d = (char *) src0->data;
  2580. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2581. const size_t src0_o = src0->nb[1];
  2582. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2583. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2584. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2585. if (src1) {
  2586. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2587. GGML_ASSERT(src0->type == src1->type);
  2588. }
  2589. const int ith = params->ith;
  2590. const int nth = params->nth;
  2591. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2592. const int nr = ggml_nrows(src0);
  2593. GGML_ASSERT(dst->ne[0] == nc);
  2594. GGML_ASSERT(ggml_nrows(dst) == nr);
  2595. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2596. // rows per thread
  2597. const int dr = (nr + nth - 1)/nth;
  2598. // row range for this thread
  2599. const int ir0 = dr*ith;
  2600. const int ir1 = MIN(ir0 + dr, nr);
  2601. for (int i1 = ir0; i1 < ir1; i1++) {
  2602. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2603. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2604. if (!src1) {
  2605. src0_p += swapped ? nc : 0;
  2606. src1_p += swapped ? 0 : nc;
  2607. }
  2608. ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2609. #ifndef NDEBUG
  2610. for (int k = 0; k < nc; k++) {
  2611. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2612. const float v = GGML_FP16_TO_FP32(x);
  2613. GGML_UNUSED(v);
  2614. assert(!isnan(v));
  2615. assert(!isinf(v));
  2616. }
  2617. #endif
  2618. }
  2619. }
  2620. static void ggml_compute_forward_geglu_erf(
  2621. const ggml_compute_params * params,
  2622. ggml_tensor * dst) {
  2623. const ggml_tensor * src0 = dst->src[0];
  2624. switch (src0->type) {
  2625. case GGML_TYPE_F32:
  2626. {
  2627. ggml_compute_forward_geglu_erf_f32(params, dst);
  2628. } break;
  2629. case GGML_TYPE_F16:
  2630. {
  2631. ggml_compute_forward_geglu_erf_f16(params, dst);
  2632. } break;
  2633. default:
  2634. {
  2635. GGML_ABORT("fatal error");
  2636. }
  2637. }
  2638. }
  2639. // ggml_compute_forward_geglu_quick
  2640. static void ggml_compute_forward_geglu_quick_f32(
  2641. const ggml_compute_params * params,
  2642. ggml_tensor * dst) {
  2643. const ggml_tensor * src0 = dst->src[0];
  2644. const ggml_tensor * src1 = dst->src[1];
  2645. char * src0_d = (char *) src0->data;
  2646. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2647. const size_t src0_o = src0->nb[1];
  2648. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2649. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2650. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2651. if (src1) {
  2652. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2653. GGML_ASSERT(src0->type == src1->type);
  2654. }
  2655. const int ith = params->ith;
  2656. const int nth = params->nth;
  2657. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2658. const int nr = ggml_nrows(src0);
  2659. GGML_ASSERT(dst->ne[0] == nc);
  2660. GGML_ASSERT(ggml_nrows(dst) == nr);
  2661. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2662. // rows per thread
  2663. const int dr = (nr + nth - 1)/nth;
  2664. // row range for this thread
  2665. const int ir0 = dr*ith;
  2666. const int ir1 = MIN(ir0 + dr, nr);
  2667. for (int i1 = ir0; i1 < ir1; i1++) {
  2668. float * src0_p = (float *) (src0_d + i1*src0_o);
  2669. float * src1_p = (float *) (src1_d + i1*src1_o);
  2670. if (!src1) {
  2671. src0_p += swapped ? nc : 0;
  2672. src1_p += swapped ? 0 : nc;
  2673. }
  2674. ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2675. #ifndef NDEBUG
  2676. for (int k = 0; k < nc; k++) {
  2677. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2678. GGML_UNUSED(x);
  2679. assert(!isnan(x));
  2680. assert(!isinf(x));
  2681. }
  2682. #endif
  2683. }
  2684. }
  2685. static void ggml_compute_forward_geglu_quick_f16(
  2686. const ggml_compute_params * params,
  2687. ggml_tensor * dst) {
  2688. const ggml_tensor * src0 = dst->src[0];
  2689. const ggml_tensor * src1 = dst->src[1];
  2690. char * src0_d = (char *) src0->data;
  2691. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2692. const size_t src0_o = src0->nb[1];
  2693. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2694. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2695. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2696. if (src1) {
  2697. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2698. GGML_ASSERT(src0->type == src1->type);
  2699. }
  2700. const int ith = params->ith;
  2701. const int nth = params->nth;
  2702. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2703. const int nr = ggml_nrows(src0);
  2704. GGML_ASSERT(dst->ne[0] == nc);
  2705. GGML_ASSERT(ggml_nrows(dst) == nr);
  2706. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2707. // rows per thread
  2708. const int dr = (nr + nth - 1)/nth;
  2709. // row range for this thread
  2710. const int ir0 = dr*ith;
  2711. const int ir1 = MIN(ir0 + dr, nr);
  2712. for (int i1 = ir0; i1 < ir1; i1++) {
  2713. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2714. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2715. if (!src1) {
  2716. src0_p += swapped ? nc : 0;
  2717. src1_p += swapped ? 0 : nc;
  2718. }
  2719. ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2720. #ifndef NDEBUG
  2721. for (int k = 0; k < nc; k++) {
  2722. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2723. const float v = GGML_FP16_TO_FP32(x);
  2724. GGML_UNUSED(v);
  2725. assert(!isnan(v));
  2726. assert(!isinf(v));
  2727. }
  2728. #endif
  2729. }
  2730. }
  2731. static void ggml_compute_forward_geglu_quick(
  2732. const ggml_compute_params * params,
  2733. ggml_tensor * dst) {
  2734. const ggml_tensor * src0 = dst->src[0];
  2735. switch (src0->type) {
  2736. case GGML_TYPE_F32:
  2737. {
  2738. ggml_compute_forward_geglu_quick_f32(params, dst);
  2739. } break;
  2740. case GGML_TYPE_F16:
  2741. {
  2742. ggml_compute_forward_geglu_quick_f16(params, dst);
  2743. } break;
  2744. default:
  2745. {
  2746. GGML_ABORT("fatal error");
  2747. }
  2748. }
  2749. }
  2750. // ggml_compute_forward_norm
  2751. static void ggml_compute_forward_norm_f32(
  2752. const ggml_compute_params * params,
  2753. ggml_tensor * dst) {
  2754. const ggml_tensor * src0 = dst->src[0];
  2755. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  2756. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2757. const int ith = params->ith;
  2758. const int nth = params->nth;
  2759. GGML_TENSOR_UNARY_OP_LOCALS
  2760. float eps;
  2761. memcpy(&eps, dst->op_params, sizeof(float));
  2762. GGML_ASSERT(eps >= 0.0f);
  2763. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2764. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2765. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2766. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2767. float sum = 0.0;
  2768. ggml_vec_sum_f32(ne00, &sum, x);
  2769. float mean = sum/ne00;
  2770. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2771. float variance = 0;
  2772. #ifdef GGML_USE_ACCELERATE
  2773. mean = -mean;
  2774. vDSP_vsadd(x, 1, &mean, y, 1, ne00);
  2775. vDSP_measqv(y, 1, &variance, ne00);
  2776. #else
  2777. variance = ggml_vec_cvar_f32(ne00, y, x, mean);
  2778. #endif //GGML_USE_ACCELERATE
  2779. const float scale = 1.0f/sqrtf(variance + eps);
  2780. ggml_vec_scale_f32(ne00, y, scale);
  2781. }
  2782. }
  2783. }
  2784. }
  2785. void ggml_compute_forward_norm(
  2786. const ggml_compute_params * params,
  2787. ggml_tensor * dst) {
  2788. const ggml_tensor * src0 = dst->src[0];
  2789. switch (src0->type) {
  2790. case GGML_TYPE_F32:
  2791. {
  2792. ggml_compute_forward_norm_f32(params, dst);
  2793. } break;
  2794. default:
  2795. {
  2796. GGML_ABORT("fatal error");
  2797. }
  2798. }
  2799. }
  2800. // ggml_compute_forward_group_rms_norm
  2801. static void ggml_compute_forward_rms_norm_f32(
  2802. const ggml_compute_params * params,
  2803. ggml_tensor * dst) {
  2804. const ggml_tensor * src0 = dst->src[0];
  2805. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  2806. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2807. const int ith = params->ith;
  2808. const int nth = params->nth;
  2809. GGML_TENSOR_UNARY_OP_LOCALS
  2810. float eps;
  2811. memcpy(&eps, dst->op_params, sizeof(float));
  2812. GGML_ASSERT(eps >= 0.0f);
  2813. // TODO: optimize
  2814. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2815. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2816. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2817. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2818. ggml_float sum = 0.0;
  2819. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2820. sum += (ggml_float)(x[i00] * x[i00]);
  2821. }
  2822. const float mean = sum/ne00;
  2823. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2824. memcpy(y, x, ne00 * sizeof(float));
  2825. // for (int i00 = 0; i00 < ne00; i00++) {
  2826. // y[i00] = x[i00];
  2827. // }
  2828. const float scale = 1.0f/sqrtf(mean + eps);
  2829. // if you hit this, likely you got an inf somewhere earlier
  2830. assert(scale > 0.0f);
  2831. ggml_vec_scale_f32(ne00, y, scale);
  2832. }
  2833. }
  2834. }
  2835. }
  2836. void ggml_compute_forward_rms_norm(
  2837. const ggml_compute_params * params,
  2838. ggml_tensor * dst) {
  2839. const ggml_tensor * src0 = dst->src[0];
  2840. switch (src0->type) {
  2841. case GGML_TYPE_F32:
  2842. {
  2843. ggml_compute_forward_rms_norm_f32(params, dst);
  2844. } break;
  2845. default:
  2846. {
  2847. GGML_ABORT("fatal error");
  2848. }
  2849. }
  2850. }
  2851. static void ggml_compute_forward_rms_norm_back_f32(
  2852. const ggml_compute_params * params,
  2853. ggml_tensor * dst) {
  2854. const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
  2855. const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
  2856. GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
  2857. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2858. GGML_ASSERT(src1->nb[0] == sizeof(float));
  2859. const int ith = params->ith;
  2860. const int nth = params->nth;
  2861. GGML_TENSOR_BINARY_OP_LOCALS
  2862. float eps;
  2863. memcpy(&eps, dst->op_params, sizeof(float));
  2864. // TODO: optimize
  2865. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2866. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2867. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2868. // src1 is same shape as src0 => same indices
  2869. const int64_t i11 = i01;
  2870. const int64_t i12 = i02;
  2871. const int64_t i13 = i03;
  2872. const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2873. const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
  2874. ggml_float sum_xx = 0.0;
  2875. ggml_float sum_xdz = 0.0;
  2876. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2877. sum_xx += (ggml_float)(x[i00] * x[i00]);
  2878. sum_xdz += (ggml_float)(x[i00] * dz[i00]);
  2879. }
  2880. //const float mean = (float)(sum_xx)/ne00;
  2881. const float mean_eps = (float)(sum_xx)/ne00 + eps;
  2882. const float sum_eps = (float)(sum_xx) + eps*ne00;
  2883. //const float mean_xdz = (float)(sum_xdz)/ne00;
  2884. // we could cache rms from forward pass to improve performance.
  2885. // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
  2886. //const float rms = sqrtf(mean_eps);
  2887. const float rrms = 1.0f / sqrtf(mean_eps);
  2888. //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
  2889. {
  2890. // z = rms_norm(x)
  2891. //
  2892. // rms_norm(src1) =
  2893. // scale(
  2894. // src1,
  2895. // div(
  2896. // 1,
  2897. // sqrt(
  2898. // add(
  2899. // scale(
  2900. // sum(
  2901. // sqr(
  2902. // src1)),
  2903. // (1.0/N)),
  2904. // eps))));
  2905. // postorder:
  2906. // ## op args grad
  2907. // 00 param src1 grad[#00]
  2908. // 01 const 1
  2909. // 02 sqr (#00) grad[#02]
  2910. // 03 sum (#02) grad[#03]
  2911. // 04 const 1/N
  2912. // 05 scale (#03, #04) grad[#05]
  2913. // 06 const eps
  2914. // 07 add (#05, #06) grad[#07]
  2915. // 08 sqrt (#07) grad[#08]
  2916. // 09 div (#01,#08) grad[#09]
  2917. // 10 scale (#00,#09) grad[#10]
  2918. //
  2919. // backward pass, given grad[#10]
  2920. // #10: scale
  2921. // grad[#00] += scale(grad[#10],#09)
  2922. // grad[#09] += sum(mul(grad[#10],#00))
  2923. // #09: div
  2924. // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
  2925. // #08: sqrt
  2926. // grad[#07] += mul(grad[#08], div(0.5, #08))
  2927. // #07: add
  2928. // grad[#05] += grad[#07]
  2929. // #05: scale
  2930. // grad[#03] += scale(grad[#05],#04)
  2931. // #03: sum
  2932. // grad[#02] += repeat(grad[#03], #02)
  2933. // #02:
  2934. // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
  2935. //
  2936. // substitute and simplify:
  2937. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  2938. // grad[#02] = repeat(grad[#03], #02)
  2939. // grad[#02] = repeat(scale(grad[#05],#04), #02)
  2940. // grad[#02] = repeat(scale(grad[#07],#04), #02)
  2941. // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
  2942. // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
  2943. // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
  2944. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
  2945. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
  2946. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
  2947. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
  2948. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  2949. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
  2950. // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
  2951. // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
  2952. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  2953. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  2954. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
  2955. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
  2956. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
  2957. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
  2958. // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
  2959. // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
  2960. // a = b*c + d*e
  2961. // a = b*c*f/f + d*e*f/f
  2962. // a = (b*c*f + d*e*f)*(1/f)
  2963. // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
  2964. // a = (b + d*e/c)*c
  2965. // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
  2966. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
  2967. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
  2968. // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
  2969. // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
  2970. // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
  2971. // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
  2972. // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
  2973. // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  2974. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  2975. }
  2976. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  2977. // post-order:
  2978. // dx := x
  2979. // dx := scale(dx,-mean_xdz/mean_eps)
  2980. // dx := add(dx, dz)
  2981. // dx := scale(dx, rrms)
  2982. float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2983. // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
  2984. ggml_vec_cpy_f32 (ne00, dx, x);
  2985. // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
  2986. ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
  2987. ggml_vec_acc_f32 (ne00, dx, dz);
  2988. ggml_vec_scale_f32(ne00, dx, rrms);
  2989. }
  2990. }
  2991. }
  2992. }
  2993. void ggml_compute_forward_rms_norm_back(
  2994. const ggml_compute_params * params,
  2995. ggml_tensor * dst) {
  2996. const ggml_tensor * src0 = dst->src[0];
  2997. switch (src0->type) {
  2998. case GGML_TYPE_F32:
  2999. {
  3000. ggml_compute_forward_rms_norm_back_f32(params, dst);
  3001. } break;
  3002. default:
  3003. {
  3004. GGML_ABORT("fatal error");
  3005. }
  3006. }
  3007. }
  3008. // ggml_compute_forward_group_norm
  3009. static void ggml_compute_forward_group_norm_f32(
  3010. const ggml_compute_params * params,
  3011. ggml_tensor * dst) {
  3012. const ggml_tensor * src0 = dst->src[0];
  3013. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3014. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3015. const int ith = params->ith;
  3016. const int nth = params->nth;
  3017. GGML_TENSOR_UNARY_OP_LOCALS
  3018. // TODO: optimize
  3019. float eps;
  3020. memcpy(&eps, dst->op_params + 1, sizeof(float));
  3021. int n_channels = src0->ne[2];
  3022. int n_groups = dst->op_params[0];
  3023. int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
  3024. for (int i = ith; i < n_groups; i += nth) {
  3025. int start = i * n_channels_per_group;
  3026. int end = start + n_channels_per_group;
  3027. if (end > n_channels) {
  3028. end = n_channels;
  3029. }
  3030. int step = end - start;
  3031. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3032. ggml_float sum = 0.0;
  3033. for (int64_t i02 = start; i02 < end; i02++) {
  3034. for (int64_t i01 = 0; i01 < ne01; i01++) {
  3035. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  3036. ggml_float sumr = 0.0;
  3037. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3038. sumr += (ggml_float)x[i00];
  3039. }
  3040. sum += sumr;
  3041. }
  3042. }
  3043. const float mean = sum / (ne00 * ne01 * step);
  3044. ggml_float sum2 = 0.0;
  3045. for (int64_t i02 = start; i02 < end; i02++) {
  3046. for (int64_t i01 = 0; i01 < ne01; i01++) {
  3047. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  3048. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  3049. ggml_float sumr = 0.0;
  3050. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3051. float v = x[i00] - mean;
  3052. y[i00] = v;
  3053. sumr += (ggml_float)(v * v);
  3054. }
  3055. sum2 += sumr;
  3056. }
  3057. }
  3058. const float variance = sum2 / (ne00 * ne01 * step);
  3059. const float scale = 1.0f / sqrtf(variance + eps);
  3060. for (int64_t i02 = start; i02 < end; i02++) {
  3061. for (int64_t i01 = 0; i01 < ne01; i01++) {
  3062. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  3063. ggml_vec_scale_f32(ne00, y, scale);
  3064. }
  3065. }
  3066. }
  3067. }
  3068. }
  3069. void ggml_compute_forward_group_norm(
  3070. const ggml_compute_params * params,
  3071. ggml_tensor * dst) {
  3072. const ggml_tensor * src0 = dst->src[0];
  3073. switch (src0->type) {
  3074. case GGML_TYPE_F32:
  3075. {
  3076. ggml_compute_forward_group_norm_f32(params, dst);
  3077. } break;
  3078. default:
  3079. {
  3080. GGML_ABORT("fatal error");
  3081. }
  3082. }
  3083. }
  3084. // ggml_compute_forward_l2_norm
  3085. static void ggml_compute_forward_l2_norm_f32(
  3086. const ggml_compute_params * params,
  3087. ggml_tensor * dst) {
  3088. const ggml_tensor * src0 = dst->src[0];
  3089. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3090. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3091. const int ith = params->ith;
  3092. const int nth = params->nth;
  3093. GGML_TENSOR_UNARY_OP_LOCALS
  3094. float eps;
  3095. memcpy(&eps, dst->op_params, sizeof(float));
  3096. GGML_ASSERT(eps >= 0.0f);
  3097. // TODO: optimize
  3098. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3099. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3100. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  3101. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  3102. ggml_float sum = 0.0;
  3103. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3104. sum += (ggml_float)(x[i00] * x[i00]);
  3105. }
  3106. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  3107. memcpy(y, x, ne00 * sizeof(float));
  3108. const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
  3109. ggml_vec_scale_f32(ne00, y, scale);
  3110. }
  3111. }
  3112. }
  3113. }
  3114. void ggml_compute_forward_l2_norm(
  3115. const ggml_compute_params * params,
  3116. ggml_tensor * dst) {
  3117. const ggml_tensor * src0 = dst->src[0];
  3118. switch (src0->type) {
  3119. case GGML_TYPE_F32:
  3120. {
  3121. ggml_compute_forward_l2_norm_f32(params, dst);
  3122. } break;
  3123. default:
  3124. {
  3125. GGML_ABORT("fatal error");
  3126. }
  3127. }
  3128. }
  3129. // ggml_compute_forward_out_prod
  3130. static void ggml_compute_forward_out_prod_f32(
  3131. const ggml_compute_params * params,
  3132. ggml_tensor * dst) {
  3133. const ggml_tensor * src0 = dst->src[0];
  3134. const ggml_tensor * src1 = dst->src[1];
  3135. GGML_TENSOR_BINARY_OP_LOCALS
  3136. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  3137. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  3138. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  3139. const int ith = params->ith;
  3140. const int nth = params->nth;
  3141. GGML_ASSERT(ne0 == ne00);
  3142. GGML_ASSERT(ne1 == ne10);
  3143. GGML_ASSERT(ne2 == ne12);
  3144. GGML_ASSERT(ne3 == ne13);
  3145. GGML_ASSERT(ne2 % ne02 == 0);
  3146. GGML_ASSERT(ne3 % ne03 == 0);
  3147. // we don't support permuted src0 or src1
  3148. GGML_ASSERT(nb00 == sizeof(float));
  3149. // dst cannot be transposed or permuted
  3150. GGML_ASSERT(nb0 == sizeof(float));
  3151. // GGML_ASSERT(nb0 <= nb1);
  3152. // GGML_ASSERT(nb1 <= nb2);
  3153. // GGML_ASSERT(nb2 <= nb3);
  3154. // nb01 >= nb00 - src0 is not transposed
  3155. // compute by src0 rows
  3156. if (ith == 0) {
  3157. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  3158. }
  3159. ggml_barrier(params->threadpool);
  3160. // dst[:,:,:,:] = 0
  3161. // for i2,i3:
  3162. // for i1:
  3163. // for i01:
  3164. // for i0:
  3165. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  3166. // parallelize by last three dimensions
  3167. // total rows in dst
  3168. const int64_t nr = ne1*ne2*ne3;
  3169. // rows per thread
  3170. const int64_t dr = (nr + nth - 1)/nth;
  3171. // row range for this thread
  3172. const int64_t ir0 = dr*ith;
  3173. const int64_t ir1 = MIN(ir0 + dr, nr);
  3174. // block-tiling attempt
  3175. const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
  3176. const int64_t blck_1 = 16;
  3177. // dps == dst per src0, used for group query attention
  3178. const int64_t dps2 = ne2 / ne02;
  3179. const int64_t dps3 = ne3 / ne03;
  3180. for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
  3181. const int64_t bir1 = MIN(bir + blck_1, ir1);
  3182. for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
  3183. const int64_t bne01 = MIN(bi01 + blck_0, ne01);
  3184. for (int64_t ir = bir; ir < bir1; ++ir) {
  3185. // dst indices
  3186. const int64_t i3 = ir/(ne2*ne1);
  3187. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  3188. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3189. const int64_t i02 = i2 / dps2;
  3190. const int64_t i03 = i3 / dps3;
  3191. //const int64_t i10 = i1;
  3192. const int64_t i12 = i2;
  3193. const int64_t i13 = i3;
  3194. #if GGML_VEC_MAD_UNROLL > 2
  3195. const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
  3196. for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
  3197. const int64_t i11 = i01;
  3198. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3199. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3200. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3201. ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
  3202. }
  3203. for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
  3204. const int64_t i11 = i01;
  3205. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3206. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3207. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3208. ggml_vec_mad_f32(ne0, d, s0, *s1);
  3209. }
  3210. #else
  3211. for (int64_t i01 = bi01; i01 < bne01; ++i01) {
  3212. const int64_t i11 = i01;
  3213. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3214. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3215. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3216. ggml_vec_mad_f32(ne0, d, s0, *s1);
  3217. }
  3218. #endif
  3219. }
  3220. }
  3221. }
  3222. }
  3223. static void ggml_compute_forward_out_prod_q_f32(
  3224. const ggml_compute_params * params,
  3225. ggml_tensor * dst) {
  3226. const ggml_tensor * src0 = dst->src[0];
  3227. const ggml_tensor * src1 = dst->src[1];
  3228. GGML_TENSOR_BINARY_OP_LOCALS;
  3229. const int ith = params->ith;
  3230. const int nth = params->nth;
  3231. const ggml_type type = src0->type;
  3232. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  3233. GGML_ASSERT(ne02 == ne12);
  3234. GGML_ASSERT(ne03 == ne13);
  3235. GGML_ASSERT(ne2 == ne12);
  3236. GGML_ASSERT(ne3 == ne13);
  3237. // we don't support permuted src0 dim0
  3238. GGML_ASSERT(nb00 == ggml_type_size(type));
  3239. // dst dim0 cannot be transposed or permuted
  3240. GGML_ASSERT(nb0 == sizeof(float));
  3241. // GGML_ASSERT(nb0 <= nb1);
  3242. // GGML_ASSERT(nb1 <= nb2);
  3243. // GGML_ASSERT(nb2 <= nb3);
  3244. GGML_ASSERT(ne0 == ne00);
  3245. GGML_ASSERT(ne1 == ne10);
  3246. GGML_ASSERT(ne2 == ne02);
  3247. GGML_ASSERT(ne3 == ne03);
  3248. // nb01 >= nb00 - src0 is not transposed
  3249. // compute by src0 rows
  3250. if (ith == 0) {
  3251. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  3252. }
  3253. ggml_barrier(params->threadpool);
  3254. // parallelize by last three dimensions
  3255. // total rows in dst
  3256. const int64_t nr = ne1*ne2*ne3;
  3257. // rows per thread
  3258. const int64_t dr = (nr + nth - 1)/nth;
  3259. // row range for this thread
  3260. const int64_t ir0 = dr*ith;
  3261. const int64_t ir1 = MIN(ir0 + dr, nr);
  3262. // dst[:,:,:,:] = 0
  3263. // for i2,i3:
  3264. // for i1:
  3265. // for i01:
  3266. // for i0:
  3267. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  3268. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  3269. for (int64_t ir = ir0; ir < ir1; ++ir) {
  3270. // dst indices
  3271. const int64_t i3 = ir/(ne2*ne1);
  3272. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  3273. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3274. const int64_t i02 = i2;
  3275. const int64_t i03 = i3;
  3276. //const int64_t i10 = i1;
  3277. const int64_t i12 = i2;
  3278. const int64_t i13 = i3;
  3279. for (int64_t i01 = 0; i01 < ne01; ++i01) {
  3280. const int64_t i11 = i01;
  3281. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3282. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3283. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3284. dequantize_row_q(s0, wdata, ne0);
  3285. ggml_vec_mad_f32(ne0, d, wdata, *s1);
  3286. }
  3287. }
  3288. }
  3289. void ggml_compute_forward_out_prod(
  3290. const ggml_compute_params * params,
  3291. ggml_tensor * dst) {
  3292. const ggml_tensor * src0 = dst->src[0];
  3293. switch (src0->type) {
  3294. case GGML_TYPE_Q4_0:
  3295. case GGML_TYPE_Q4_1:
  3296. case GGML_TYPE_Q5_0:
  3297. case GGML_TYPE_Q5_1:
  3298. case GGML_TYPE_Q8_0:
  3299. case GGML_TYPE_MXFP4:
  3300. case GGML_TYPE_Q2_K:
  3301. case GGML_TYPE_Q3_K:
  3302. case GGML_TYPE_Q4_K:
  3303. case GGML_TYPE_Q5_K:
  3304. case GGML_TYPE_Q6_K:
  3305. case GGML_TYPE_TQ1_0:
  3306. case GGML_TYPE_TQ2_0:
  3307. case GGML_TYPE_IQ2_XXS:
  3308. case GGML_TYPE_IQ2_XS:
  3309. case GGML_TYPE_IQ3_XXS:
  3310. case GGML_TYPE_IQ1_S:
  3311. case GGML_TYPE_IQ1_M:
  3312. case GGML_TYPE_IQ4_NL:
  3313. case GGML_TYPE_IQ4_XS:
  3314. case GGML_TYPE_IQ3_S:
  3315. case GGML_TYPE_IQ2_S:
  3316. {
  3317. ggml_compute_forward_out_prod_q_f32(params, dst);
  3318. } break;
  3319. case GGML_TYPE_F16:
  3320. {
  3321. GGML_ABORT("fatal error"); // todo
  3322. // ggml_compute_forward_out_prod_f16_f32(params, dst);
  3323. }
  3324. case GGML_TYPE_F32:
  3325. {
  3326. ggml_compute_forward_out_prod_f32(params, dst);
  3327. } break;
  3328. default:
  3329. {
  3330. GGML_ABORT("fatal error");
  3331. }
  3332. }
  3333. }
  3334. // ggml_compute_forward_scale
  3335. static void ggml_compute_forward_scale_f32(
  3336. const ggml_compute_params * params,
  3337. ggml_tensor * dst) {
  3338. const ggml_tensor * src0 = dst->src[0];
  3339. GGML_ASSERT(ggml_is_contiguous(src0));
  3340. GGML_ASSERT(ggml_is_contiguous(dst));
  3341. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3342. float s; // scale factor
  3343. float b; // bias
  3344. memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
  3345. memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
  3346. const int ith = params->ith;
  3347. const int nth = params->nth;
  3348. const int nc = src0->ne[0];
  3349. const int nr = ggml_nrows(src0);
  3350. // rows per thread
  3351. const int dr = (nr + nth - 1)/nth;
  3352. // row range for this thread
  3353. const int ir0 = dr*ith;
  3354. const int ir1 = MIN(ir0 + dr, nr);
  3355. const size_t nb01 = src0->nb[1];
  3356. const size_t nb1 = dst->nb[1];
  3357. if (b == 0.0f) {
  3358. for (int i1 = ir0; i1 < ir1; i1++) {
  3359. if (dst->data != src0->data) {
  3360. // src0 is same shape as dst => same indices
  3361. // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
  3362. memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
  3363. }
  3364. ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
  3365. }
  3366. } else {
  3367. for (int i1 = ir0; i1 < ir1; i1++) {
  3368. ggml_vec_mad1_f32(nc,
  3369. (float *) ((char *) dst->data + i1*nb1),
  3370. (float *) ((char *) src0->data + i1*nb1),
  3371. s, b);
  3372. }
  3373. }
  3374. }
  3375. void ggml_compute_forward_scale(
  3376. const ggml_compute_params * params,
  3377. ggml_tensor * dst) {
  3378. const ggml_tensor * src0 = dst->src[0];
  3379. switch (src0->type) {
  3380. case GGML_TYPE_F32:
  3381. {
  3382. ggml_compute_forward_scale_f32(params, dst);
  3383. } break;
  3384. default:
  3385. {
  3386. GGML_ABORT("fatal error");
  3387. }
  3388. }
  3389. }
  3390. // ggml_compute_forward_set
  3391. static void ggml_compute_forward_set_f32(
  3392. const ggml_compute_params * params,
  3393. ggml_tensor * dst) {
  3394. const ggml_tensor * src0 = dst->src[0];
  3395. const ggml_tensor * src1 = dst->src[1];
  3396. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3397. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3398. // view src0 and dst with these strides and data offset inbytes during set
  3399. // nb0 is implicitly element_size because src0 and dst are contiguous
  3400. size_t nb1 = ((int32_t *) dst->op_params)[0];
  3401. size_t nb2 = ((int32_t *) dst->op_params)[1];
  3402. size_t nb3 = ((int32_t *) dst->op_params)[2];
  3403. size_t offset = ((int32_t *) dst->op_params)[3];
  3404. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  3405. if (!inplace) {
  3406. if (params->ith == 0) {
  3407. // memcpy needs to be synchronized across threads to avoid race conditions.
  3408. // => do it in INIT phase
  3409. memcpy(
  3410. ((char *) dst->data),
  3411. ((char *) src0->data),
  3412. ggml_nbytes(dst));
  3413. }
  3414. ggml_barrier(params->threadpool);
  3415. }
  3416. const int ith = params->ith;
  3417. const int nth = params->nth;
  3418. const int nr = ggml_nrows(src1);
  3419. const int nc = src1->ne[0];
  3420. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  3421. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  3422. // src0 and dst as viewed during set
  3423. const size_t nb0 = ggml_element_size(src0);
  3424. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  3425. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  3426. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  3427. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  3428. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  3429. GGML_ASSERT(nb10 == sizeof(float));
  3430. // rows per thread
  3431. const int dr = (nr + nth - 1)/nth;
  3432. // row range for this thread
  3433. const int ir0 = dr*ith;
  3434. const int ir1 = MIN(ir0 + dr, nr);
  3435. for (int ir = ir0; ir < ir1; ++ir) {
  3436. // src0 and dst are viewed with shape of src1 and offset
  3437. // => same indices
  3438. const int i3 = ir/(ne12*ne11);
  3439. const int i2 = (ir - i3*ne12*ne11)/ne11;
  3440. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  3441. ggml_vec_cpy_f32(nc,
  3442. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  3443. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  3444. }
  3445. }
  3446. static void ggml_compute_forward_set_i32(
  3447. const ggml_compute_params * params,
  3448. ggml_tensor * dst) {
  3449. const ggml_tensor * src0 = dst->src[0];
  3450. const ggml_tensor * src1 = dst->src[1];
  3451. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3452. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3453. // view src0 and dst with these strides and data offset inbytes during set
  3454. // nb0 is implicitly element_size because src0 and dst are contiguous
  3455. size_t nb1 = ((int32_t *) dst->op_params)[0];
  3456. size_t nb2 = ((int32_t *) dst->op_params)[1];
  3457. size_t nb3 = ((int32_t *) dst->op_params)[2];
  3458. size_t offset = ((int32_t *) dst->op_params)[3];
  3459. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  3460. if (!inplace) {
  3461. if (params->ith == 0) {
  3462. // memcpy needs to be synchronized across threads to avoid race conditions.
  3463. // => do it in INIT phase
  3464. memcpy(
  3465. ((char *) dst->data),
  3466. ((char *) src0->data),
  3467. ggml_nbytes(dst));
  3468. }
  3469. ggml_barrier(params->threadpool);
  3470. }
  3471. const int ith = params->ith;
  3472. const int nth = params->nth;
  3473. const int nr = ggml_nrows(src1);
  3474. const int nc = src1->ne[0];
  3475. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  3476. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  3477. // src0 and dst as viewed during set
  3478. const size_t nb0 = ggml_element_size(src0);
  3479. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  3480. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  3481. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  3482. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  3483. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  3484. GGML_ASSERT(nb10 == sizeof(int32_t));
  3485. // rows per thread
  3486. const int dr = (nr + nth - 1)/nth;
  3487. // row range for this thread
  3488. const int ir0 = dr*ith;
  3489. const int ir1 = MIN(ir0 + dr, nr);
  3490. for (int ir = ir0; ir < ir1; ++ir) {
  3491. // src0 and dst are viewed with shape of src1 and offset
  3492. // => same indices
  3493. const int i3 = ir/(ne12*ne11);
  3494. const int i2 = (ir - i3*ne12*ne11)/ne11;
  3495. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  3496. ggml_vec_cpy_i32(nc,
  3497. (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  3498. (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  3499. }
  3500. }
  3501. void ggml_compute_forward_set(
  3502. const ggml_compute_params * params,
  3503. ggml_tensor * dst) {
  3504. const ggml_tensor * src0 = dst->src[0];
  3505. switch (src0->type) {
  3506. case GGML_TYPE_F32:
  3507. {
  3508. ggml_compute_forward_set_f32(params, dst);
  3509. } break;
  3510. case GGML_TYPE_I32:
  3511. {
  3512. ggml_compute_forward_set_i32(params, dst);
  3513. } break;
  3514. case GGML_TYPE_F16:
  3515. case GGML_TYPE_BF16:
  3516. case GGML_TYPE_Q4_0:
  3517. case GGML_TYPE_Q4_1:
  3518. case GGML_TYPE_Q5_0:
  3519. case GGML_TYPE_Q5_1:
  3520. case GGML_TYPE_Q8_0:
  3521. case GGML_TYPE_Q8_1:
  3522. case GGML_TYPE_MXFP4:
  3523. case GGML_TYPE_Q2_K:
  3524. case GGML_TYPE_Q3_K:
  3525. case GGML_TYPE_Q4_K:
  3526. case GGML_TYPE_Q5_K:
  3527. case GGML_TYPE_Q6_K:
  3528. case GGML_TYPE_TQ1_0:
  3529. case GGML_TYPE_TQ2_0:
  3530. case GGML_TYPE_IQ2_XXS:
  3531. case GGML_TYPE_IQ2_XS:
  3532. case GGML_TYPE_IQ3_XXS:
  3533. case GGML_TYPE_IQ1_S:
  3534. case GGML_TYPE_IQ1_M:
  3535. case GGML_TYPE_IQ4_NL:
  3536. case GGML_TYPE_IQ4_XS:
  3537. case GGML_TYPE_IQ3_S:
  3538. case GGML_TYPE_IQ2_S:
  3539. default:
  3540. {
  3541. GGML_ABORT("fatal error");
  3542. }
  3543. }
  3544. }
  3545. // ggml_compute_forward_cpy
  3546. void ggml_compute_forward_cpy(
  3547. const ggml_compute_params * params,
  3548. ggml_tensor * dst) {
  3549. ggml_compute_forward_dup(params, dst);
  3550. }
  3551. // ggml_compute_forward_cont
  3552. void ggml_compute_forward_cont(
  3553. const ggml_compute_params * params,
  3554. ggml_tensor * dst) {
  3555. ggml_compute_forward_dup(params, dst);
  3556. }
  3557. // ggml_compute_forward_get_rows
  3558. static void ggml_compute_forward_get_rows_q(
  3559. const ggml_compute_params * params,
  3560. ggml_tensor * dst) {
  3561. const ggml_tensor * src0 = dst->src[0];
  3562. const ggml_tensor * src1 = dst->src[1];
  3563. GGML_TENSOR_BINARY_OP_LOCALS
  3564. const int64_t nc = ne00;
  3565. const int64_t nr = ggml_nelements(src1);
  3566. const ggml_type type = src0->type;
  3567. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  3568. assert(ne0 == nc);
  3569. assert(ne02 == ne11);
  3570. assert(nb00 == ggml_type_size(type));
  3571. assert(ggml_nrows(dst) == nr);
  3572. const int ith = params->ith;
  3573. const int nth = params->nth;
  3574. // rows per thread
  3575. const int dr = (nr + nth - 1)/nth;
  3576. // row range for this thread
  3577. const int ir0 = dr*ith;
  3578. const int ir1 = MIN(ir0 + dr, nr);
  3579. for (int64_t i = ir0; i < ir1; ++i) {
  3580. const int64_t i12 = i/(ne11*ne10);
  3581. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3582. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3583. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3584. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3585. dequantize_row_q(
  3586. (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  3587. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  3588. }
  3589. }
  3590. static void ggml_compute_forward_get_rows_f16(
  3591. const ggml_compute_params * params,
  3592. ggml_tensor * dst) {
  3593. const ggml_tensor * src0 = dst->src[0];
  3594. const ggml_tensor * src1 = dst->src[1];
  3595. GGML_TENSOR_BINARY_OP_LOCALS
  3596. const int64_t nc = ne00;
  3597. const int64_t nr = ggml_nelements(src1);
  3598. assert(ne0 == nc);
  3599. assert(ne02 == ne11);
  3600. assert(nb00 == sizeof(ggml_fp16_t));
  3601. assert(ggml_nrows(dst) == nr);
  3602. const int ith = params->ith;
  3603. const int nth = params->nth;
  3604. // rows per thread
  3605. const int dr = (nr + nth - 1)/nth;
  3606. // row range for this thread
  3607. const int ir0 = dr*ith;
  3608. const int ir1 = MIN(ir0 + dr, nr);
  3609. for (int64_t i = ir0; i < ir1; ++i) {
  3610. const int64_t i12 = i/(ne11*ne10);
  3611. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3612. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3613. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3614. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3615. ggml_cpu_fp16_to_fp32(
  3616. (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  3617. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  3618. }
  3619. }
  3620. static void ggml_compute_forward_get_rows_bf16(
  3621. const ggml_compute_params * params,
  3622. ggml_tensor * dst) {
  3623. const ggml_tensor * src0 = dst->src[0];
  3624. const ggml_tensor * src1 = dst->src[1];
  3625. GGML_TENSOR_BINARY_OP_LOCALS
  3626. const int64_t nc = ne00;
  3627. const int64_t nr = ggml_nelements(src1);
  3628. assert(ne0 == nc);
  3629. assert(ne02 == ne11);
  3630. assert(nb00 == sizeof(ggml_bf16_t));
  3631. assert(ggml_nrows(dst) == nr);
  3632. const int ith = params->ith;
  3633. const int nth = params->nth;
  3634. // rows per thread
  3635. const int dr = (nr + nth - 1)/nth;
  3636. // row range for this thread
  3637. const int ir0 = dr*ith;
  3638. const int ir1 = MIN(ir0 + dr, nr);
  3639. for (int64_t i = ir0; i < ir1; ++i) {
  3640. const int64_t i12 = i/(ne11*ne10);
  3641. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3642. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3643. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3644. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3645. ggml_cpu_bf16_to_fp32(
  3646. (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  3647. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  3648. }
  3649. }
  3650. static void ggml_compute_forward_get_rows_f32(
  3651. const ggml_compute_params * params,
  3652. ggml_tensor * dst) {
  3653. const ggml_tensor * src0 = dst->src[0];
  3654. const ggml_tensor * src1 = dst->src[1];
  3655. GGML_TENSOR_BINARY_OP_LOCALS
  3656. const int64_t nc = ne00;
  3657. const int64_t nr = ggml_nelements(src1);
  3658. assert(ne0 == nc);
  3659. assert(ne02 == ne11);
  3660. assert(nb00 == sizeof(float));
  3661. assert(ggml_nrows(dst) == nr);
  3662. const int ith = params->ith;
  3663. const int nth = params->nth;
  3664. // rows per thread
  3665. const int dr = (nr + nth - 1)/nth;
  3666. // row range for this thread
  3667. const int ir0 = dr*ith;
  3668. const int ir1 = MIN(ir0 + dr, nr);
  3669. for (int64_t i = ir0; i < ir1; ++i) {
  3670. const int64_t i12 = i/(ne11*ne10);
  3671. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3672. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3673. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3674. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3675. ggml_vec_cpy_f32(nc,
  3676. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
  3677. (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
  3678. }
  3679. }
  3680. void ggml_compute_forward_get_rows(
  3681. const ggml_compute_params * params,
  3682. ggml_tensor * dst) {
  3683. const ggml_tensor * src0 = dst->src[0];
  3684. switch (src0->type) {
  3685. case GGML_TYPE_Q4_0:
  3686. case GGML_TYPE_Q4_1:
  3687. case GGML_TYPE_Q5_0:
  3688. case GGML_TYPE_Q5_1:
  3689. case GGML_TYPE_Q8_0:
  3690. case GGML_TYPE_Q8_1:
  3691. case GGML_TYPE_MXFP4:
  3692. case GGML_TYPE_Q2_K:
  3693. case GGML_TYPE_Q3_K:
  3694. case GGML_TYPE_Q4_K:
  3695. case GGML_TYPE_Q5_K:
  3696. case GGML_TYPE_Q6_K:
  3697. case GGML_TYPE_TQ1_0:
  3698. case GGML_TYPE_TQ2_0:
  3699. case GGML_TYPE_IQ2_XXS:
  3700. case GGML_TYPE_IQ2_XS:
  3701. case GGML_TYPE_IQ3_XXS:
  3702. case GGML_TYPE_IQ1_S:
  3703. case GGML_TYPE_IQ1_M:
  3704. case GGML_TYPE_IQ4_NL:
  3705. case GGML_TYPE_IQ4_XS:
  3706. case GGML_TYPE_IQ3_S:
  3707. case GGML_TYPE_IQ2_S:
  3708. {
  3709. ggml_compute_forward_get_rows_q(params, dst);
  3710. } break;
  3711. case GGML_TYPE_F16:
  3712. {
  3713. ggml_compute_forward_get_rows_f16(params, dst);
  3714. } break;
  3715. case GGML_TYPE_BF16:
  3716. {
  3717. ggml_compute_forward_get_rows_bf16(params, dst);
  3718. } break;
  3719. case GGML_TYPE_F32:
  3720. case GGML_TYPE_I32:
  3721. {
  3722. ggml_compute_forward_get_rows_f32(params, dst);
  3723. } break;
  3724. default:
  3725. {
  3726. GGML_ABORT("fatal error");
  3727. }
  3728. }
  3729. //static bool first = true;
  3730. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  3731. //if (first) {
  3732. // first = false;
  3733. //} else {
  3734. // for (int k = 0; k < dst->ne[1]; ++k) {
  3735. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  3736. // for (int i = 0; i < 16; ++i) {
  3737. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  3738. // }
  3739. // printf("\n");
  3740. // }
  3741. // printf("\n");
  3742. // }
  3743. // printf("\n");
  3744. // exit(0);
  3745. //}
  3746. }
  3747. template<typename idx_t>
  3748. static void ggml_compute_forward_set_rows_f32(
  3749. const ggml_compute_params * params,
  3750. ggml_tensor * dst) {
  3751. const ggml_tensor * src0 = dst->src[0];
  3752. const ggml_tensor * src1 = dst->src[1];
  3753. GGML_TENSOR_BINARY_OP_LOCALS
  3754. const int64_t nc = ne00;
  3755. const int64_t nr = ne01;
  3756. assert(ne0 == nc);
  3757. assert(ne2 == ne02);
  3758. assert(ne3 == ne03);
  3759. assert(src0->type == GGML_TYPE_F32);
  3760. assert(ne02 % ne11 == 0);
  3761. assert(ne03 % ne12 == 0);
  3762. const int ith = params->ith;
  3763. const int nth = params->nth;
  3764. // rows per thread
  3765. const int64_t dr = (nr + nth - 1)/nth;
  3766. // row range for this thread
  3767. const int64_t ir0 = dr*ith;
  3768. const int64_t ir1 = std::min(ir0 + dr, nr);
  3769. ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
  3770. for (int64_t i03 = 0; i03 < ne03; ++i03) {
  3771. for (int64_t i02 = 0; i02 < ne02; ++i02) {
  3772. for (int64_t i = ir0; i < ir1; ++i) {
  3773. const int64_t i12 = i03%ne12;
  3774. const int64_t i11 = i02%ne11;
  3775. const int64_t i10 = i;
  3776. const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3777. GGML_ASSERT(i1 >= 0 && i1 < ne1);
  3778. from_float(
  3779. (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
  3780. ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
  3781. }
  3782. }
  3783. }
  3784. }
  3785. void ggml_compute_forward_set_rows(
  3786. const ggml_compute_params * params,
  3787. ggml_tensor * dst) {
  3788. const ggml_tensor * src0 = dst->src[0];
  3789. const ggml_tensor * src1 = dst->src[1];
  3790. switch (src0->type) {
  3791. case GGML_TYPE_F32:
  3792. {
  3793. if (src1->type == GGML_TYPE_I64) {
  3794. ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
  3795. } else if (src1->type == GGML_TYPE_I32) {
  3796. ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
  3797. } else {
  3798. GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
  3799. }
  3800. } break;
  3801. default:
  3802. {
  3803. GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
  3804. }
  3805. }
  3806. }
  3807. // ggml_compute_forward_get_rows_back
  3808. static void ggml_compute_forward_get_rows_back_f32_f16(
  3809. const ggml_compute_params * params,
  3810. ggml_tensor * dst) {
  3811. const ggml_tensor * src0 = dst->src[0];
  3812. const ggml_tensor * src1 = dst->src[1];
  3813. if (params->ith != 0) {
  3814. return;
  3815. }
  3816. GGML_ASSERT(ggml_is_contiguous(dst));
  3817. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  3818. memset(dst->data, 0, ggml_nbytes(dst));
  3819. const int nc = src0->ne[0];
  3820. const int nr = ggml_nelements(src1);
  3821. GGML_ASSERT( dst->ne[0] == nc);
  3822. GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
  3823. for (int i = 0; i < nr; ++i) {
  3824. const int r = ((int32_t *) src1->data)[i];
  3825. for (int j = 0; j < nc; ++j) {
  3826. ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
  3827. ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
  3828. }
  3829. }
  3830. }
  3831. static void ggml_compute_forward_get_rows_back_f32(
  3832. const ggml_compute_params * params,
  3833. ggml_tensor * dst) {
  3834. const ggml_tensor * src0 = dst->src[0];
  3835. const ggml_tensor * src1 = dst->src[1];
  3836. if (params->ith != 0) {
  3837. return;
  3838. }
  3839. GGML_ASSERT(ggml_is_contiguous(dst));
  3840. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  3841. memset(dst->data, 0, ggml_nbytes(dst));
  3842. const int nc = src0->ne[0];
  3843. const int nr = ggml_nelements(src1);
  3844. GGML_ASSERT( dst->ne[0] == nc);
  3845. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3846. for (int i = 0; i < nr; ++i) {
  3847. const int r = ((int32_t *) src1->data)[i];
  3848. ggml_vec_add_f32(nc,
  3849. (float *) ((char *) dst->data + r*dst->nb[1]),
  3850. (float *) ((char *) dst->data + r*dst->nb[1]),
  3851. (float *) ((char *) src0->data + i*src0->nb[1]));
  3852. }
  3853. }
  3854. void ggml_compute_forward_get_rows_back(
  3855. const ggml_compute_params * params,
  3856. ggml_tensor * dst) {
  3857. const ggml_tensor * src0 = dst->src[0];
  3858. switch (src0->type) {
  3859. case GGML_TYPE_F16:
  3860. {
  3861. ggml_compute_forward_get_rows_back_f32_f16(params, dst);
  3862. } break;
  3863. case GGML_TYPE_F32:
  3864. {
  3865. ggml_compute_forward_get_rows_back_f32(params, dst);
  3866. } break;
  3867. default:
  3868. {
  3869. GGML_ABORT("fatal error");
  3870. }
  3871. }
  3872. //static bool first = true;
  3873. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  3874. //if (first) {
  3875. // first = false;
  3876. //} else {
  3877. // for (int k = 0; k < dst->ne[1]; ++k) {
  3878. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  3879. // for (int i = 0; i < 16; ++i) {
  3880. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  3881. // }
  3882. // printf("\n");
  3883. // }
  3884. // printf("\n");
  3885. // }
  3886. // printf("\n");
  3887. // exit(0);
  3888. //}
  3889. }
  3890. // ggml_compute_forward_diag
  3891. static void ggml_compute_forward_diag_f32(
  3892. const ggml_compute_params * params,
  3893. ggml_tensor * dst) {
  3894. const ggml_tensor * src0 = dst->src[0];
  3895. if (params->ith != 0) {
  3896. return;
  3897. }
  3898. // TODO: handle transposed/permuted matrices
  3899. GGML_TENSOR_UNARY_OP_LOCALS
  3900. GGML_ASSERT(ne00 == ne0);
  3901. GGML_ASSERT(ne00 == ne1);
  3902. GGML_ASSERT(ne01 == 1);
  3903. GGML_ASSERT(ne02 == ne2);
  3904. GGML_ASSERT(ne03 == ne3);
  3905. GGML_ASSERT(nb00 == sizeof(float));
  3906. GGML_ASSERT(nb0 == sizeof(float));
  3907. for (int i3 = 0; i3 < ne3; i3++) {
  3908. for (int i2 = 0; i2 < ne2; i2++) {
  3909. for (int i1 = 0; i1 < ne1; i1++) {
  3910. float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3911. float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
  3912. for (int i0 = 0; i0 < i1; i0++) {
  3913. d[i0] = 0;
  3914. }
  3915. d[i1] = s[i1];
  3916. for (int i0 = i1+1; i0 < ne0; i0++) {
  3917. d[i0] = 0;
  3918. }
  3919. }
  3920. }
  3921. }
  3922. }
  3923. void ggml_compute_forward_diag(
  3924. const ggml_compute_params * params,
  3925. ggml_tensor * dst) {
  3926. const ggml_tensor * src0 = dst->src[0];
  3927. switch (src0->type) {
  3928. case GGML_TYPE_F32:
  3929. {
  3930. ggml_compute_forward_diag_f32(params, dst);
  3931. } break;
  3932. default:
  3933. {
  3934. GGML_ABORT("fatal error");
  3935. }
  3936. }
  3937. }
  3938. // ggml_compute_forward_diag_mask_inf
  3939. static void ggml_compute_forward_diag_mask_f32(
  3940. const ggml_compute_params * params,
  3941. ggml_tensor * dst,
  3942. const float value) {
  3943. const ggml_tensor * src0 = dst->src[0];
  3944. const int ith = params->ith;
  3945. const int nth = params->nth;
  3946. const int n_past = ((int32_t *) dst->op_params)[0];
  3947. const bool inplace = src0->data == dst->data;
  3948. GGML_ASSERT(n_past >= 0);
  3949. if (!inplace) {
  3950. if (ith == 0) {
  3951. // memcpy needs to be synchronized across threads to avoid race conditions.
  3952. // => do it in INIT phase
  3953. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  3954. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3955. memcpy(
  3956. ((char *) dst->data),
  3957. ((char *) src0->data),
  3958. ggml_nbytes(dst));
  3959. }
  3960. ggml_barrier(params->threadpool);
  3961. }
  3962. // TODO: handle transposed/permuted matrices
  3963. const int n = ggml_nrows(src0);
  3964. const int nc = src0->ne[0];
  3965. const int nr = src0->ne[1];
  3966. const int nz = n/nr;
  3967. GGML_ASSERT( dst->nb[0] == sizeof(float));
  3968. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3969. for (int k = 0; k < nz; k++) {
  3970. for (int j = ith; j < nr; j += nth) {
  3971. for (int i = n_past; i < nc; i++) {
  3972. if (i > n_past + j) {
  3973. *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
  3974. }
  3975. }
  3976. }
  3977. }
  3978. }
  3979. void ggml_compute_forward_diag_mask_inf(
  3980. const ggml_compute_params * params,
  3981. ggml_tensor * dst) {
  3982. const ggml_tensor * src0 = dst->src[0];
  3983. switch (src0->type) {
  3984. case GGML_TYPE_F32:
  3985. {
  3986. ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
  3987. } break;
  3988. default:
  3989. {
  3990. GGML_ABORT("fatal error");
  3991. }
  3992. }
  3993. }
  3994. void ggml_compute_forward_diag_mask_zero(
  3995. const ggml_compute_params * params,
  3996. ggml_tensor * dst) {
  3997. const ggml_tensor * src0 = dst->src[0];
  3998. switch (src0->type) {
  3999. case GGML_TYPE_F32:
  4000. {
  4001. ggml_compute_forward_diag_mask_f32(params, dst, 0);
  4002. } break;
  4003. default:
  4004. {
  4005. GGML_ABORT("fatal error");
  4006. }
  4007. }
  4008. }
  4009. // ggml_compute_forward_soft_max
  4010. static void ggml_compute_forward_soft_max_f32(
  4011. const ggml_compute_params * params,
  4012. ggml_tensor * dst) {
  4013. const ggml_tensor * src0 = dst->src[0];
  4014. const ggml_tensor * src1 = dst->src[1];
  4015. const ggml_tensor * src2 = dst->src[2];
  4016. assert(ggml_is_contiguous(dst));
  4017. assert(ggml_are_same_shape(src0, dst));
  4018. float scale = 1.0f;
  4019. float max_bias = 0.0f;
  4020. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  4021. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  4022. const int ith = params->ith;
  4023. const int nth = params->nth;
  4024. GGML_TENSOR_UNARY_OP_LOCALS
  4025. const int64_t nb11 = src1 ? src1->nb[1] : 1;
  4026. const int64_t nb12 = src1 ? src1->nb[2] : 1;
  4027. const int64_t nb13 = src1 ? src1->nb[3] : 1;
  4028. const int64_t ne12 = src1 ? src1->ne[2] : 1;
  4029. const int64_t ne13 = src1 ? src1->ne[3] : 1;
  4030. // TODO: is this supposed to be ceil instead of floor?
  4031. // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
  4032. const uint32_t n_head = ne02;
  4033. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  4034. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  4035. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  4036. float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  4037. const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
  4038. // sinks
  4039. const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
  4040. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4041. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4042. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  4043. const int64_t i11 = i01;
  4044. const int64_t i12 = i02%ne12;
  4045. const int64_t i13 = i03%ne13;
  4046. // ALiBi
  4047. const uint32_t h = i02; // head
  4048. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  4049. float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  4050. float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  4051. // broadcast the mask across rows
  4052. ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
  4053. float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
  4054. ggml_vec_cpy_f32 (ne00, wp, sp);
  4055. ggml_vec_scale_f32(ne00, wp, scale);
  4056. if (mp_f32) {
  4057. if (use_f16) {
  4058. for (int i = 0; i < ne00; ++i) {
  4059. wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
  4060. }
  4061. } else {
  4062. for (int i = 0; i < ne00; ++i) {
  4063. wp[i] += slope*mp_f32[i];
  4064. }
  4065. }
  4066. }
  4067. #ifndef NDEBUG
  4068. for (int i = 0; i < ne00; ++i) {
  4069. //printf("p[%d] = %f\n", i, p[i]);
  4070. assert(!isnan(wp[i]));
  4071. }
  4072. #endif
  4073. float max = -INFINITY;
  4074. ggml_vec_max_f32(ne00, &max, wp);
  4075. // if we have sinks, make a correction as if they were included in the softmax
  4076. if (sk) {
  4077. max = MAX(max, sk[i02]);
  4078. }
  4079. ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
  4080. assert(sum > 0.0);
  4081. if (sk) {
  4082. sum += (ggml_float) expf(sk[i02] - max);
  4083. }
  4084. sum = 1.0/sum;
  4085. ggml_vec_scale_f32(ne00, dp, sum);
  4086. #ifndef NDEBUG
  4087. for (int i = 0; i < ne00; ++i) {
  4088. assert(!isnan(dp[i]));
  4089. assert(!isinf(dp[i]));
  4090. }
  4091. #endif
  4092. }
  4093. }
  4094. }
  4095. }
  4096. void ggml_compute_forward_soft_max(
  4097. const ggml_compute_params * params,
  4098. ggml_tensor * dst) {
  4099. const ggml_tensor * src0 = dst->src[0];
  4100. switch (src0->type) {
  4101. case GGML_TYPE_F32:
  4102. {
  4103. ggml_compute_forward_soft_max_f32(params, dst);
  4104. } break;
  4105. default:
  4106. {
  4107. GGML_ABORT("fatal error");
  4108. }
  4109. }
  4110. }
  4111. // ggml_compute_forward_soft_max_ext_back
  4112. static void ggml_compute_forward_soft_max_ext_back_f32(
  4113. const ggml_compute_params * params,
  4114. ggml_tensor * dst) {
  4115. const ggml_tensor * src0 = dst->src[0];
  4116. const ggml_tensor * src1 = dst->src[1];
  4117. GGML_ASSERT(ggml_is_contiguous(src0));
  4118. GGML_ASSERT(ggml_is_contiguous(src1));
  4119. GGML_ASSERT(ggml_is_contiguous(dst));
  4120. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4121. GGML_ASSERT(ggml_are_same_shape(src1, dst));
  4122. float scale = 1.0f;
  4123. float max_bias = 0.0f;
  4124. memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
  4125. memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
  4126. GGML_ASSERT(max_bias == 0.0f);
  4127. // TODO: handle transposed/permuted matrices
  4128. const int ith = params->ith;
  4129. const int nth = params->nth;
  4130. const int nc = src0->ne[0];
  4131. const int nr = ggml_nrows(src0);
  4132. // rows per thread
  4133. const int dr = (nr + nth - 1)/nth;
  4134. // row range for this thread
  4135. const int ir0 = dr*ith;
  4136. const int ir1 = MIN(ir0 + dr, nr);
  4137. for (int i1 = ir0; i1 < ir1; i1++) {
  4138. float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
  4139. float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
  4140. float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
  4141. #ifndef NDEBUG
  4142. for (int i = 0; i < nc; ++i) {
  4143. //printf("p[%d] = %f\n", i, p[i]);
  4144. assert(!isnan(dy[i]));
  4145. assert(!isnan(y[i]));
  4146. }
  4147. #endif
  4148. // Jii = yi - yi*yi
  4149. // Jij = -yi*yj
  4150. // J = diag(y)-y.T*y
  4151. // dx = J * dy
  4152. // dxk = sum_i(Jki * dyi)
  4153. // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
  4154. // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
  4155. // dxk = sum_i(-yk*yi * dyi) + yk*dyk
  4156. // dxk = -yk * sum_i(yi * dyi) + yk*dyk
  4157. // dxk = -yk * dot(y, dy) + yk*dyk
  4158. // dxk = yk * (- dot(y, dy) + dyk)
  4159. // dxk = yk * (dyk - dot(y, dy))
  4160. //
  4161. // post-order:
  4162. // dot_y_dy := dot(y, dy)
  4163. // dx := dy
  4164. // dx := dx - dot_y_dy
  4165. // dx := dx * y
  4166. // linear runtime, no additional memory
  4167. float dot_y_dy = 0;
  4168. ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
  4169. ggml_vec_cpy_f32 (nc, dx, dy);
  4170. ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
  4171. ggml_vec_mul_f32 (nc, dx, dx, y);
  4172. ggml_vec_scale_f32(nc, dx, scale);
  4173. #ifndef NDEBUG
  4174. for (int i = 0; i < nc; ++i) {
  4175. assert(!isnan(dx[i]));
  4176. assert(!isinf(dx[i]));
  4177. }
  4178. #endif
  4179. }
  4180. }
  4181. void ggml_compute_forward_soft_max_ext_back(
  4182. const ggml_compute_params * params,
  4183. ggml_tensor * dst) {
  4184. const ggml_tensor * src0 = dst->src[0];
  4185. switch (src0->type) {
  4186. case GGML_TYPE_F32:
  4187. {
  4188. ggml_compute_forward_soft_max_ext_back_f32(params, dst);
  4189. } break;
  4190. default:
  4191. {
  4192. GGML_ABORT("fatal error");
  4193. }
  4194. }
  4195. }
  4196. // ggml_compute_forward_clamp
  4197. static void ggml_compute_forward_clamp_f32(
  4198. const ggml_compute_params * params,
  4199. ggml_tensor * dst) {
  4200. const ggml_tensor * src0 = dst->src[0];
  4201. float min;
  4202. float max;
  4203. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  4204. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  4205. const int ith = params->ith;
  4206. const int nth = params->nth;
  4207. const int n = ggml_nrows(src0);
  4208. const int nc = src0->ne[0];
  4209. const size_t nb00 = src0->nb[0];
  4210. const size_t nb01 = src0->nb[1];
  4211. const size_t nb0 = dst->nb[0];
  4212. const size_t nb1 = dst->nb[1];
  4213. GGML_ASSERT( nb0 == sizeof(float));
  4214. GGML_ASSERT(nb00 == sizeof(float));
  4215. for (int j = ith; j < n; j += nth) {
  4216. float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
  4217. float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
  4218. for (int i = 0; i < nc; i++) {
  4219. dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
  4220. }
  4221. }
  4222. }
  4223. static void ggml_compute_forward_clamp_f16(
  4224. const ggml_compute_params * params,
  4225. ggml_tensor * dst) {
  4226. const ggml_tensor * src0 = dst->src[0];
  4227. float min;
  4228. float max;
  4229. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  4230. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  4231. const int ith = params->ith;
  4232. const int nth = params->nth;
  4233. const int n = ggml_nrows(src0);
  4234. const int nc = src0->ne[0];
  4235. const size_t nb00 = src0->nb[0];
  4236. const size_t nb01 = src0->nb[1];
  4237. const size_t nb0 = dst->nb[0];
  4238. const size_t nb1 = dst->nb[1];
  4239. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  4240. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4241. for (int j = ith; j < n; j += nth) {
  4242. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
  4243. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
  4244. for (int i = 0; i < nc; i++) {
  4245. float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
  4246. dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
  4247. }
  4248. }
  4249. }
  4250. void ggml_compute_forward_clamp(
  4251. const ggml_compute_params * params,
  4252. ggml_tensor * dst) {
  4253. const ggml_tensor * src0 = dst->src[0];
  4254. switch (src0->type) {
  4255. case GGML_TYPE_F32:
  4256. {
  4257. ggml_compute_forward_clamp_f32(params, dst);
  4258. } break;
  4259. case GGML_TYPE_F16:
  4260. {
  4261. ggml_compute_forward_clamp_f16(params, dst);
  4262. } break;
  4263. case GGML_TYPE_BF16:
  4264. case GGML_TYPE_Q4_0:
  4265. case GGML_TYPE_Q4_1:
  4266. case GGML_TYPE_Q5_0:
  4267. case GGML_TYPE_Q5_1:
  4268. case GGML_TYPE_Q8_0:
  4269. case GGML_TYPE_Q8_1:
  4270. case GGML_TYPE_MXFP4:
  4271. case GGML_TYPE_Q2_K:
  4272. case GGML_TYPE_Q3_K:
  4273. case GGML_TYPE_Q4_K:
  4274. case GGML_TYPE_Q5_K:
  4275. case GGML_TYPE_Q6_K:
  4276. case GGML_TYPE_TQ1_0:
  4277. case GGML_TYPE_TQ2_0:
  4278. case GGML_TYPE_IQ2_XXS:
  4279. case GGML_TYPE_IQ2_XS:
  4280. case GGML_TYPE_IQ3_XXS:
  4281. case GGML_TYPE_IQ1_S:
  4282. case GGML_TYPE_IQ1_M:
  4283. case GGML_TYPE_IQ4_NL:
  4284. case GGML_TYPE_IQ4_XS:
  4285. case GGML_TYPE_IQ3_S:
  4286. case GGML_TYPE_IQ2_S:
  4287. case GGML_TYPE_Q8_K:
  4288. case GGML_TYPE_I8:
  4289. case GGML_TYPE_I16:
  4290. case GGML_TYPE_I32:
  4291. case GGML_TYPE_I64:
  4292. case GGML_TYPE_F64:
  4293. case GGML_TYPE_COUNT:
  4294. {
  4295. GGML_ABORT("fatal error");
  4296. }
  4297. }
  4298. }
  4299. // ggml_compute_forward_rope
  4300. static float rope_yarn_ramp(const float low, const float high, const int i0) {
  4301. const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
  4302. return 1 - MIN(1, MAX(0, y));
  4303. }
  4304. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  4305. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  4306. static void rope_yarn(
  4307. float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
  4308. float * cos_theta, float * sin_theta) {
  4309. // Get n-d rotational scaling corrected for extrapolation
  4310. float theta_interp = freq_scale * theta_extrap;
  4311. float theta = theta_interp;
  4312. if (ext_factor != 0.0f) {
  4313. float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
  4314. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  4315. // Get n-d magnitude scaling corrected for interpolation
  4316. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  4317. }
  4318. *cos_theta = cosf(theta) * mscale;
  4319. *sin_theta = sinf(theta) * mscale;
  4320. }
  4321. static void ggml_rope_cache_init(
  4322. float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  4323. float * cache, float sin_sign, float theta_scale) {
  4324. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  4325. float theta = theta_base;
  4326. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  4327. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  4328. rope_yarn(
  4329. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  4330. );
  4331. cache[i0 + 1] *= sin_sign;
  4332. theta *= theta_scale;
  4333. }
  4334. }
  4335. static void ggml_mrope_cache_init(
  4336. float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
  4337. float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  4338. float * cache, float sin_sign, float theta_scale) {
  4339. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  4340. float theta_t = theta_base_t;
  4341. float theta_h = theta_base_h;
  4342. float theta_w = theta_base_w;
  4343. float theta_e = theta_base_e; // extra position id for vision encoder
  4344. int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
  4345. int sec_w = sections[1] + sections[0];
  4346. int sec_e = sections[2] + sec_w;
  4347. GGML_ASSERT(sect_dims <= ne0);
  4348. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  4349. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  4350. int sector = (i0 / 2) % sect_dims;
  4351. if (indep_sects) {
  4352. // compute theta independently for each dim sections
  4353. // (i.e. reset corresponding theta when `i0` go from one section to another)
  4354. if (sector == 0) {
  4355. theta_t = theta_base_t;
  4356. }
  4357. else if (sector == sections[0]) {
  4358. theta_h = theta_base_h;;
  4359. }
  4360. else if (sector == sec_w) {
  4361. theta_w = theta_base_w;
  4362. }
  4363. else if (sector == sec_e) {
  4364. theta_e = theta_base_e;
  4365. }
  4366. }
  4367. float theta = theta_t;
  4368. if (is_imrope) { // qwen3vl apply interleaved mrope
  4369. if (sector % 3 == 1 && sector < 3 * sections[1]) {
  4370. theta = theta_h;
  4371. } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
  4372. theta = theta_w;
  4373. } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
  4374. theta = theta_t;
  4375. } else {
  4376. theta = theta_e;
  4377. }
  4378. } else {
  4379. if (sector >= sections[0] && sector < sec_w) {
  4380. theta = theta_h;
  4381. }
  4382. else if (sector >= sec_w && sector < sec_w + sections[2]) {
  4383. theta = theta_w;
  4384. }
  4385. else if (sector >= sec_w + sections[2]) {
  4386. theta = theta_e;
  4387. }
  4388. }
  4389. rope_yarn(
  4390. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  4391. );
  4392. cache[i0 + 1] *= sin_sign;
  4393. theta_t *= theta_scale;
  4394. theta_w *= theta_scale;
  4395. theta_h *= theta_scale;
  4396. theta_e *= theta_scale;
  4397. }
  4398. }
  4399. template<typename T>
  4400. static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
  4401. for (int64_t i0 = 0; i0 < n; i0 += 2) {
  4402. const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
  4403. const float cos_theta = cache[i0 + 0];
  4404. const float sin_theta = cache[i0 + 1];
  4405. const T * const src = src_data + ic;
  4406. T * dst = dst_data + ic;
  4407. const float x0 = type_conversion_table<T>::to_f32(src[0]);
  4408. const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
  4409. dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
  4410. dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
  4411. }
  4412. }
  4413. template<typename T> //float or ggml_fp16_t
  4414. static void ggml_compute_forward_rope_flt(
  4415. const ggml_compute_params * params,
  4416. ggml_tensor * dst,
  4417. const bool forward) {
  4418. const ggml_tensor * src0 = dst->src[0];
  4419. const ggml_tensor * src1 = dst->src[1];
  4420. const ggml_tensor * src2 = dst->src[2];
  4421. GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
  4422. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  4423. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  4424. int sections[4];
  4425. //const int n_past = ((int32_t *) dst->op_params)[0];
  4426. const int n_dims = ((int32_t *) dst->op_params)[1];
  4427. const int mode = ((int32_t *) dst->op_params)[2];
  4428. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  4429. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  4430. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  4431. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  4432. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  4433. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  4434. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  4435. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  4436. memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
  4437. GGML_TENSOR_UNARY_OP_LOCALS
  4438. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  4439. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  4440. GGML_ASSERT(nb0 == nb00);
  4441. GGML_ASSERT(nb0 == sizeof(T));
  4442. const int ith = params->ith;
  4443. const int nth = params->nth;
  4444. const int nr = ggml_nrows(dst);
  4445. GGML_ASSERT(n_dims <= ne0);
  4446. GGML_ASSERT(n_dims % 2 == 0);
  4447. // rows per thread
  4448. const int dr = (nr + nth - 1)/nth;
  4449. // row range for this thread
  4450. const int ir0 = dr*ith;
  4451. const int ir1 = MIN(ir0 + dr, nr);
  4452. // row index used to determine which thread to use
  4453. int ir = 0;
  4454. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  4455. float corr_dims[2];
  4456. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  4457. const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
  4458. const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
  4459. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  4460. if (mrope_used) {
  4461. GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
  4462. }
  4463. if (is_vision) {
  4464. GGML_ASSERT(n_dims == ne0/2);
  4465. }
  4466. const float * freq_factors = NULL;
  4467. if (src2 != NULL) {
  4468. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  4469. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  4470. freq_factors = (const float *) src2->data;
  4471. }
  4472. // backward process uses inverse rotation by cos and sin.
  4473. // cos and sin build a rotation matrix, where the inverse is the transpose.
  4474. // this essentially just switches the sign of sin.
  4475. const float sin_sign = forward ? 1.0f : -1.0f;
  4476. const int32_t * pos = (const int32_t *) src1->data;
  4477. for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
  4478. for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
  4479. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  4480. if (!mrope_used) {
  4481. const int64_t p = pos[i2];
  4482. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4483. }
  4484. else {
  4485. const int64_t p_t = pos[i2];
  4486. const int64_t p_h = pos[i2 + ne2];
  4487. const int64_t p_w = pos[i2 + ne2 * 2];
  4488. const int64_t p_e = pos[i2 + ne2 * 3];
  4489. ggml_mrope_cache_init(
  4490. p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
  4491. freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4492. }
  4493. for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
  4494. if (ir++ < ir0) continue;
  4495. if (ir > ir1) break;
  4496. T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  4497. T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  4498. switch (mode) {
  4499. case GGML_ROPE_TYPE_NORMAL:
  4500. rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
  4501. break;
  4502. case GGML_ROPE_TYPE_NEOX:
  4503. case GGML_ROPE_TYPE_MROPE:
  4504. case GGML_ROPE_TYPE_IMROPE:
  4505. rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
  4506. break;
  4507. case GGML_ROPE_TYPE_VISION:
  4508. rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
  4509. break;
  4510. default:
  4511. GGML_ABORT("rope type not supported");
  4512. }
  4513. if (!is_vision) {
  4514. // fill the remain channels with data from src tensor
  4515. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4516. const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4517. T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4518. dst_data[0] = src[0];
  4519. dst_data[1] = src[1];
  4520. }
  4521. }
  4522. } //attn-heads
  4523. }
  4524. }
  4525. }
  4526. void ggml_compute_forward_rope(
  4527. const ggml_compute_params * params,
  4528. ggml_tensor * dst) {
  4529. const ggml_tensor * src0 = dst->src[0];
  4530. switch (src0->type) {
  4531. case GGML_TYPE_F16:
  4532. {
  4533. ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
  4534. } break;
  4535. case GGML_TYPE_F32:
  4536. {
  4537. ggml_compute_forward_rope_flt<float>(params, dst, true);
  4538. } break;
  4539. default:
  4540. {
  4541. GGML_ABORT("fatal error");
  4542. }
  4543. }
  4544. }
  4545. // ggml_compute_forward_rope_back
  4546. void ggml_compute_forward_rope_back(
  4547. const ggml_compute_params * params,
  4548. ggml_tensor * dst) {
  4549. const ggml_tensor * src0 = dst->src[0];
  4550. switch (src0->type) {
  4551. case GGML_TYPE_F16:
  4552. {
  4553. ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
  4554. } break;
  4555. case GGML_TYPE_F32:
  4556. {
  4557. ggml_compute_forward_rope_flt<float>(params, dst, false);
  4558. } break;
  4559. default:
  4560. {
  4561. GGML_ABORT("fatal error");
  4562. }
  4563. }
  4564. }
  4565. // ggml_compute_forward_conv_transpose_1d
  4566. static void ggml_compute_forward_conv_transpose_1d_f16_f32(
  4567. const ggml_compute_params * params,
  4568. ggml_tensor * dst) {
  4569. const ggml_tensor * src0 = dst->src[0];
  4570. const ggml_tensor * src1 = dst->src[1];
  4571. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4572. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4573. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4574. GGML_TENSOR_BINARY_OP_LOCALS
  4575. const int ith = params->ith;
  4576. const int nth = params->nth;
  4577. const int nk = ne00*ne01*ne02;
  4578. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4579. GGML_ASSERT(nb10 == sizeof(float));
  4580. if (ith == 0) {
  4581. memset(params->wdata, 0, params->wsize);
  4582. // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  4583. {
  4584. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  4585. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4586. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4587. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
  4588. ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
  4589. for (int64_t i00 = 0; i00 < ne00; i00++) {
  4590. dst_data[i00*ne02 + i02] = src[i00];
  4591. }
  4592. }
  4593. }
  4594. }
  4595. // permute source data (src1) from (L x Cin) to (Cin x L)
  4596. {
  4597. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  4598. ggml_fp16_t * dst_data = wdata;
  4599. for (int64_t i11 = 0; i11 < ne11; i11++) {
  4600. const float * const src = (float *)((char *) src1->data + i11*nb11);
  4601. for (int64_t i10 = 0; i10 < ne10; i10++) {
  4602. dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
  4603. }
  4604. }
  4605. }
  4606. // need to zero dst since we are accumulating into it
  4607. memset(dst->data, 0, ggml_nbytes(dst));
  4608. }
  4609. ggml_barrier(params->threadpool);
  4610. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  4611. // total rows in dst
  4612. const int nr = ne1;
  4613. // rows per thread
  4614. const int dr = (nr + nth - 1)/nth;
  4615. // row range for this thread
  4616. const int ir0 = dr*ith;
  4617. const int ir1 = MIN(ir0 + dr, nr);
  4618. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  4619. ggml_fp16_t * const wdata_src = wdata + nk;
  4620. for (int i1 = ir0; i1 < ir1; i1++) {
  4621. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  4622. ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
  4623. for (int i10 = 0; i10 < ne10; i10++) {
  4624. const int i1n = i10*ne11;
  4625. for (int i00 = 0; i00 < ne00; i00++) {
  4626. float v = 0;
  4627. ggml_vec_dot_f16(ne02, &v, 0,
  4628. (ggml_fp16_t *) wdata_src + i1n, 0,
  4629. (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
  4630. dst_data[i10*s0 + i00] += v;
  4631. }
  4632. }
  4633. }
  4634. }
  4635. static void ggml_compute_forward_conv_transpose_1d_f32(
  4636. const ggml_compute_params * params,
  4637. ggml_tensor * dst) {
  4638. const ggml_tensor * src0 = dst->src[0];
  4639. const ggml_tensor * src1 = dst->src[1];
  4640. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  4641. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4642. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4643. GGML_TENSOR_BINARY_OP_LOCALS
  4644. const int ith = params->ith;
  4645. const int nth = params->nth;
  4646. const int nk = ne00*ne01*ne02;
  4647. GGML_ASSERT(nb00 == sizeof(float));
  4648. GGML_ASSERT(nb10 == sizeof(float));
  4649. if (ith == 0) {
  4650. memset(params->wdata, 0, params->wsize);
  4651. // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  4652. {
  4653. float * const wdata = (float *) params->wdata + 0;
  4654. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4655. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4656. const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
  4657. float * dst_data = wdata + i01*ne00*ne02;
  4658. for (int64_t i00 = 0; i00 < ne00; i00++) {
  4659. dst_data[i00*ne02 + i02] = src[i00];
  4660. }
  4661. }
  4662. }
  4663. }
  4664. // prepare source data (src1)
  4665. {
  4666. float * const wdata = (float *) params->wdata + nk;
  4667. float * dst_data = wdata;
  4668. for (int64_t i11 = 0; i11 < ne11; i11++) {
  4669. const float * const src = (float *)((char *) src1->data + i11*nb11);
  4670. for (int64_t i10 = 0; i10 < ne10; i10++) {
  4671. dst_data[i10*ne11 + i11] = src[i10];
  4672. }
  4673. }
  4674. }
  4675. // need to zero dst since we are accumulating into it
  4676. memset(dst->data, 0, ggml_nbytes(dst));
  4677. }
  4678. ggml_barrier(params->threadpool);
  4679. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  4680. // total rows in dst
  4681. const int nr = ne1;
  4682. // rows per thread
  4683. const int dr = (nr + nth - 1)/nth;
  4684. // row range for this thread
  4685. const int ir0 = dr*ith;
  4686. const int ir1 = MIN(ir0 + dr, nr);
  4687. float * const wdata = (float *) params->wdata + 0;
  4688. float * const wdata_src = wdata + nk;
  4689. for (int i1 = ir0; i1 < ir1; i1++) {
  4690. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  4691. float * wdata_kernel = wdata + i1*ne02*ne00;
  4692. for (int i10 = 0; i10 < ne10; i10++) {
  4693. const int i1n = i10*ne11;
  4694. for (int i00 = 0; i00 < ne00; i00++) {
  4695. float v = 0;
  4696. ggml_vec_dot_f32(ne02, &v, 0,
  4697. wdata_src + i1n, 0,
  4698. wdata_kernel + i00*ne02, 0, 1);
  4699. dst_data[i10*s0 + i00] += v;
  4700. }
  4701. }
  4702. }
  4703. }
  4704. void ggml_compute_forward_conv_transpose_1d(
  4705. const ggml_compute_params * params,
  4706. ggml_tensor * dst) {
  4707. const ggml_tensor * src0 = dst->src[0];
  4708. switch (src0->type) {
  4709. case GGML_TYPE_F16:
  4710. {
  4711. ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
  4712. } break;
  4713. case GGML_TYPE_F32:
  4714. {
  4715. ggml_compute_forward_conv_transpose_1d_f32(params, dst);
  4716. } break;
  4717. default:
  4718. {
  4719. GGML_ABORT("fatal error");
  4720. }
  4721. }
  4722. }
  4723. // ggml_compute_forward_im2col_f32
  4724. // src0: kernel [OC, IC, KH, KW]
  4725. // src1: image [N, IC, IH, IW]
  4726. // dst: result [N, OH, OW, IC*KH*KW]
  4727. static void ggml_compute_forward_im2col_f32(
  4728. const ggml_compute_params * params,
  4729. ggml_tensor * dst) {
  4730. const ggml_tensor * src0 = dst->src[0];
  4731. const ggml_tensor * src1 = dst->src[1];
  4732. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4733. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4734. GGML_TENSOR_BINARY_OP_LOCALS;
  4735. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4736. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4737. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  4738. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  4739. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  4740. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  4741. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  4742. const int ith = params->ith;
  4743. const int nth = params->nth;
  4744. const int64_t N = is_2D ? ne13 : ne12;
  4745. const int64_t IC = is_2D ? ne12 : ne11;
  4746. const int64_t IH = is_2D ? ne11 : 1;
  4747. const int64_t IW = ne10;
  4748. const int64_t KH = is_2D ? ne01 : 1;
  4749. const int64_t KW = ne00;
  4750. const int64_t OH = is_2D ? ne2 : 1;
  4751. const int64_t OW = ne1;
  4752. int ofs0 = is_2D ? nb13 : nb12;
  4753. int ofs1 = is_2D ? nb12 : nb11;
  4754. GGML_ASSERT(nb10 == sizeof(float));
  4755. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  4756. {
  4757. float * const wdata = (float *) dst->data;
  4758. for (int64_t in = 0; in < N; in++) {
  4759. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  4760. for (int64_t iow = 0; iow < OW; iow++) {
  4761. for (int64_t iic = ith; iic < IC; iic += nth) {
  4762. // micro kernel
  4763. float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  4764. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  4765. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  4766. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4767. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  4768. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  4769. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  4770. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  4771. } else {
  4772. dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
  4773. }
  4774. }
  4775. }
  4776. }
  4777. }
  4778. }
  4779. }
  4780. }
  4781. }
  4782. // ggml_compute_forward_im2col_f16
  4783. // src0: kernel [OC, IC, KH, KW]
  4784. // src1: image [N, IC, IH, IW]
  4785. // dst: result [N, OH, OW, IC*KH*KW]
  4786. static void ggml_compute_forward_im2col_f16(
  4787. const ggml_compute_params * params,
  4788. ggml_tensor * dst) {
  4789. const ggml_tensor * src0 = dst->src[0];
  4790. const ggml_tensor * src1 = dst->src[1];
  4791. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4792. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4793. GGML_ASSERT( dst->type == GGML_TYPE_F16);
  4794. GGML_TENSOR_BINARY_OP_LOCALS;
  4795. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4796. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4797. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  4798. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  4799. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  4800. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  4801. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  4802. const int ith = params->ith;
  4803. const int nth = params->nth;
  4804. const int64_t N = is_2D ? ne13 : ne12;
  4805. const int64_t IC = is_2D ? ne12 : ne11;
  4806. const int64_t IH = is_2D ? ne11 : 1;
  4807. const int64_t IW = ne10;
  4808. const int64_t KH = is_2D ? ne01 : 1;
  4809. const int64_t KW = ne00;
  4810. const int64_t OH = is_2D ? ne2 : 1;
  4811. const int64_t OW = ne1;
  4812. int ofs0 = is_2D ? nb13 : nb12;
  4813. int ofs1 = is_2D ? nb12 : nb11;
  4814. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4815. GGML_ASSERT(nb10 == sizeof(float));
  4816. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  4817. {
  4818. ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
  4819. for (int64_t in = 0; in < N; in++) {
  4820. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  4821. for (int64_t iow = 0; iow < OW; iow++) {
  4822. for (int64_t iic = ith; iic < IC; iic += nth) {
  4823. // micro kernel
  4824. ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  4825. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  4826. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  4827. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4828. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  4829. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  4830. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  4831. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  4832. } else {
  4833. dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
  4834. }
  4835. }
  4836. }
  4837. }
  4838. }
  4839. }
  4840. }
  4841. }
  4842. }
  4843. void ggml_compute_forward_im2col(
  4844. const ggml_compute_params * params,
  4845. ggml_tensor * dst) {
  4846. switch (dst->type) {
  4847. case GGML_TYPE_F16:
  4848. {
  4849. ggml_compute_forward_im2col_f16(params, dst);
  4850. } break;
  4851. case GGML_TYPE_F32:
  4852. {
  4853. ggml_compute_forward_im2col_f32(params, dst);
  4854. } break;
  4855. default:
  4856. {
  4857. GGML_ABORT("fatal error");
  4858. }
  4859. }
  4860. }
  4861. // ggml_compute_forward_im2col_back_f32
  4862. void ggml_compute_forward_im2col_back_f32(
  4863. const ggml_compute_params * params,
  4864. ggml_tensor * dst) {
  4865. const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
  4866. const ggml_tensor * src1 = dst->src[1]; // convolution kernel
  4867. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  4868. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4869. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4870. GGML_TENSOR_BINARY_OP_LOCALS;
  4871. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4872. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4873. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  4874. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  4875. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  4876. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  4877. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  4878. const int ith = params->ith;
  4879. const int nth = params->nth;
  4880. const int64_t N = is_2D ? ne3 : ne2;
  4881. const int64_t IC = is_2D ? ne2 : ne1;
  4882. const int64_t IH = is_2D ? ne1 : 1;
  4883. const int64_t IW = ne0;
  4884. const int64_t KH = is_2D ? ne11 : 1;
  4885. const int64_t KW = ne10;
  4886. const int64_t OH = is_2D ? ne02 : 1;
  4887. const int64_t OW = ne01;
  4888. int ofs0 = is_2D ? nb3 : nb2;
  4889. int ofs1 = is_2D ? nb2 : nb1;
  4890. GGML_ASSERT(nb0 == sizeof(float));
  4891. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  4892. {
  4893. float * const wdata = (float *) dst->data;
  4894. for (int64_t in = 0; in < N; in++) {
  4895. for (int64_t iic = ith; iic < IC; iic += nth) {
  4896. for (int64_t iih = 0; iih < IH; iih++) {
  4897. for (int64_t iiw = 0; iiw < IW; iiw++) {
  4898. // micro kernel
  4899. float grad = 0.0f;
  4900. for (int64_t ikh = 0; ikh < KH; ikh++) {
  4901. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4902. // For s0 > 1 some values were skipped over in the forward pass.
  4903. // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
  4904. const int64_t tmpw = (iiw + p0 - ikw*d0);
  4905. if (tmpw % s0 != 0) {
  4906. continue;
  4907. }
  4908. const int64_t iow = tmpw / s0;
  4909. // Equivalent logic as above except for s1.
  4910. int64_t ioh;
  4911. if (is_2D) {
  4912. const int64_t tmph = iih + p1 - ikh*d1;
  4913. if (tmph % s1 != 0) {
  4914. continue;
  4915. }
  4916. ioh = tmph / s1;
  4917. } else {
  4918. ioh = 0;
  4919. }
  4920. if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
  4921. continue;
  4922. }
  4923. const float * const grad_in = (const float *) src0->data
  4924. + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  4925. grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
  4926. }
  4927. }
  4928. float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
  4929. dst_data[iih*IW + iiw] = grad;
  4930. }
  4931. }
  4932. }
  4933. }
  4934. }
  4935. }
  4936. // ggml_compute_forward_im2col_3d_f16
  4937. // src0: kernel [OC*IC, KD, KH, KW]
  4938. // src1: image [N*IC, ID, IH, IW]
  4939. // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
  4940. static void ggml_compute_forward_im2col_3d_f16(
  4941. const ggml_compute_params * params,
  4942. ggml_tensor * dst) {
  4943. const ggml_tensor * src0 = dst->src[0];
  4944. const ggml_tensor * src1 = dst->src[1];
  4945. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4946. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4947. GGML_ASSERT( dst->type == GGML_TYPE_F16);
  4948. GGML_TENSOR_BINARY_OP_LOCALS;
  4949. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4950. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4951. const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
  4952. const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
  4953. const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
  4954. const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
  4955. const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
  4956. const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
  4957. const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
  4958. const int32_t IC = ((const int32_t *)(dst->op_params))[9];
  4959. const int ith = params->ith;
  4960. const int nth = params->nth;
  4961. const int64_t N = ne13 / IC;
  4962. const int64_t ID = ne12;
  4963. const int64_t IH = ne11;
  4964. const int64_t IW = ne10;
  4965. const int64_t OC = ne03 / IC;
  4966. GGML_UNUSED(OC);
  4967. const int64_t KD = ne02;
  4968. const int64_t KH = ne01;
  4969. const int64_t KW = ne00;
  4970. const int64_t OD = ne3 / N;
  4971. const int64_t OH = ne2;
  4972. const int64_t OW = ne1;
  4973. const int64_t OH_OW = OH*OW;
  4974. const int64_t KD_KH_KW = KD*KH*KW;
  4975. const int64_t KH_KW = KH*KW;
  4976. const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
  4977. GGML_ASSERT(nb10 == sizeof(float));
  4978. // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
  4979. {
  4980. ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
  4981. for (int64_t in = 0; in < N; in++) {
  4982. for (int64_t iod = 0; iod < OD; iod++) {
  4983. for (int64_t ioh = 0; ioh < OH; ioh++) {
  4984. for (int64_t iow = 0; iow < OW; iow++) {
  4985. for (int64_t iic = ith; iic < IC; iic += nth) {
  4986. // micro kernel
  4987. ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
  4988. const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
  4989. for (int64_t ikd = 0; ikd < KD; ikd++) {
  4990. for (int64_t ikh = 0; ikh < KH; ikh++) {
  4991. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4992. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  4993. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  4994. const int64_t iid = iod*s2 + ikd*d2 - p2;
  4995. if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
  4996. dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
  4997. } else {
  4998. const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
  4999. dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
  5000. }
  5001. }
  5002. }
  5003. }
  5004. }
  5005. }
  5006. }
  5007. }
  5008. }
  5009. }
  5010. }
  5011. // ggml_compute_forward_im2col_3d_f32
  5012. // src0: kernel [OC*IC, KD, KH, KW]
  5013. // src1: image [N*IC, ID, IH, IW]
  5014. // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
  5015. static void ggml_compute_forward_im2col_3d_f32(
  5016. const ggml_compute_params * params,
  5017. ggml_tensor * dst) {
  5018. const ggml_tensor * src0 = dst->src[0];
  5019. const ggml_tensor * src1 = dst->src[1];
  5020. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5021. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5022. GGML_TENSOR_BINARY_OP_LOCALS;
  5023. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  5024. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  5025. const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
  5026. const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
  5027. const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
  5028. const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
  5029. const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
  5030. const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
  5031. const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
  5032. const int32_t IC = ((const int32_t *)(dst->op_params))[9];
  5033. const int ith = params->ith;
  5034. const int nth = params->nth;
  5035. const int64_t N = ne13 / IC;
  5036. const int64_t ID = ne12;
  5037. const int64_t IH = ne11;
  5038. const int64_t IW = ne10;
  5039. const int64_t OC = ne03 / IC;
  5040. GGML_UNUSED(OC);
  5041. const int64_t KD = ne02;
  5042. const int64_t KH = ne01;
  5043. const int64_t KW = ne00;
  5044. const int64_t OD = ne3 / N;
  5045. const int64_t OH = ne2;
  5046. const int64_t OW = ne1;
  5047. const int64_t OH_OW = OH*OW;
  5048. const int64_t KD_KH_KW = KD*KH*KW;
  5049. const int64_t KH_KW = KH*KW;
  5050. const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
  5051. GGML_ASSERT(nb10 == sizeof(float));
  5052. // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
  5053. {
  5054. float * const wdata = (float *) dst->data;
  5055. for (int64_t in = 0; in < N; in++) {
  5056. for (int64_t iod = 0; iod < OD; iod++) {
  5057. for (int64_t ioh = 0; ioh < OH; ioh++) {
  5058. for (int64_t iow = 0; iow < OW; iow++) {
  5059. for (int64_t iic = ith; iic < IC; iic += nth) {
  5060. // micro kernel
  5061. float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
  5062. const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
  5063. for (int64_t ikd = 0; ikd < KD; ikd++) {
  5064. for (int64_t ikh = 0; ikh < KH; ikh++) {
  5065. for (int64_t ikw = 0; ikw < KW; ikw++) {
  5066. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  5067. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  5068. const int64_t iid = iod*s2 + ikd*d2 - p2;
  5069. if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
  5070. dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
  5071. } else {
  5072. const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
  5073. dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
  5074. }
  5075. }
  5076. }
  5077. }
  5078. }
  5079. }
  5080. }
  5081. }
  5082. }
  5083. }
  5084. }
  5085. void ggml_compute_forward_im2col_3d(
  5086. const ggml_compute_params * params,
  5087. ggml_tensor * dst) {
  5088. switch (dst->type) {
  5089. case GGML_TYPE_F16:
  5090. {
  5091. ggml_compute_forward_im2col_3d_f16(params, dst);
  5092. } break;
  5093. case GGML_TYPE_F32:
  5094. {
  5095. ggml_compute_forward_im2col_3d_f32(params, dst);
  5096. } break;
  5097. default:
  5098. {
  5099. GGML_ABORT("fatal error");
  5100. }
  5101. }
  5102. }
  5103. static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
  5104. void * a, void * b, float * c) {
  5105. const ggml_type_traits * traits = ggml_get_type_traits(type);
  5106. struct ggml_tensor src1 = {};
  5107. src1.type = type;
  5108. src1.ne[0] = k;
  5109. src1.ne[1] = m;
  5110. src1.ne[2] = 1;
  5111. src1.ne[3] = 1;
  5112. src1.nb[0] = traits->type_size;
  5113. src1.nb[1] = k * traits->type_size;
  5114. src1.nb[2] = src1.nb[1];
  5115. src1.nb[3] = src1.nb[2];
  5116. src1.data = a;
  5117. struct ggml_tensor src0 = {};
  5118. src0.type = type;
  5119. src0.ne[0] = k;
  5120. src0.ne[1] = n;
  5121. src0.ne[2] = 1;
  5122. src0.ne[3] = 1;
  5123. src0.nb[0] = traits->type_size;
  5124. src0.nb[1] = k * traits->type_size;
  5125. src0.nb[2] = src0.nb[1];
  5126. src0.nb[3] = src0.nb[2];
  5127. src0.data = b;
  5128. struct ggml_tensor dst = {};
  5129. dst.ne[0] = n;
  5130. dst.ne[1] = m;
  5131. dst.ne[2] = 1;
  5132. dst.ne[3] = 1;
  5133. dst.nb[0] = sizeof(float);
  5134. dst.nb[1] = n * sizeof(float);
  5135. dst.nb[2] = dst.nb[1];
  5136. dst.nb[3] = dst.nb[2];
  5137. dst.data = c;
  5138. dst.src[0] = &src0;
  5139. dst.src[1] = &src1;
  5140. ggml_compute_forward_mul_mat(params, &dst);
  5141. }
  5142. // ggml_compute_forward_conv_2d
  5143. static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
  5144. const ggml_tensor * kernel, // [KW, KH, IC, OC]
  5145. const ggml_tensor * src, // [W, H, C, N]
  5146. ggml_tensor * dst, // [OW, OH, OC, N]
  5147. ggml_type kernel_type) {
  5148. GGML_ASSERT(ggml_is_contiguous(kernel));
  5149. GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
  5150. GGML_ASSERT(kernel->type == kernel_type);
  5151. const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
  5152. const int32_t stride_x = dst->op_params[0];
  5153. const int32_t stride_y = dst->op_params[1];
  5154. const int32_t pad_x = dst->op_params[2];
  5155. const int32_t pad_y = dst->op_params[3];
  5156. const int32_t dilation_x = dst->op_params[4];
  5157. const int32_t dilation_y = dst->op_params[5];
  5158. const int64_t c_in = src->ne[2];
  5159. const int64_t c_out = kernel->ne[3];
  5160. GGML_ASSERT(c_in == kernel->ne[2]);
  5161. const int64_t src_w = src->ne[0];
  5162. const int64_t src_h = src->ne[1];
  5163. const int64_t knl_w = kernel->ne[0];
  5164. const int64_t knl_h = kernel->ne[1];
  5165. const int64_t dst_w = dst->ne[0];
  5166. const int64_t dst_h = dst->ne[1];
  5167. const float * src_data = (float *) src->data;
  5168. void * knl_data = kernel->data;
  5169. float * dst_data = (float *) dst->data;
  5170. const int64_t knl_n = knl_w * knl_h * c_in;
  5171. const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
  5172. const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
  5173. const int64_t batch_size = params->wsize / space_per_patch;
  5174. const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
  5175. const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
  5176. GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
  5177. void * tmp = params->wdata;
  5178. for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
  5179. const int64_t patch_start_batch = batch_i * patches_per_batch;
  5180. const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
  5181. patch_total);
  5182. const int64_t patch_n = patch_end_batch - patch_start_batch;
  5183. const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
  5184. const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
  5185. const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
  5186. //im2col for a patch
  5187. for (int64_t p = patch_start; p < patch_end; ++p) {
  5188. const int64_t batch_n = p / (dst_w * dst_h);
  5189. const int64_t src_x = (p / dst_w) % dst_h;
  5190. const int64_t src_y = p % dst_w;
  5191. const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
  5192. char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
  5193. for (int64_t ic = 0; ic < c_in; ++ic) {
  5194. for (int64_t ky = 0; ky < knl_h; ++ky) {
  5195. for (int64_t kx = 0; kx < knl_w; ++kx) {
  5196. const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
  5197. const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
  5198. int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
  5199. float src_val;
  5200. if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
  5201. src_val = 0.0f;
  5202. } else {
  5203. const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
  5204. src_val = *src_ptr;
  5205. }
  5206. char * element_ptr = dst_row + dst_idx * traits->type_size;
  5207. if (kernel_type == GGML_TYPE_F32) {
  5208. *(float *) element_ptr = src_val;
  5209. } else if (kernel_type == GGML_TYPE_F16) {
  5210. *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
  5211. }
  5212. }
  5213. }
  5214. }
  5215. } // patches handled by this thread
  5216. ggml_barrier(params->threadpool);
  5217. float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
  5218. GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
  5219. // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
  5220. ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
  5221. ggml_barrier(params->threadpool);
  5222. //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
  5223. const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
  5224. const int64_t permute_start = params->ith * permute_per_thread;
  5225. const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
  5226. for (int64_t i = permute_start; i < permute_end; ++i) {
  5227. const int64_t p = patch_start_batch + i;
  5228. const int64_t batch_n = p / (dst_w * dst_h);
  5229. const int64_t dst_y = (p / dst_w) % dst_h;
  5230. const int64_t dst_x = p % dst_w;
  5231. for (int64_t oc = 0; oc < c_out; ++oc) {
  5232. const float value = gemm_output[i * c_out + oc];
  5233. float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
  5234. *dst_ptr = value;
  5235. }
  5236. }
  5237. }
  5238. }
  5239. void ggml_compute_forward_conv_2d(
  5240. const ggml_compute_params * params,
  5241. ggml_tensor * dst) {
  5242. const ggml_tensor * src0 = dst->src[0];
  5243. const ggml_tensor * src1 = dst->src[1];
  5244. ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
  5245. }
  5246. // ggml_compute_forward_conv_3d
  5247. static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
  5248. const ggml_tensor * kernel,
  5249. const ggml_tensor * src,
  5250. ggml_tensor * dst,
  5251. ggml_type kernel_type) {
  5252. GGML_ASSERT(ggml_is_contiguous(kernel));
  5253. GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
  5254. GGML_ASSERT(kernel->type == kernel_type);
  5255. const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
  5256. const int32_t s0 = dst->op_params[0];
  5257. const int32_t s1 = dst->op_params[1];
  5258. const int32_t s2 = dst->op_params[2];
  5259. const int32_t p0 = dst->op_params[3];
  5260. const int32_t p1 = dst->op_params[4];
  5261. const int32_t p2 = dst->op_params[5];
  5262. const int32_t d0 = dst->op_params[6];
  5263. const int32_t d1 = dst->op_params[7];
  5264. const int32_t d2 = dst->op_params[8];
  5265. const int32_t c = dst->op_params[9];
  5266. const int32_t n = dst->op_params[10];
  5267. const int32_t oc = dst->op_params[11];
  5268. const int64_t src_w = src->ne[0];
  5269. const int64_t src_h = src->ne[1];
  5270. const int64_t src_d = src->ne[2];
  5271. const int64_t knl_w = kernel->ne[0];
  5272. const int64_t knl_h = kernel->ne[1];
  5273. const int64_t knl_d = kernel->ne[2];
  5274. const int64_t dst_w = dst->ne[0];
  5275. const int64_t dst_h = dst->ne[1];
  5276. const int64_t dst_d = dst->ne[2];
  5277. const float * src_data = (float *) src->data;
  5278. void * knl_data = kernel->data;
  5279. float * dst_data = (float *) dst->data;
  5280. const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
  5281. const int64_t knl_n_total = knl_n_per_channel * c;
  5282. const int64_t patch_total = n * dst_w * dst_h * dst_d;
  5283. const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
  5284. const int64_t batch_size = params->wsize / space_per_patch;
  5285. const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
  5286. const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
  5287. GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
  5288. void * tmp = params->wdata;
  5289. for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
  5290. const int64_t patch_start_batch = batch_i * patches_per_batch;
  5291. const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
  5292. const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
  5293. const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
  5294. const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
  5295. const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
  5296. for (int64_t p = patch_start; p < patch_end; ++p) {
  5297. const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
  5298. const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
  5299. const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
  5300. const int64_t dst_z = p_in_batch / (dst_w * dst_h);
  5301. const int64_t dst_y = p_in_depth / dst_w;
  5302. const int64_t dst_x = p_in_depth % dst_w;
  5303. char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
  5304. for (int64_t ic = 0; ic < c; ++ic) {
  5305. for (int64_t kz = 0; kz < knl_d; ++kz) {
  5306. for (int64_t ky = 0; ky < knl_h; ++ky) {
  5307. for (int64_t kx = 0; kx < knl_w; ++kx) {
  5308. const int64_t sz = dst_z * s2 + kz * d2 - p2;
  5309. const int64_t sy = dst_y * s1 + ky * d1 - p1;
  5310. const int64_t sx = dst_x * s0 + kx * d0 - p0;
  5311. int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
  5312. float src_val;
  5313. if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
  5314. src_val = 0.0f;
  5315. } else {
  5316. const int64_t cn_idx = batch_idx * c + ic;
  5317. const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
  5318. src_val = *src_ptr;
  5319. }
  5320. char * element_ptr = dst_row + dst_idx * traits->type_size;
  5321. if (kernel_type == GGML_TYPE_F32) {
  5322. *(float *)element_ptr = src_val;
  5323. } else if (kernel_type == GGML_TYPE_F16) {
  5324. *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
  5325. }
  5326. }
  5327. }
  5328. }
  5329. }
  5330. }
  5331. ggml_barrier(params->threadpool);
  5332. float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
  5333. ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
  5334. ggml_barrier(params->threadpool);
  5335. const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
  5336. const int64_t permute_start = params->ith * permute_per_thread;
  5337. const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
  5338. for (int64_t i = permute_start; i < permute_end; ++i) {
  5339. const int64_t p = patch_start_batch + i;
  5340. const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
  5341. const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
  5342. const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
  5343. const int64_t dst_z = p_in_batch / (dst_w * dst_h);
  5344. const int64_t dst_y = p_in_depth / dst_w;
  5345. const int64_t dst_x = p_in_depth % dst_w;
  5346. for (int64_t ioc = 0; ioc < oc; ++ioc) {
  5347. const float value = gemm_output[i * oc + ioc];
  5348. const int64_t ocn_idx = batch_idx * oc + ioc;
  5349. float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
  5350. *dst_ptr = value;
  5351. }
  5352. }
  5353. }
  5354. }
  5355. void ggml_compute_forward_conv_3d(
  5356. const ggml_compute_params * params,
  5357. ggml_tensor * dst) {
  5358. const ggml_tensor * src0 = dst->src[0];
  5359. const ggml_tensor * src1 = dst->src[1];
  5360. ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
  5361. }
  5362. // ggml_compute_forward_conv_transpose_2d
  5363. void ggml_compute_forward_conv_transpose_2d(
  5364. const ggml_compute_params * params,
  5365. ggml_tensor * dst) {
  5366. const ggml_tensor * src0 = dst->src[0];
  5367. const ggml_tensor * src1 = dst->src[1];
  5368. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  5369. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5370. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5371. GGML_TENSOR_BINARY_OP_LOCALS
  5372. const int ith = params->ith;
  5373. const int nth = params->nth;
  5374. const int nk = ne00*ne01*ne02*ne03;
  5375. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  5376. GGML_ASSERT(nb10 == sizeof(float));
  5377. if (ith == 0) {
  5378. memset(params->wdata, 0, params->wsize);
  5379. // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
  5380. {
  5381. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  5382. for (int64_t i03 = 0; i03 < ne03; i03++) {
  5383. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5384. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
  5385. ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
  5386. for (int64_t i01 = 0; i01 < ne01; i01++) {
  5387. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5388. dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
  5389. }
  5390. }
  5391. }
  5392. }
  5393. }
  5394. // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
  5395. {
  5396. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  5397. for (int i12 = 0; i12 < ne12; i12++) {
  5398. for (int i11 = 0; i11 < ne11; i11++) {
  5399. const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
  5400. ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
  5401. for (int i10 = 0; i10 < ne10; i10++) {
  5402. dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
  5403. }
  5404. }
  5405. }
  5406. }
  5407. memset(dst->data, 0, ggml_nbytes(dst));
  5408. }
  5409. ggml_barrier(params->threadpool);
  5410. const int32_t stride = ggml_get_op_params_i32(dst, 0);
  5411. // total patches in dst
  5412. const int np = ne2;
  5413. // patches per thread
  5414. const int dp = (np + nth - 1)/nth;
  5415. // patch range for this thread
  5416. const int ip0 = dp*ith;
  5417. const int ip1 = MIN(ip0 + dp, np);
  5418. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  5419. ggml_fp16_t * const wdata_src = wdata + nk;
  5420. for (int i2 = ip0; i2 < ip1; i2++) { // Cout
  5421. float * dst_data = (float *)((char *) dst->data + i2*nb2);
  5422. ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
  5423. for (int i11 = 0; i11 < ne11; i11++) {
  5424. for (int i10 = 0; i10 < ne10; i10++) {
  5425. const int i1n = i11*ne10*ne12 + i10*ne12;
  5426. for (int i01 = 0; i01 < ne01; i01++) {
  5427. for (int i00 = 0; i00 < ne00; i00++) {
  5428. float v = 0;
  5429. ggml_vec_dot_f16(ne03, &v, 0,
  5430. wdata_src + i1n, 0,
  5431. wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
  5432. dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
  5433. }
  5434. }
  5435. }
  5436. }
  5437. }
  5438. }
  5439. // ggml_compute_forward_conv_2d_dw
  5440. struct ggml_conv_2d_dw_params {
  5441. int64_t channels;
  5442. int64_t batch;
  5443. int64_t src_w;
  5444. int64_t src_h;
  5445. int64_t dst_w;
  5446. int64_t dst_h;
  5447. int64_t knl_w;
  5448. int64_t knl_h;
  5449. int stride_x;
  5450. int stride_y;
  5451. int pad_x;
  5452. int pad_y;
  5453. int dilation_x;
  5454. int dilation_y;
  5455. };
  5456. static void ggml_compute_forward_conv_2d_dw_cwhn(
  5457. const ggml_compute_params * params,
  5458. const ggml_tensor * src,
  5459. const ggml_tensor * kernel,
  5460. ggml_tensor * dst,
  5461. const ggml_conv_2d_dw_params & p) {
  5462. const int64_t c = p.channels;
  5463. const float * knl_data = (const float *)kernel->data;
  5464. const int64_t rows_total = p.dst_h * p.batch;
  5465. const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
  5466. const int64_t row_start = params->ith * rows_per_thread;
  5467. const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
  5468. #ifdef GGML_SIMD
  5469. #if defined(__ARM_FEATURE_SVE)
  5470. const int64_t pkg_size = svcntw();
  5471. #else
  5472. const int64_t pkg_size = GGML_F32_EPR;
  5473. #endif
  5474. const int64_t pkg_count = c / pkg_size;
  5475. const int64_t c_pkg_end = pkg_count * pkg_size;
  5476. #else
  5477. const int64_t c_pkg_end = 0;
  5478. #endif
  5479. for (int64_t row = row_start; row < row_end; ++row) {
  5480. const int64_t dst_y = row % p.dst_h;
  5481. const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
  5482. for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
  5483. float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
  5484. const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
  5485. const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
  5486. #ifdef GGML_SIMD
  5487. // Vectorized loop
  5488. for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
  5489. GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
  5490. for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
  5491. const int64_t src_y = src_y_base + knl_y * p.dilation_y;
  5492. if (src_y < 0 || src_y >= p.src_h) {
  5493. continue;
  5494. }
  5495. for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
  5496. const int64_t src_x = src_x_base + knl_x * p.dilation_x;
  5497. if (src_x < 0 || src_x >= p.src_w) {
  5498. continue;
  5499. }
  5500. GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
  5501. GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
  5502. sum = GGML_F32_VEC_FMA(sum, k, s);
  5503. }
  5504. }
  5505. GGML_F32_VEC_STORE(dst_data + c_i, sum);
  5506. }
  5507. #endif
  5508. // Scalar loop
  5509. for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
  5510. float sum = 0.0f;
  5511. for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
  5512. const int64_t src_y = src_y_base + knl_y * p.dilation_y;
  5513. if (src_y < 0 || src_y >= p.src_h) {
  5514. continue;
  5515. }
  5516. for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
  5517. const int64_t src_x = src_x_base + knl_x * p.dilation_x;
  5518. if (src_x < 0 || src_x >= p.src_w) {
  5519. continue;
  5520. }
  5521. sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
  5522. * src_data[(src_y * p.src_w + src_x) * c + c_i];
  5523. }
  5524. }
  5525. dst_data[c_i] = sum;
  5526. }
  5527. }
  5528. }
  5529. }
  5530. static void ggml_compute_forward_conv_2d_dw_whcn(
  5531. const ggml_compute_params * params,
  5532. const ggml_tensor * src,
  5533. const ggml_tensor * kernel,
  5534. ggml_tensor * dst,
  5535. const ggml_conv_2d_dw_params & p) {
  5536. const int64_t n = p.channels * p.batch;
  5537. const int64_t per_thread = (n + params->nth - 1) / params->nth;
  5538. const int64_t start = params->ith * per_thread;
  5539. const int64_t end = MIN(start + per_thread, n);
  5540. for (int64_t i = start; i < end; ++i) {
  5541. const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
  5542. const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
  5543. float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
  5544. for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
  5545. for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
  5546. float sum = 0.0f;
  5547. for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
  5548. const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
  5549. if (src_y < 0 || src_y >= p.src_h) {
  5550. continue;
  5551. }
  5552. for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
  5553. const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
  5554. if (src_x < 0 || src_x >= p.src_w) {
  5555. continue;
  5556. }
  5557. sum += knl_data[knl_y * p.knl_w + knl_x]
  5558. * src_data[src_y * p.src_w + src_x];
  5559. }
  5560. }
  5561. dst_data[dst_y * p.dst_w + dst_x] = sum;
  5562. }
  5563. }
  5564. }
  5565. }
  5566. void ggml_compute_forward_conv_2d_dw(
  5567. const ggml_compute_params * params,
  5568. ggml_tensor * dst) {
  5569. const ggml_tensor * kernel = dst->src[0];
  5570. const ggml_tensor * src = dst->src[1];
  5571. ggml_conv_2d_dw_params p;
  5572. p.channels = src->ne[2];
  5573. p.batch = src->ne[3];
  5574. p.src_w = src->ne[0];
  5575. p.src_h = src->ne[1];
  5576. p.dst_w = dst->ne[0];
  5577. p.dst_h = dst->ne[1];
  5578. p.knl_w = kernel->ne[0];
  5579. p.knl_h = kernel->ne[1];
  5580. p.stride_x = dst->op_params[0];
  5581. p.stride_y = dst->op_params[1];
  5582. p.pad_x = dst->op_params[2];
  5583. p.pad_y = dst->op_params[3];
  5584. p.dilation_x = dst->op_params[4];
  5585. p.dilation_y = dst->op_params[5];
  5586. GGML_ASSERT(kernel->ne[3] == p.channels);
  5587. GGML_ASSERT(dst->ne[3] == p.batch);
  5588. if (ggml_is_contiguous(src)) {
  5589. ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
  5590. } else if (ggml_is_contiguous_channels(src)) {
  5591. // kernel should also have channels most contiguous in memory
  5592. GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
  5593. ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
  5594. } else {
  5595. GGML_ABORT("non-contiguous memory layout not supported");
  5596. }
  5597. }
  5598. // ggml_compute_forward_pool_1d_sk_p0
  5599. static void ggml_compute_forward_pool_1d_sk_p0(
  5600. const ggml_compute_params * params,
  5601. const ggml_op_pool op,
  5602. const int k,
  5603. ggml_tensor * dst) {
  5604. const ggml_tensor * src = dst->src[0];
  5605. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  5606. if (params->ith != 0) {
  5607. return;
  5608. }
  5609. const char * cdata = (const char *)src->data;
  5610. const char * const data_end = cdata + ggml_nbytes(src);
  5611. float * drow = (float *)dst->data;
  5612. const int64_t rs = dst->ne[0];
  5613. while (cdata < data_end) {
  5614. const void * srow = (const void *)cdata;
  5615. int j = 0;
  5616. for (int64_t i = 0; i < rs; ++i) {
  5617. switch (op) {
  5618. case GGML_OP_POOL_AVG: drow[i] = 0; break;
  5619. case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
  5620. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5621. }
  5622. for (int ki = 0; ki < k; ++ki) {
  5623. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  5624. switch (op) {
  5625. case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
  5626. case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
  5627. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5628. }
  5629. ++j;
  5630. }
  5631. switch (op) {
  5632. case GGML_OP_POOL_AVG: drow[i] /= k; break;
  5633. case GGML_OP_POOL_MAX: break;
  5634. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5635. }
  5636. }
  5637. cdata += src->nb[1];
  5638. drow += rs;
  5639. }
  5640. }
  5641. // ggml_compute_forward_pool_1d
  5642. void ggml_compute_forward_pool_1d(
  5643. const ggml_compute_params * params,
  5644. ggml_tensor * dst) {
  5645. const int32_t * opts = (const int32_t *)dst->op_params;
  5646. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  5647. const int k0 = opts[1];
  5648. const int s0 = opts[2];
  5649. const int p0 = opts[3];
  5650. GGML_ASSERT(p0 == 0); // padding not supported
  5651. GGML_ASSERT(k0 == s0); // only s = k supported
  5652. ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
  5653. }
  5654. // ggml_compute_forward_pool_2d
  5655. void ggml_compute_forward_pool_2d(
  5656. const ggml_compute_params * params,
  5657. ggml_tensor * dst) {
  5658. const ggml_tensor * src = dst->src[0];
  5659. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  5660. if (params->ith != 0) {
  5661. return;
  5662. }
  5663. const int32_t * opts = (const int32_t *)dst->op_params;
  5664. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  5665. const int k0 = opts[1];
  5666. const int k1 = opts[2];
  5667. const int s0 = opts[3];
  5668. const int s1 = opts[4];
  5669. const int p0 = opts[5];
  5670. const int p1 = opts[6];
  5671. const char * cdata = (const char*)src->data;
  5672. const char * const data_end = cdata + ggml_nbytes(src);
  5673. const int64_t px = dst->ne[0];
  5674. const int64_t py = dst->ne[1];
  5675. const int64_t pa = px * py;
  5676. float * dplane = (float *)dst->data;
  5677. const int ka = k0 * k1;
  5678. const int offset0 = -p0;
  5679. const int offset1 = -p1;
  5680. while (cdata < data_end) {
  5681. for (int oy = 0; oy < py; ++oy) {
  5682. float * const drow = dplane + oy * px;
  5683. for (int ox = 0; ox < px; ++ox) {
  5684. float * const out = drow + ox;
  5685. switch (op) {
  5686. case GGML_OP_POOL_AVG: *out = 0; break;
  5687. case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
  5688. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5689. }
  5690. const int ix = offset0 + ox * s0;
  5691. const int iy = offset1 + oy * s1;
  5692. for (int ky = 0; ky < k1; ++ky) {
  5693. if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
  5694. const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
  5695. for (int kx = 0; kx < k0; ++kx) {
  5696. int j = ix + kx;
  5697. if (j < 0 || j >= src->ne[0]) continue;
  5698. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  5699. switch (op) {
  5700. case GGML_OP_POOL_AVG: *out += srow_j; break;
  5701. case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
  5702. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5703. }
  5704. }
  5705. }
  5706. switch (op) {
  5707. case GGML_OP_POOL_AVG: *out /= ka; break;
  5708. case GGML_OP_POOL_MAX: break;
  5709. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5710. }
  5711. }
  5712. }
  5713. cdata += src->nb[2];
  5714. dplane += pa;
  5715. }
  5716. }
  5717. // ggml_compute_forward_pool_2d_back
  5718. void ggml_compute_forward_pool_2d_back(
  5719. const ggml_compute_params * params,
  5720. ggml_tensor * dst) {
  5721. const ggml_tensor * src = dst->src[0];
  5722. const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
  5723. assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  5724. if (params->ith != 0) {
  5725. return;
  5726. }
  5727. const int32_t * opts = (const int32_t *)dst->op_params;
  5728. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  5729. const int k0 = opts[1];
  5730. const int k1 = opts[2];
  5731. const int s0 = opts[3];
  5732. const int s1 = opts[4];
  5733. const int p0 = opts[5];
  5734. const int p1 = opts[6];
  5735. char * cdata = (char *) dst->data;
  5736. const char * cdataf = (const char *) dstf->data;
  5737. const char * const data_end = cdata + ggml_nbytes(dst);
  5738. GGML_ASSERT(params->ith == 0);
  5739. memset(cdata, 0, ggml_nbytes(dst));
  5740. const int64_t px = src->ne[0];
  5741. const int64_t py = src->ne[1];
  5742. const int64_t pa = px * py;
  5743. const float * splane = (const float *) src->data;
  5744. const int ka = k0 * k1;
  5745. const int offset0 = -p0;
  5746. const int offset1 = -p1;
  5747. while (cdata < data_end) {
  5748. for (int oy = 0; oy < py; ++oy) {
  5749. const float * const srow = splane + oy * px;
  5750. for (int ox = 0; ox < px; ++ox) {
  5751. const float grad0 = srow[ox];
  5752. const int ix = offset0 + ox * s0;
  5753. const int iy = offset1 + oy * s1;
  5754. if (op == GGML_OP_POOL_MAX) {
  5755. float maxval = -FLT_MAX;
  5756. int kxmax = -1;
  5757. int kymax = -1;
  5758. for (int ky = 0; ky < k1; ++ky) {
  5759. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  5760. continue;
  5761. }
  5762. const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
  5763. for (int kx = 0; kx < k0; ++kx) {
  5764. int j = ix + kx;
  5765. if (j < 0 || j >= dst->ne[0]) {
  5766. continue;
  5767. }
  5768. const float val = dst->type == GGML_TYPE_F32 ?
  5769. ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
  5770. if (val <= maxval) {
  5771. continue;
  5772. }
  5773. maxval = val;
  5774. kxmax = kx;
  5775. kymax = ky;
  5776. }
  5777. }
  5778. if (kxmax == -1 || kymax == -1) {
  5779. continue;
  5780. }
  5781. void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
  5782. const int j = ix + kxmax;
  5783. if (dst->type == GGML_TYPE_F32) {
  5784. ((float *) drow)[j] += grad0;
  5785. } else {
  5786. ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
  5787. }
  5788. } else if (op == GGML_OP_POOL_AVG) {
  5789. const float grad = grad0 / ka;
  5790. for (int ky = 0; ky < k1; ++ky) {
  5791. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  5792. continue;
  5793. }
  5794. void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
  5795. for (int kx = 0; kx < k0; ++kx) {
  5796. int j = ix + kx;
  5797. if (j < 0 || j >= dst->ne[0]) {
  5798. continue;
  5799. }
  5800. if (dst->type == GGML_TYPE_F32) {
  5801. ((float *) drow)[j] += grad;
  5802. } else {
  5803. ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
  5804. }
  5805. }
  5806. }
  5807. } else {
  5808. GGML_ASSERT(false);
  5809. }
  5810. }
  5811. }
  5812. cdata += dst->nb[2];
  5813. cdataf += dst->nb[2];
  5814. splane += pa;
  5815. }
  5816. }
  5817. // ggml_compute_forward_upscale
  5818. static void ggml_compute_forward_upscale_f32(
  5819. const ggml_compute_params * params,
  5820. ggml_tensor * dst) {
  5821. const ggml_tensor * src0 = dst->src[0];
  5822. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  5823. const int ith = params->ith;
  5824. const int nth = params->nth;
  5825. GGML_TENSOR_UNARY_OP_LOCALS
  5826. float sf0 = (float)ne0/src0->ne[0];
  5827. float sf1 = (float)ne1/src0->ne[1];
  5828. float sf2 = (float)ne2/src0->ne[2];
  5829. float sf3 = (float)ne3/src0->ne[3];
  5830. float pixel_offset = 0.5f;
  5831. const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
  5832. const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
  5833. if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
  5834. pixel_offset = 0.0f;
  5835. sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
  5836. sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
  5837. }
  5838. if (mode == GGML_SCALE_MODE_NEAREST) {
  5839. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5840. const int64_t i03 = i3 / sf3;
  5841. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  5842. const int64_t i02 = i2 / sf2;
  5843. for (int64_t i1 = 0; i1 < ne1; i1++) {
  5844. const int64_t i01 = i1 / sf1;
  5845. for (int64_t i0 = 0; i0 < ne0; i0++) {
  5846. const int64_t i00 = i0 / sf0;
  5847. const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  5848. float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  5849. *y = *x;
  5850. }
  5851. }
  5852. }
  5853. }
  5854. } else if (mode == GGML_SCALE_MODE_BILINEAR) {
  5855. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5856. const int64_t i03 = i3 / sf3;
  5857. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  5858. const int64_t i02 = i2 / sf2;
  5859. for (int64_t i1 = 0; i1 < ne1; i1++) {
  5860. const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
  5861. int64_t y0 = (int64_t)floorf(y);
  5862. int64_t y1 = y0 + 1;
  5863. y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
  5864. y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
  5865. float dy = y - (float)y0;
  5866. dy = std::max(0.0f, std::min(dy, 1.0f));
  5867. for (int64_t i0 = 0; i0 < ne0; i0++) {
  5868. const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
  5869. int64_t x0 = (int64_t)floorf(x);
  5870. int64_t x1 = x0 + 1;
  5871. x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
  5872. x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
  5873. float dx = x - (float)x0;
  5874. dx = std::max(0.0f, std::min(dx, 1.0f));
  5875. // fetch the four surrounding pixel values and interpolate
  5876. const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
  5877. const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
  5878. const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
  5879. const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
  5880. const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
  5881. float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  5882. *y_dst = val;
  5883. }
  5884. }
  5885. }
  5886. }
  5887. } else if (mode == GGML_SCALE_MODE_BICUBIC) {
  5888. // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  5889. const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
  5890. auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
  5891. auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
  5892. auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
  5893. const float w0 = weight2(x + 1);
  5894. const float w1 = weight1(x + 0);
  5895. const float w2 = weight1(1 - x);
  5896. const float w3 = weight2(2 - x);
  5897. return p0*w0 + p1*w1 + p2*w2 + p3*w3;
  5898. };
  5899. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5900. const int64_t i03 = i3 / sf3;
  5901. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  5902. const int64_t i02 = i2 / sf2;
  5903. for (int64_t i1 = 0; i1 < ne1; i1++) {
  5904. const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
  5905. const int64_t y0 = (int64_t)floorf(y);
  5906. const float dy = y - (float)y0;
  5907. for (int64_t i0 = 0; i0 < ne0; i0++) {
  5908. const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
  5909. const int64_t x0 = (int64_t)floorf(x);
  5910. const float dx = x - (float)x0;
  5911. auto p = [=](int64_t x_off, int64_t y_off) -> float {
  5912. int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
  5913. int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
  5914. return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  5915. };
  5916. const float val = bicubic(
  5917. bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
  5918. bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
  5919. bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
  5920. bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
  5921. float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  5922. *y_dst = val;
  5923. }
  5924. }
  5925. }
  5926. }
  5927. } else {
  5928. GGML_ABORT("unsupported upscale mode");
  5929. }
  5930. }
  5931. void ggml_compute_forward_upscale(
  5932. const ggml_compute_params * params,
  5933. ggml_tensor * dst) {
  5934. const ggml_tensor * src0 = dst->src[0];
  5935. switch (src0->type) {
  5936. case GGML_TYPE_F32:
  5937. {
  5938. ggml_compute_forward_upscale_f32(params, dst);
  5939. } break;
  5940. default:
  5941. {
  5942. GGML_ABORT("fatal error");
  5943. }
  5944. }
  5945. }
  5946. // ggml_compute_forward_pad
  5947. static void ggml_compute_forward_pad_f32(
  5948. const ggml_compute_params * params,
  5949. ggml_tensor * dst) {
  5950. const ggml_tensor * src0 = dst->src[0];
  5951. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5952. GGML_ASSERT( dst->nb[0] == sizeof(float));
  5953. const int ith = params->ith;
  5954. const int nth = params->nth;
  5955. GGML_TENSOR_UNARY_OP_LOCALS
  5956. float * dst_ptr = (float *) dst->data;
  5957. const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
  5958. const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
  5959. const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
  5960. const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
  5961. const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
  5962. const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
  5963. const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
  5964. const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
  5965. // TODO: optimize
  5966. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  5967. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  5968. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  5969. for (int64_t i3 = 0; i3 < ne3; ++i3) {
  5970. const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
  5971. if ((i0 >= lp0 && i0 < ne0 - rp0) \
  5972. && (i1 >= lp1 && i1 < ne1 - rp1) \
  5973. && (i2 >= lp2 && i2 < ne2 - rp2) \
  5974. && (i3 >= lp3 && i3 < ne3 - rp3)) {
  5975. const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
  5976. const float * src_ptr = (const float *)((char *) src0->data + src_idx);
  5977. dst_ptr[dst_idx] = *src_ptr;
  5978. } else {
  5979. dst_ptr[dst_idx] = 0;
  5980. }
  5981. }
  5982. }
  5983. }
  5984. }
  5985. }
  5986. void ggml_compute_forward_pad(
  5987. const ggml_compute_params * params,
  5988. ggml_tensor * dst) {
  5989. const ggml_tensor * src0 = dst->src[0];
  5990. switch (src0->type) {
  5991. case GGML_TYPE_F32:
  5992. {
  5993. ggml_compute_forward_pad_f32(params, dst);
  5994. } break;
  5995. default:
  5996. {
  5997. GGML_ABORT("fatal error");
  5998. }
  5999. }
  6000. }
  6001. // ggml_compute_forward_pad_reflect_1d
  6002. void ggml_compute_forward_pad_reflect_1d(
  6003. const ggml_compute_params * params,
  6004. ggml_tensor * dst) {
  6005. const ggml_tensor * src0 = dst->src[0];
  6006. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  6007. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  6008. const int ith = params->ith;
  6009. const int nth = params->nth;
  6010. const int32_t * opts = (const int32_t *) dst->op_params;
  6011. const int p0 = opts[0];
  6012. const int p1 = opts[1];
  6013. GGML_TENSOR_UNARY_OP_LOCALS
  6014. for (int64_t i3 = 0; i3 < ne3; i3++) {
  6015. for (int64_t i2 = 0; i2 < ne2; i2++) {
  6016. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  6017. float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
  6018. float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
  6019. ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
  6020. for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
  6021. for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
  6022. }
  6023. }
  6024. }
  6025. }
  6026. // ggml_compute_forward_roll
  6027. static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
  6028. if (i < 0) {
  6029. return i + ne;
  6030. } else if (i >= ne) {
  6031. return i - ne;
  6032. }
  6033. return i;
  6034. }
  6035. static void ggml_compute_forward_roll_f32(
  6036. const ggml_compute_params * params,
  6037. ggml_tensor * dst) {
  6038. const ggml_tensor * src0 = dst->src[0];
  6039. const float * src_data = (const float *) src0->data;
  6040. float * dst_data = (float *) dst->data;
  6041. GGML_TENSOR_UNARY_OP_LOCALS
  6042. const int s0 = ggml_get_op_params_i32(dst, 0);
  6043. const int s1 = ggml_get_op_params_i32(dst, 1);
  6044. const int s2 = ggml_get_op_params_i32(dst, 2);
  6045. const int s3 = ggml_get_op_params_i32(dst, 3);
  6046. const int64_t total = ne1 * ne2 * ne3;
  6047. const int64_t per_thread = (total + params->nth) / params->nth;
  6048. const int64_t start = params->ith * per_thread;
  6049. const int64_t end = std::min(start + per_thread, total);
  6050. for (int64_t i = start; i < end; ++i) {
  6051. const int64_t i1 = i % ne1;
  6052. const int64_t i2 = (i / ne1) % ne2;
  6053. const int64_t i3 = i / (ne2 * ne1);
  6054. float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
  6055. const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
  6056. const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
  6057. const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
  6058. const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
  6059. const int64_t s = ggml_wrap_index(-s0, ne00);
  6060. const int64_t n = ne00 - s;
  6061. ggml_vec_cpy_f32(n, dst_row, src_row + s);
  6062. ggml_vec_cpy_f32(s, dst_row + n, src_row);
  6063. }
  6064. }
  6065. void ggml_compute_forward_roll(
  6066. const ggml_compute_params * params,
  6067. ggml_tensor * dst) {
  6068. const ggml_tensor * src0 = dst->src[0];
  6069. switch (src0->type) {
  6070. case GGML_TYPE_F32:
  6071. {
  6072. ggml_compute_forward_roll_f32(params, dst);
  6073. } break;
  6074. default:
  6075. {
  6076. GGML_ABORT("fatal error");
  6077. }
  6078. }
  6079. }
  6080. // ggml_compute_forward_arange
  6081. static void ggml_compute_forward_arange_f32(
  6082. const ggml_compute_params * params,
  6083. ggml_tensor * dst) {
  6084. GGML_ASSERT(dst->nb[0] == sizeof(float));
  6085. const int ith = params->ith;
  6086. const int nth = params->nth;
  6087. const float start = ggml_get_op_params_f32(dst, 0);
  6088. const float stop = ggml_get_op_params_f32(dst, 1);
  6089. const float step = ggml_get_op_params_f32(dst, 2);
  6090. const int64_t steps = (int64_t) ceilf((stop - start) / step);
  6091. GGML_ASSERT(ggml_nelements(dst) == steps);
  6092. for (int64_t i = ith; i < steps; i+= nth) {
  6093. float value = start + step * i;
  6094. ((float *)dst->data)[i] = value;
  6095. }
  6096. }
  6097. void ggml_compute_forward_arange(
  6098. const ggml_compute_params * params,
  6099. ggml_tensor * dst) {
  6100. switch (dst->type) {
  6101. case GGML_TYPE_F32:
  6102. {
  6103. ggml_compute_forward_arange_f32(params, dst);
  6104. } break;
  6105. default:
  6106. {
  6107. GGML_ABORT("fatal error");
  6108. }
  6109. }
  6110. }
  6111. static void ggml_compute_forward_timestep_embedding_f32(
  6112. const ggml_compute_params * params,
  6113. ggml_tensor * dst) {
  6114. const ggml_tensor * src0 = dst->src[0];
  6115. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6116. const int ith = params->ith;
  6117. const int nth = params->nth;
  6118. GGML_TENSOR_UNARY_OP_LOCALS
  6119. const int dim = ggml_get_op_params_i32(dst, 0);
  6120. const int max_period = ggml_get_op_params_i32(dst, 1);
  6121. int half = dim / 2;
  6122. for (int64_t i = 0; i < ne00; i++) {
  6123. float * embed_data = (float *)((char *) dst->data + i*nb1);
  6124. for (int64_t j = ith; j < half; j += nth) {
  6125. float timestep = ((float *)src0->data)[i];
  6126. float freq = (float)expf(-logf(max_period) * j / half);
  6127. float arg = timestep * freq;
  6128. embed_data[j] = cosf(arg);
  6129. embed_data[j + half] = sinf(arg);
  6130. }
  6131. if (dim % 2 != 0 && ith == 0) {
  6132. embed_data[2 * half] = 0.f;
  6133. }
  6134. }
  6135. }
  6136. void ggml_compute_forward_timestep_embedding(
  6137. const ggml_compute_params * params,
  6138. ggml_tensor * dst) {
  6139. const ggml_tensor * src0 = dst->src[0];
  6140. switch (src0->type) {
  6141. case GGML_TYPE_F32:
  6142. {
  6143. ggml_compute_forward_timestep_embedding_f32(params, dst);
  6144. } break;
  6145. default:
  6146. {
  6147. GGML_ABORT("fatal error");
  6148. }
  6149. }
  6150. }
  6151. // ggml_compute_forward_argsort
  6152. static void ggml_compute_forward_argsort_f32(
  6153. const ggml_compute_params * params,
  6154. ggml_tensor * dst) {
  6155. const ggml_tensor * src0 = dst->src[0];
  6156. GGML_TENSOR_UNARY_OP_LOCALS
  6157. GGML_ASSERT(nb0 == sizeof(float));
  6158. const int ith = params->ith;
  6159. const int nth = params->nth;
  6160. const int64_t nr = ggml_nrows(src0);
  6161. ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
  6162. for (int64_t i = ith; i < nr; i += nth) {
  6163. const float * src_data = (float *)((char *) src0->data + i*nb01);
  6164. int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
  6165. for (int64_t j = 0; j < ne0; j++) {
  6166. dst_data[j] = j;
  6167. }
  6168. std::function<bool(int32_t, int32_t)> cmp;
  6169. // note: this might be causing memory allocations? ideally should be avoided if it's the case
  6170. switch (order) {
  6171. case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
  6172. case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
  6173. default: GGML_ABORT("invalid sort order");
  6174. }
  6175. std::sort(dst_data, dst_data + ne0, cmp);
  6176. }
  6177. }
  6178. void ggml_compute_forward_argsort(
  6179. const ggml_compute_params * params,
  6180. ggml_tensor * dst) {
  6181. const ggml_tensor * src0 = dst->src[0];
  6182. switch (src0->type) {
  6183. case GGML_TYPE_F32:
  6184. {
  6185. ggml_compute_forward_argsort_f32(params, dst);
  6186. } break;
  6187. default:
  6188. {
  6189. GGML_ABORT("fatal error");
  6190. }
  6191. }
  6192. }
  6193. // ggml_compute_forward_flash_attn_ext
  6194. static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
  6195. const ggml_compute_params * params,
  6196. ggml_tensor * dst,
  6197. int ir0, int ir1) {
  6198. const ggml_tensor * q = dst->src[0];
  6199. const ggml_tensor * k = dst->src[1];
  6200. const ggml_tensor * v = dst->src[2];
  6201. const ggml_tensor * mask = dst->src[3];
  6202. const ggml_tensor * sinks = dst->src[4];
  6203. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  6204. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  6205. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  6206. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  6207. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  6208. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  6209. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6210. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  6211. const int64_t DK = nek0;
  6212. const int64_t DV = nev0;
  6213. const int64_t N = neq1;
  6214. GGML_ASSERT(ne0 == DV);
  6215. GGML_ASSERT(ne2 == N);
  6216. // input tensor rows must be contiguous
  6217. GGML_ASSERT(nbq0 == ggml_type_size(q->type));
  6218. GGML_ASSERT(nbk0 == ggml_type_size(k->type));
  6219. GGML_ASSERT(nbv0 == ggml_type_size(v->type));
  6220. GGML_ASSERT(neq0 == DK);
  6221. GGML_ASSERT(nek0 == DK);
  6222. GGML_ASSERT(nev0 == DV);
  6223. GGML_ASSERT(neq1 == N);
  6224. // dst cannot be transposed or permuted
  6225. GGML_ASSERT(nb0 == sizeof(float));
  6226. GGML_ASSERT(nb0 <= nb1);
  6227. GGML_ASSERT(nb1 <= nb2);
  6228. GGML_ASSERT(nb2 <= nb3);
  6229. // broadcast factors
  6230. const int64_t rk2 = neq2/nek2;
  6231. const int64_t rk3 = neq3/nek3;
  6232. const int64_t rv2 = neq2/nev2;
  6233. const int64_t rv3 = neq3/nev3;
  6234. // parallelize by q rows using ggml_vec_dot_f32
  6235. float scale = 1.0f;
  6236. float max_bias = 0.0f;
  6237. float logit_softcap = 0.0f;
  6238. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  6239. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  6240. memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
  6241. if (logit_softcap != 0) {
  6242. scale /= logit_softcap;
  6243. }
  6244. const uint32_t n_head = neq2;
  6245. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  6246. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  6247. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  6248. ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
  6249. ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
  6250. ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
  6251. ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
  6252. GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
  6253. GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
  6254. int ith = params->ith;
  6255. // loop over n_batch and n_head
  6256. for (int ir = ir0; ir < ir1; ++ir) {
  6257. // q indices
  6258. const int iq3 = ir/(neq2*neq1);
  6259. const int iq2 = (ir - iq3*neq2*neq1)/neq1;
  6260. const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
  6261. const uint32_t h = iq2; // head index
  6262. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  6263. float S = 0.0f; // sum
  6264. float M = -INFINITY; // maximum KQ value
  6265. float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
  6266. float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
  6267. ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
  6268. ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
  6269. if (v->type == GGML_TYPE_F16) {
  6270. memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
  6271. } else {
  6272. memset(VKQ32, 0, DV*sizeof(float));
  6273. }
  6274. const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
  6275. // k indices
  6276. const int ik3 = iq3 / rk3;
  6277. const int ik2 = iq2 / rk2;
  6278. // v indices
  6279. const int iv3 = iq3 / rv3;
  6280. const int iv2 = iq2 / rv2;
  6281. const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
  6282. q_to_vec_dot(pq, Q_q, DK);
  6283. // online softmax / attention
  6284. // loop over n_kv and n_head_kv
  6285. // ref: https://arxiv.org/pdf/2112.05682.pdf
  6286. for (int64_t ic = 0; ic < nek1; ++ic) {
  6287. const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
  6288. if (mv == -INFINITY) {
  6289. continue;
  6290. }
  6291. float s; // KQ value
  6292. const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
  6293. kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
  6294. s = s*scale; // scale KQ value
  6295. if (logit_softcap != 0.0f) {
  6296. s = logit_softcap*tanhf(s);
  6297. }
  6298. s += mv; // apply mask
  6299. const float Mold = M;
  6300. float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
  6301. float vs = 1.0f; // post-softmax KQ value, expf(s - M)
  6302. const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
  6303. if (v->type == GGML_TYPE_F16) {
  6304. if (s > M) {
  6305. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  6306. M = s;
  6307. ms = expf(Mold - M);
  6308. // V = V*expf(Mold - M)
  6309. ggml_vec_scale_f16(DV, VKQ16, ms);
  6310. } else {
  6311. // no new maximum, ms == 1.0f, vs != 1.0f
  6312. vs = expf(s - M);
  6313. }
  6314. // V += v*expf(s - M)
  6315. ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
  6316. } else {
  6317. if (s > M) {
  6318. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  6319. M = s;
  6320. ms = expf(Mold - M);
  6321. // V = V*expf(Mold - M)
  6322. ggml_vec_scale_f32(DV, VKQ32, ms);
  6323. } else {
  6324. // no new maximum, ms == 1.0f, vs != 1.0f
  6325. vs = expf(s - M);
  6326. }
  6327. // V += v*expf(s - M)
  6328. if (v_to_float) {
  6329. v_to_float(v_data, V32, DV);
  6330. ggml_vec_mad_f32(DV, VKQ32, V32, vs);
  6331. } else {
  6332. // V is F32
  6333. ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
  6334. }
  6335. }
  6336. S = S*ms + vs; // scale and increment sum with partial sum
  6337. }
  6338. if (v->type == GGML_TYPE_F16) {
  6339. for (int64_t d = 0; d < DV; ++d) {
  6340. VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
  6341. }
  6342. }
  6343. // sinks
  6344. if (sinks) {
  6345. const float s = ((float *)((char *) sinks->data))[h];
  6346. float ms = 1.0f;
  6347. float vs = 1.0f;
  6348. if (s > M) {
  6349. ms = expf(M - s);
  6350. ggml_vec_scale_f32(DV, VKQ32, ms);
  6351. } else {
  6352. vs = expf(s - M);
  6353. }
  6354. S = S*ms + vs;
  6355. }
  6356. // V /= S
  6357. const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
  6358. ggml_vec_scale_f32(DV, VKQ32, S_inv);
  6359. // dst indices
  6360. const int i1 = iq1;
  6361. const int i2 = iq2;
  6362. const int i3 = iq3;
  6363. // original
  6364. //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
  6365. // permute(0, 2, 1, 3)
  6366. memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
  6367. }
  6368. }
  6369. static void ggml_compute_forward_flash_attn_ext_f16(
  6370. const ggml_compute_params * params,
  6371. ggml_tensor * dst) {
  6372. const ggml_tensor * q = dst->src[0];
  6373. const ggml_tensor * k = dst->src[1];
  6374. const ggml_tensor * v = dst->src[2];
  6375. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  6376. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  6377. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  6378. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  6379. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  6380. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  6381. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6382. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  6383. const int64_t DK = nek0;
  6384. const int64_t DV = nev0;
  6385. const int64_t N = neq1;
  6386. GGML_ASSERT(ne0 == DV);
  6387. GGML_ASSERT(ne2 == N);
  6388. // input tensor rows must be contiguous
  6389. GGML_ASSERT(nbq0 == ggml_type_size(q->type));
  6390. GGML_ASSERT(nbk0 == ggml_type_size(k->type));
  6391. GGML_ASSERT(nbv0 == ggml_type_size(v->type));
  6392. GGML_ASSERT(neq0 == DK);
  6393. GGML_ASSERT(nek0 == DK);
  6394. GGML_ASSERT(nev0 == DV);
  6395. GGML_ASSERT(neq1 == N);
  6396. // dst cannot be transposed or permuted
  6397. GGML_ASSERT(nb0 == sizeof(float));
  6398. GGML_ASSERT(nb0 <= nb1);
  6399. GGML_ASSERT(nb1 <= nb2);
  6400. GGML_ASSERT(nb2 <= nb3);
  6401. // parallelize by q rows using ggml_vec_dot_f32
  6402. // total rows in q
  6403. const int64_t nr = neq1*neq2*neq3;
  6404. // rows per thread
  6405. const int ith = params->ith;
  6406. const int nth = params->nth;
  6407. // disable for NUMA
  6408. const bool disable_chunking = ggml_is_numa();
  6409. // 4x chunks per thread
  6410. int nth_scaled = nth * 4;
  6411. int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
  6412. int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
  6413. if (nth == 1 || nchunk < nth || disable_chunking) {
  6414. nchunk = nth;
  6415. }
  6416. if (ith == 0) {
  6417. // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
  6418. ggml_threadpool_chunk_set(params->threadpool, nth);
  6419. }
  6420. ggml_barrier(params->threadpool);
  6421. // The number of elements in each chunk
  6422. const int64_t dr = (nr + nchunk - 1) / nchunk;
  6423. // The first chunk comes from our thread_id, the rest will get auto-assigned.
  6424. int current_chunk = ith;
  6425. while (current_chunk < nchunk) {
  6426. const int64_t ir0 = dr * current_chunk;
  6427. const int64_t ir1 = MIN(ir0 + dr, nr);
  6428. ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
  6429. current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
  6430. }
  6431. }
  6432. void ggml_compute_forward_flash_attn_ext(
  6433. const ggml_compute_params * params,
  6434. ggml_tensor * dst) {
  6435. switch (dst->op_params[3]) {
  6436. case GGML_PREC_DEFAULT:
  6437. case GGML_PREC_F32:
  6438. {
  6439. // uses F32 accumulators
  6440. ggml_compute_forward_flash_attn_ext_f16(params, dst);
  6441. } break;
  6442. default:
  6443. {
  6444. GGML_ABORT("fatal error");
  6445. }
  6446. }
  6447. }
  6448. // ggml_compute_forward_flash_attn_back
  6449. static void ggml_compute_forward_flash_attn_back_f32(
  6450. const ggml_compute_params * params,
  6451. const bool masked,
  6452. ggml_tensor * dst) {
  6453. const ggml_tensor * q = dst->src[0];
  6454. const ggml_tensor * k = dst->src[1];
  6455. const ggml_tensor * v = dst->src[2];
  6456. const ggml_tensor * d = dst->src[3];
  6457. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  6458. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  6459. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  6460. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  6461. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  6462. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  6463. GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
  6464. GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
  6465. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6466. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  6467. const int ith = params->ith;
  6468. const int nth = params->nth;
  6469. const int64_t D = neq0;
  6470. const int64_t N = neq1;
  6471. const int64_t P = nek1 - N;
  6472. const int64_t M = P + N;
  6473. const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
  6474. const int mxDM = MAX(D, Mup);
  6475. // GGML_ASSERT(ne0 == D);
  6476. // GGML_ASSERT(ne1 == N);
  6477. GGML_ASSERT(P >= 0);
  6478. GGML_ASSERT(nbq0 == sizeof(float));
  6479. GGML_ASSERT(nbk0 == sizeof(float));
  6480. GGML_ASSERT(nbv0 == sizeof(float));
  6481. GGML_ASSERT(neq0 == D);
  6482. GGML_ASSERT(nek0 == D);
  6483. GGML_ASSERT(nev1 == D);
  6484. GGML_ASSERT(ned0 == D);
  6485. GGML_ASSERT(neq1 == N);
  6486. GGML_ASSERT(nek1 == N + P);
  6487. GGML_ASSERT(nev1 == D);
  6488. GGML_ASSERT(ned1 == N);
  6489. // dst cannot be transposed or permuted
  6490. GGML_ASSERT(nb0 == sizeof(float));
  6491. GGML_ASSERT(nb0 <= nb1);
  6492. GGML_ASSERT(nb1 <= nb2);
  6493. GGML_ASSERT(nb2 <= nb3);
  6494. if (ith == 0) {
  6495. memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
  6496. }
  6497. ggml_barrier(params->threadpool);
  6498. const int64_t elem_q = ggml_nelements(q);
  6499. const int64_t elem_k = ggml_nelements(k);
  6500. ggml_type result_type = dst->type;
  6501. GGML_ASSERT(ggml_blck_size(result_type) == 1);
  6502. const size_t tsize = ggml_type_size(result_type);
  6503. const size_t offs_q = 0;
  6504. const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
  6505. const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
  6506. void * grad_q = (char *) dst->data;
  6507. void * grad_k = (char *) dst->data + offs_k;
  6508. void * grad_v = (char *) dst->data + offs_v;
  6509. const size_t nbgq1 = nb0*neq0;
  6510. const size_t nbgq2 = nb0*neq0*neq1;
  6511. const size_t nbgq3 = nb0*neq0*neq1*neq2;
  6512. const size_t nbgk1 = nb0*nek0;
  6513. const size_t nbgk2 = nb0*nek0*nek1;
  6514. const size_t nbgk3 = nb0*nek0*nek1*neq2;
  6515. const size_t nbgv1 = nb0*nev0;
  6516. const size_t nbgv2 = nb0*nev0*nev1;
  6517. const size_t nbgv3 = nb0*nev0*nev1*neq2;
  6518. // parallelize by k rows using ggml_vec_dot_f32
  6519. // total rows in k
  6520. const int nr = nek2*nek3;
  6521. // rows per thread
  6522. const int dr = (nr + nth - 1)/nth;
  6523. // row range for this thread
  6524. const int ir0 = dr*ith;
  6525. const int ir1 = MIN(ir0 + dr, nr);
  6526. const float scale = 1.0f/sqrtf(D);
  6527. //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
  6528. // how often k2 (and v2) is repeated in q2
  6529. int nrep = neq2/nek2;
  6530. for (int ir = ir0; ir < ir1; ++ir) {
  6531. // q indices
  6532. const int ik3 = ir/(nek2);
  6533. const int ik2 = ir - ik3*nek2;
  6534. const int iq3 = ik3;
  6535. const int id3 = ik3;
  6536. const int iv3 = ik3;
  6537. const int iv2 = ik2;
  6538. for (int irep = 0; irep < nrep; ++irep) {
  6539. const int iq2 = ik2 + irep*nek2;
  6540. const int id2 = iq2;
  6541. // (ik2 + irep*nek2) % nek2 == ik2
  6542. for (int iq1 = 0; iq1 < neq1; ++iq1) {
  6543. const int id1 = iq1;
  6544. // not sure about CACHE_LINE_SIZE_F32..
  6545. // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
  6546. float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
  6547. float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
  6548. for (int i = M; i < Mup; ++i) {
  6549. S[i] = -INFINITY;
  6550. }
  6551. const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
  6552. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  6553. // k indices
  6554. const int ik1 = ic;
  6555. // S indices
  6556. const int i1 = ik1;
  6557. ggml_vec_dot_f32(neq0,
  6558. S + i1, 0,
  6559. (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
  6560. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
  6561. }
  6562. // scale
  6563. ggml_vec_scale_f32(masked_begin, S, scale);
  6564. for (int64_t i = masked_begin; i < M; i++) {
  6565. S[i] = -INFINITY;
  6566. }
  6567. // softmax
  6568. // exclude known -INF S[..] values from max and loop
  6569. // dont forget to set their SM values to zero
  6570. {
  6571. float max = -INFINITY;
  6572. ggml_vec_max_f32(masked_begin, &max, S);
  6573. ggml_float sum = 0.0;
  6574. {
  6575. #ifdef GGML_SOFT_MAX_ACCELERATE
  6576. max = -max;
  6577. vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
  6578. vvexpf(SM, SM, &Mup);
  6579. ggml_vec_sum_f32(Mup, &sum, SM);
  6580. #else
  6581. sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
  6582. #endif
  6583. }
  6584. assert(sum > 0.0);
  6585. sum = 1.0/sum;
  6586. ggml_vec_scale_f32(masked_begin, SM, sum);
  6587. }
  6588. // step-by-step explanation
  6589. {
  6590. // forward-process shape grads from backward process
  6591. // parallel_for ik2,ik3:
  6592. // for irep:
  6593. // iq2 = ik2 + irep*nek2
  6594. // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
  6595. // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
  6596. // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
  6597. // for iq1:
  6598. // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
  6599. // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
  6600. // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
  6601. // S0 = -Inf [D,1,1,1]
  6602. // ~S1[i] = dot(kcur[:D,i], qcur)
  6603. // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
  6604. // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
  6605. // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  6606. // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
  6607. // ~S5[i] = dot(vcur[:,i], S4)
  6608. // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
  6609. // ~dst[i,iq1,iq2,iq3] = S5[i] ^
  6610. // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
  6611. // dst backward-/ grad[dst] = d
  6612. //
  6613. // output gradients with their dependencies:
  6614. //
  6615. // grad[kcur] = grad[S1].T @ qcur
  6616. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  6617. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  6618. // grad[S4] = grad[S5] @ vcur
  6619. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  6620. // grad[qcur] = grad[S1] @ kcur
  6621. // grad[vcur] = grad[S5].T @ S4
  6622. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  6623. //
  6624. // in post-order:
  6625. //
  6626. // S1 = qcur @ kcur.T
  6627. // S2 = S1 * scale
  6628. // S3 = diag_mask_inf(S2, P)
  6629. // S4 = softmax(S3)
  6630. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  6631. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  6632. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  6633. // grad[qcur] = grad[S1] @ kcur
  6634. // grad[kcur] = grad[S1].T @ qcur
  6635. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  6636. //
  6637. // using less variables (SM=S4):
  6638. //
  6639. // S = diag_mask_inf(qcur @ kcur.T * scale, P)
  6640. // SM = softmax(S)
  6641. // S = d[:D,iq1,iq2,iq3] @ vcur
  6642. // dot_SM_gradSM = dot(SM, S)
  6643. // S = SM * (S - dot(SM, S))
  6644. // S = diag_mask_zero(S, P) * scale
  6645. //
  6646. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  6647. // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
  6648. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  6649. }
  6650. // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  6651. // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  6652. // for ic:
  6653. // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
  6654. // exclude known future zero S[..] values from operation
  6655. ggml_vec_set_f32(masked_begin, S, 0);
  6656. for (int64_t ic = 0; ic < D; ++ic) {
  6657. ggml_vec_mad_f32(masked_begin,
  6658. S,
  6659. (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
  6660. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  6661. }
  6662. // S = SM * (S - dot(SM, S))
  6663. float dot_SM_gradSM = 0;
  6664. ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
  6665. ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
  6666. ggml_vec_mul_f32 (masked_begin, S, S, SM);
  6667. // S = diag_mask_zero(S, P) * scale
  6668. // already done by above ggml_vec_set_f32
  6669. // exclude known zero S[..] values from operation
  6670. ggml_vec_scale_f32(masked_begin, S, scale);
  6671. // S shape [M,1]
  6672. // SM shape [M,1]
  6673. // kcur shape [D,M]
  6674. // qcur shape [D,1]
  6675. // vcur shape [M,D]
  6676. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  6677. // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
  6678. // for ic:
  6679. // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
  6680. // exclude known zero S[..] values from loop
  6681. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  6682. ggml_vec_mad_f32(D,
  6683. (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
  6684. (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
  6685. S[ic]);
  6686. }
  6687. // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
  6688. // for ic:
  6689. // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
  6690. // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
  6691. // exclude known zero S[..] values from loop
  6692. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  6693. ggml_vec_mad_f32(D,
  6694. (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
  6695. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
  6696. S[ic]);
  6697. }
  6698. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  6699. // for ic:
  6700. // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
  6701. // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
  6702. // exclude known zero SM[..] values from mad
  6703. for (int64_t ic = 0; ic < D; ++ic) {
  6704. ggml_vec_mad_f32(masked_begin,
  6705. (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
  6706. SM,
  6707. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  6708. }
  6709. }
  6710. }
  6711. }
  6712. }
  6713. void ggml_compute_forward_flash_attn_back(
  6714. const ggml_compute_params * params,
  6715. const bool masked,
  6716. ggml_tensor * dst) {
  6717. const ggml_tensor * q = dst->src[0];
  6718. switch (q->type) {
  6719. case GGML_TYPE_F32:
  6720. {
  6721. ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
  6722. } break;
  6723. default:
  6724. {
  6725. GGML_ABORT("fatal error");
  6726. }
  6727. }
  6728. }
  6729. // ggml_compute_forward_ssm_conv
  6730. static void ggml_compute_forward_ssm_conv_f32(
  6731. const ggml_compute_params * params,
  6732. ggml_tensor * dst) {
  6733. const ggml_tensor * src0 = dst->src[0]; // conv_x
  6734. const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
  6735. const int ith = params->ith;
  6736. const int nth = params->nth;
  6737. const int nc = src1->ne[0]; // d_conv
  6738. const int ncs = src0->ne[0]; // d_conv - 1 + n_t
  6739. const int nr = src0->ne[1]; // d_inner
  6740. const int n_t = dst->ne[1]; // tokens per sequence
  6741. const int n_s = dst->ne[2]; // number of sequences in the batch
  6742. GGML_ASSERT( dst->ne[0] == nr);
  6743. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6744. GGML_ASSERT(src1->nb[0] == sizeof(float));
  6745. GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
  6746. // rows per thread
  6747. const int dr = (nr + nth - 1)/nth;
  6748. // row range for this thread
  6749. const int ir0 = dr*ith;
  6750. const int ir1 = MIN(ir0 + dr, nr);
  6751. const int ir = ir1 - ir0;
  6752. for (int i3 = 0; i3 < n_s; ++i3) {
  6753. for (int i2 = 0; i2 < n_t; ++i2) {
  6754. // {d_conv - 1 + n_t, d_inner, n_seqs}
  6755. // sliding window
  6756. const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
  6757. const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
  6758. float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
  6759. // TODO: transpose the output for smaller strides for big batches?
  6760. // d_inner
  6761. for (int i1 = 0; i1 < ir; ++i1) {
  6762. // rowwise dot product
  6763. // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
  6764. float sumf = 0.0f;
  6765. // d_conv
  6766. for (int i0 = 0; i0 < nc; ++i0) {
  6767. sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
  6768. }
  6769. x[i1] = sumf;
  6770. }
  6771. }
  6772. }
  6773. }
  6774. void ggml_compute_forward_ssm_conv(
  6775. const ggml_compute_params * params,
  6776. ggml_tensor * dst) {
  6777. switch (dst->src[0]->type) {
  6778. case GGML_TYPE_F32:
  6779. {
  6780. ggml_compute_forward_ssm_conv_f32(params, dst);
  6781. } break;
  6782. default:
  6783. {
  6784. GGML_ABORT("fatal error");
  6785. }
  6786. }
  6787. }
  6788. // ggml_compute_forward_ssm_scan
  6789. static void ggml_compute_forward_ssm_scan_f32(
  6790. const ggml_compute_params * params,
  6791. ggml_tensor * dst) {
  6792. const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
  6793. const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
  6794. const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
  6795. const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
  6796. const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
  6797. const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
  6798. const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
  6799. const int ith = params->ith;
  6800. const int nth = params->nth;
  6801. const int64_t nc = src0->ne[0]; // d_state
  6802. const int64_t nr = src0->ne[1]; // dim
  6803. const int64_t nh = src1->ne[1]; // n_head
  6804. const int64_t ng = src4->ne[1];
  6805. const int64_t nt = src1->ne[2]; // number of tokens per sequence
  6806. const int64_t ns = src1->ne[3]; // number of sequences in the batch
  6807. // can't use ggml_nbytes because src1 is not necessarily contiguous
  6808. const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
  6809. GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
  6810. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6811. GGML_ASSERT(src1->nb[0] == sizeof(float));
  6812. GGML_ASSERT(src2->nb[0] == sizeof(float));
  6813. GGML_ASSERT(src3->nb[0] == sizeof(float));
  6814. GGML_ASSERT(src4->nb[0] == sizeof(float));
  6815. GGML_ASSERT(src5->nb[0] == sizeof(float));
  6816. GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
  6817. GGML_ASSERT(nh % ng == 0);
  6818. // heads per thread
  6819. const int dh = (nh + nth - 1)/nth;
  6820. // head range for this thread
  6821. const int ih0 = dh*ith;
  6822. const int ih1 = MIN(ih0 + dh, nh);
  6823. const int32_t * ids = (const int32_t *) src6->data;
  6824. for (int i3 = 0; i3 < ns; ++i3) {
  6825. const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
  6826. float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
  6827. for (int i2 = 0; i2 < nt; ++i2) {
  6828. const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
  6829. const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
  6830. const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
  6831. const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
  6832. const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
  6833. float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
  6834. if (src3->ne[0] == 1) {
  6835. // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
  6836. // n_head
  6837. for (int h = ih0; h < ih1; ++h) {
  6838. // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
  6839. const float dt_soft_plus = ggml_softplus(dt[h]);
  6840. const float dA = expf(dt_soft_plus * A[h]);
  6841. const int g = h / (nh / ng); // repeat_interleave
  6842. // dim
  6843. for (int i1 = 0; i1 < nr; ++i1) {
  6844. const int ii = i1 + h*nr;
  6845. const float x_dt = x[ii] * dt_soft_plus;
  6846. float sumf = 0.0f;
  6847. #if defined(GGML_SIMD)
  6848. #if defined(__ARM_FEATURE_SVE)
  6849. const int ggml_f32_epr = svcntw();
  6850. const int ggml_f32_step = 1 * ggml_f32_epr;
  6851. const int np = (nc & ~(ggml_f32_step - 1));
  6852. GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
  6853. GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
  6854. GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
  6855. for (int i = 0; i < np; i += ggml_f32_step) {
  6856. // TODO: maybe unroll more?
  6857. for (int j = 0; j < 1; j++) {
  6858. GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
  6859. GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
  6860. GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
  6861. t0 = GGML_F32_VEC_MUL(t0, adA);
  6862. t1 = GGML_F32_VEC_MUL(t1, axdt);
  6863. t0 = GGML_F32_VEC_ADD(t0, t1);
  6864. sum = GGML_F32_VEC_FMA(sum, t0, t2);
  6865. GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
  6866. }
  6867. }
  6868. sumf = GGML_F32xt_REDUCE_ONE(sum);
  6869. #elif defined(__riscv_v_intrinsic)
  6870. // todo: RVV implementation
  6871. const int np = 0;
  6872. #else
  6873. const int np = (nc & ~(GGML_F32_STEP - 1));
  6874. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  6875. GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
  6876. GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
  6877. GGML_F32_VEC ax[GGML_F32_ARR];
  6878. GGML_F32_VEC ay[GGML_F32_ARR];
  6879. GGML_F32_VEC az[GGML_F32_ARR];
  6880. for (int i = 0; i < np; i += GGML_F32_STEP) {
  6881. for (int j = 0; j < GGML_F32_ARR; j++) {
  6882. ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
  6883. ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
  6884. az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
  6885. ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
  6886. ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
  6887. ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
  6888. sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
  6889. GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
  6890. }
  6891. }
  6892. // reduce sum0..sum3 to sum0
  6893. GGML_F32_VEC_REDUCE(sumf, sum);
  6894. #endif
  6895. #else
  6896. const int np = 0;
  6897. #endif
  6898. // d_state
  6899. for (int i0 = np; i0 < nc; ++i0) {
  6900. const int i = i0 + ii*nc;
  6901. const int ig = i0 + g*nc;
  6902. // state = prev_state * dA + dB * x
  6903. const float state = (s0[i] * dA) + (B[ig] * x_dt);
  6904. // y = rowwise_dotprod(state, C)
  6905. sumf += state * C[ig];
  6906. s[i] = state;
  6907. }
  6908. y[ii] = sumf;
  6909. }
  6910. }
  6911. } else {
  6912. // Mamba-1 has an element-wise decay factor for the states
  6913. // n_head
  6914. for (int h = ih0; h < ih1; ++h) {
  6915. // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
  6916. const float dt_soft_plus = ggml_softplus(dt[h]);
  6917. const int g = h / (nh / ng); // repeat_interleave
  6918. // dim
  6919. for (int i1 = 0; i1 < nr; ++i1) {
  6920. const int ii = i1 + h*nr;
  6921. const float x_dt = x[ii] * dt_soft_plus;
  6922. #if defined(__ARM_FEATURE_SVE)
  6923. svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
  6924. svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
  6925. svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
  6926. // d_state
  6927. // TODO: what happens when (d_state % svcntw()) != 0?
  6928. for (int64_t k = 0; k < nc; k += svcntw()) {
  6929. svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
  6930. svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
  6931. svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
  6932. svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
  6933. svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
  6934. t1 = exp_ps_sve(svptrue_b32(), t1);
  6935. svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
  6936. vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
  6937. r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
  6938. GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
  6939. }
  6940. y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
  6941. #else
  6942. float sumf = 0.0f;
  6943. // NOTE: can't really use GGML_SIMD here because d_state is usually 16
  6944. // and also because expf is used within the loop.
  6945. // d_state
  6946. for (int i0 = 0; i0 < nc; ++i0) {
  6947. const int i = i0 + ii*nc;
  6948. const int ig = i0 + g*nc;
  6949. // state = prev_state * dA + dB * x
  6950. const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
  6951. // y = rowwise_dotprod(state, C)
  6952. sumf += state * C[ig];
  6953. s[i] = state;
  6954. }
  6955. y[ii] = sumf;
  6956. #endif
  6957. }
  6958. }
  6959. }
  6960. // use the output as the source when it's not the first token-wise iteration
  6961. s0 = s;
  6962. }
  6963. }
  6964. }
  6965. void ggml_compute_forward_ssm_scan(
  6966. const ggml_compute_params * params,
  6967. ggml_tensor * dst) {
  6968. switch (dst->src[0]->type) {
  6969. case GGML_TYPE_F32:
  6970. {
  6971. ggml_compute_forward_ssm_scan_f32(params, dst);
  6972. } break;
  6973. default:
  6974. {
  6975. GGML_ABORT("fatal error");
  6976. }
  6977. }
  6978. }
  6979. // ggml_compute_forward_win_part
  6980. static void ggml_compute_forward_win_part_f32(
  6981. const ggml_compute_params * params,
  6982. ggml_tensor * dst) {
  6983. GGML_UNUSED(params);
  6984. const ggml_tensor * src0 = dst->src[0];
  6985. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  6986. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6987. const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
  6988. const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
  6989. const int32_t w = ((const int32_t *)(dst->op_params))[2];
  6990. assert(ne00 == ne0);
  6991. assert(ne3 == nep0*nep1);
  6992. // TODO: optimize / multi-thread
  6993. for (int py = 0; py < nep1; ++py) {
  6994. for (int px = 0; px < nep0; ++px) {
  6995. const int64_t i3 = py*nep0 + px;
  6996. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  6997. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  6998. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  6999. const int64_t i02 = py*w + i2;
  7000. const int64_t i01 = px*w + i1;
  7001. const int64_t i00 = i0;
  7002. const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
  7003. const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
  7004. if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
  7005. ((float *) dst->data)[i] = 0.0f;
  7006. } else {
  7007. ((float *) dst->data)[i] = ((float *) src0->data)[j];
  7008. }
  7009. }
  7010. }
  7011. }
  7012. }
  7013. }
  7014. }
  7015. void ggml_compute_forward_win_part(
  7016. const ggml_compute_params * params,
  7017. ggml_tensor * dst) {
  7018. const ggml_tensor * src0 = dst->src[0];
  7019. switch (src0->type) {
  7020. case GGML_TYPE_F32:
  7021. {
  7022. ggml_compute_forward_win_part_f32(params, dst);
  7023. } break;
  7024. default:
  7025. {
  7026. GGML_ABORT("fatal error");
  7027. }
  7028. }
  7029. }
  7030. // ggml_compute_forward_win_unpart
  7031. static void ggml_compute_forward_win_unpart_f32(
  7032. const ggml_compute_params * params,
  7033. ggml_tensor * dst) {
  7034. GGML_UNUSED(params);
  7035. const ggml_tensor * src0 = dst->src[0];
  7036. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  7037. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  7038. const int32_t w = ((const int32_t *)(dst->op_params))[0];
  7039. // padding
  7040. const int px = (w - ne1%w)%w;
  7041. //const int py = (w - ne2%w)%w;
  7042. const int npx = (px + ne1)/w;
  7043. //const int npy = (py + ne2)/w;
  7044. assert(ne0 == ne00);
  7045. // TODO: optimize / multi-thread
  7046. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  7047. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  7048. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  7049. const int ip2 = i2/w;
  7050. const int ip1 = i1/w;
  7051. const int64_t i02 = i2%w;
  7052. const int64_t i01 = i1%w;
  7053. const int64_t i00 = i0;
  7054. const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
  7055. const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
  7056. ((float *) dst->data)[j] = ((float *) src0->data)[i];
  7057. }
  7058. }
  7059. }
  7060. }
  7061. void ggml_compute_forward_win_unpart(
  7062. const ggml_compute_params * params,
  7063. ggml_tensor * dst) {
  7064. const ggml_tensor * src0 = dst->src[0];
  7065. switch (src0->type) {
  7066. case GGML_TYPE_F32:
  7067. {
  7068. ggml_compute_forward_win_unpart_f32(params, dst);
  7069. } break;
  7070. default:
  7071. {
  7072. GGML_ABORT("fatal error");
  7073. }
  7074. }
  7075. }
  7076. //gmml_compute_forward_unary
  7077. void ggml_compute_forward_unary(
  7078. const ggml_compute_params * params,
  7079. ggml_tensor * dst) {
  7080. const ggml_unary_op op = ggml_get_unary_op(dst);
  7081. switch (op) {
  7082. case GGML_UNARY_OP_ABS:
  7083. {
  7084. ggml_compute_forward_abs(params, dst);
  7085. } break;
  7086. case GGML_UNARY_OP_SGN:
  7087. {
  7088. ggml_compute_forward_sgn(params, dst);
  7089. } break;
  7090. case GGML_UNARY_OP_NEG:
  7091. {
  7092. ggml_compute_forward_neg(params, dst);
  7093. } break;
  7094. case GGML_UNARY_OP_STEP:
  7095. {
  7096. ggml_compute_forward_step(params, dst);
  7097. } break;
  7098. case GGML_UNARY_OP_TANH:
  7099. {
  7100. ggml_compute_forward_tanh(params, dst);
  7101. } break;
  7102. case GGML_UNARY_OP_ELU:
  7103. {
  7104. ggml_compute_forward_elu(params, dst);
  7105. } break;
  7106. case GGML_UNARY_OP_RELU:
  7107. {
  7108. ggml_compute_forward_relu(params, dst);
  7109. } break;
  7110. case GGML_UNARY_OP_SIGMOID:
  7111. {
  7112. ggml_compute_forward_sigmoid(params, dst);
  7113. } break;
  7114. case GGML_UNARY_OP_GELU:
  7115. {
  7116. ggml_compute_forward_gelu(params, dst);
  7117. } break;
  7118. case GGML_UNARY_OP_GELU_ERF:
  7119. {
  7120. ggml_compute_forward_gelu_erf(params, dst);
  7121. } break;
  7122. case GGML_UNARY_OP_GELU_QUICK:
  7123. {
  7124. ggml_compute_forward_gelu_quick(params, dst);
  7125. } break;
  7126. case GGML_UNARY_OP_SILU:
  7127. {
  7128. ggml_compute_forward_silu(params, dst);
  7129. } break;
  7130. case GGML_UNARY_OP_HARDSWISH:
  7131. {
  7132. ggml_compute_forward_hardswish(params, dst);
  7133. } break;
  7134. case GGML_UNARY_OP_HARDSIGMOID:
  7135. {
  7136. ggml_compute_forward_hardsigmoid(params, dst);
  7137. } break;
  7138. case GGML_UNARY_OP_EXP:
  7139. {
  7140. ggml_compute_forward_exp(params, dst);
  7141. } break;
  7142. case GGML_UNARY_OP_FLOOR:
  7143. {
  7144. ggml_compute_forward_floor(params, dst);
  7145. } break;
  7146. case GGML_UNARY_OP_CEIL:
  7147. {
  7148. ggml_compute_forward_ceil(params, dst);
  7149. } break;
  7150. case GGML_UNARY_OP_ROUND:
  7151. {
  7152. ggml_compute_forward_round(params, dst);
  7153. } break;
  7154. case GGML_UNARY_OP_TRUNC:
  7155. {
  7156. ggml_compute_forward_trunc(params, dst);
  7157. } break;
  7158. case GGML_UNARY_OP_XIELU:
  7159. {
  7160. ggml_compute_forward_xielu(params, dst);
  7161. } break;
  7162. default:
  7163. {
  7164. GGML_ABORT("fatal error");
  7165. }
  7166. }
  7167. }
  7168. //ggml_compute_forward_glu
  7169. void ggml_compute_forward_glu(
  7170. const ggml_compute_params * params,
  7171. ggml_tensor * dst) {
  7172. const ggml_glu_op op = ggml_get_glu_op(dst);
  7173. switch (op) {
  7174. case GGML_GLU_OP_REGLU:
  7175. {
  7176. ggml_compute_forward_reglu(params, dst);
  7177. } break;
  7178. case GGML_GLU_OP_GEGLU:
  7179. {
  7180. ggml_compute_forward_geglu(params, dst);
  7181. } break;
  7182. case GGML_GLU_OP_SWIGLU:
  7183. {
  7184. ggml_compute_forward_swiglu(params, dst);
  7185. } break;
  7186. case GGML_GLU_OP_SWIGLU_OAI:
  7187. {
  7188. ggml_compute_forward_swiglu_oai(params, dst);
  7189. } break;
  7190. case GGML_GLU_OP_GEGLU_ERF:
  7191. {
  7192. ggml_compute_forward_geglu_erf(params, dst);
  7193. } break;
  7194. case GGML_GLU_OP_GEGLU_QUICK:
  7195. {
  7196. ggml_compute_forward_geglu_quick(params, dst);
  7197. } break;
  7198. default:
  7199. {
  7200. GGML_ABORT("fatal error");
  7201. }
  7202. }
  7203. }
  7204. // ggml_compute_forward_get_rel_pos
  7205. static void ggml_compute_forward_get_rel_pos_f16(
  7206. const ggml_compute_params * params,
  7207. ggml_tensor * dst) {
  7208. GGML_UNUSED(params);
  7209. const ggml_tensor * src0 = dst->src[0];
  7210. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
  7211. GGML_TENSOR_UNARY_OP_LOCALS
  7212. const int64_t w = ne1;
  7213. ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
  7214. ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
  7215. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  7216. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  7217. const int64_t pos = (w - i1 - 1) + i2;
  7218. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  7219. dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
  7220. }
  7221. }
  7222. }
  7223. }
  7224. void ggml_compute_forward_get_rel_pos(
  7225. const ggml_compute_params * params,
  7226. ggml_tensor * dst) {
  7227. const ggml_tensor * src0 = dst->src[0];
  7228. switch (src0->type) {
  7229. case GGML_TYPE_F16:
  7230. case GGML_TYPE_BF16:
  7231. {
  7232. ggml_compute_forward_get_rel_pos_f16(params, dst);
  7233. } break;
  7234. default:
  7235. {
  7236. GGML_ABORT("fatal error");
  7237. }
  7238. }
  7239. }
  7240. // ggml_compute_forward_add_rel_pos
  7241. static void ggml_compute_forward_add_rel_pos_f32(
  7242. const ggml_compute_params * params,
  7243. ggml_tensor * dst) {
  7244. const ggml_tensor * src0 = dst->src[0];
  7245. const ggml_tensor * src1 = dst->src[1];
  7246. const ggml_tensor * src2 = dst->src[2];
  7247. const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
  7248. if (!inplace) {
  7249. if (params->ith == 0) {
  7250. memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
  7251. }
  7252. ggml_barrier(params->threadpool);
  7253. }
  7254. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
  7255. float * src1_data = (float *) src1->data;
  7256. float * src2_data = (float *) src2->data;
  7257. float * dst_data = (float *) dst->data;
  7258. const int64_t ne10 = src1->ne[0];
  7259. const int64_t ne11 = src1->ne[1];
  7260. const int64_t ne12 = src1->ne[2];
  7261. const int64_t ne13 = src1->ne[3];
  7262. const int ith = params->ith;
  7263. const int nth = params->nth;
  7264. // total patches in dst
  7265. const int np = ne13;
  7266. // patches per thread
  7267. const int dp = (np + nth - 1)/nth;
  7268. // patch range for this thread
  7269. const int ip0 = dp*ith;
  7270. const int ip1 = MIN(ip0 + dp, np);
  7271. for (int64_t i13 = ip0; i13 < ip1; ++i13) {
  7272. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  7273. for (int64_t i11 = 0; i11 < ne11; ++i11) {
  7274. const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
  7275. for (int64_t i10 = 0; i10 < ne10; ++i10) {
  7276. const int64_t jp0 = jp1 + i10;
  7277. const float src1_e = src1_data[jp0];
  7278. const float src2_e = src2_data[jp0];
  7279. const int64_t jdh = jp0 * ne10;
  7280. const int64_t jdw = jdh - (ne10 - 1) * i10;
  7281. for (int64_t j = 0; j < ne10; ++j) {
  7282. dst_data[jdh + j ] += src2_e;
  7283. dst_data[jdw + j*ne10] += src1_e;
  7284. }
  7285. }
  7286. }
  7287. }
  7288. }
  7289. }
  7290. void ggml_compute_forward_add_rel_pos(
  7291. const ggml_compute_params * params,
  7292. ggml_tensor * dst) {
  7293. const ggml_tensor * src0 = dst->src[0];
  7294. switch (src0->type) {
  7295. case GGML_TYPE_F32:
  7296. {
  7297. ggml_compute_forward_add_rel_pos_f32(params, dst);
  7298. } break;
  7299. default:
  7300. {
  7301. GGML_ABORT("fatal error");
  7302. }
  7303. }
  7304. }
  7305. // ggml_compute_forward_rwkv_wkv6
  7306. static void ggml_compute_forward_rwkv_wkv6_f32(
  7307. const ggml_compute_params * params,
  7308. ggml_tensor * dst) {
  7309. const int64_t T = dst->src[1]->ne[2];
  7310. const int64_t C = dst->ne[0];
  7311. const int64_t HEADS = dst->src[1]->ne[1];
  7312. const int64_t n_seqs = dst->src[5]->ne[1];
  7313. const int64_t head_size = C / HEADS;
  7314. float * dst_data = (float *) dst->data;
  7315. float * state = ((float *) dst->data) + C * T;
  7316. const int ith = params->ith;
  7317. const int nth = params->nth;
  7318. if (ith >= HEADS) {
  7319. return;
  7320. }
  7321. const int h_start = (HEADS * ith) / nth;
  7322. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  7323. (HEADS * (ith + 1)) / nth : HEADS;
  7324. float * k = (float *) dst->src[0]->data;
  7325. float * v = (float *) dst->src[1]->data;
  7326. float * r = (float *) dst->src[2]->data;
  7327. float * time_faaaa = (float *) dst->src[3]->data;
  7328. float * time_decay = (float *) dst->src[4]->data;
  7329. size_t t_stride = HEADS * head_size; // Same to C
  7330. size_t h_stride = C / HEADS;
  7331. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  7332. size_t h_stride_2d = head_size * head_size;
  7333. if (ith == 0) {
  7334. memset(dst_data, 0, T * C * sizeof(float));
  7335. }
  7336. ggml_barrier(params->threadpool);
  7337. #if defined(__AVX__) && !defined(__AVX512F__)
  7338. #define GGML_F32X GGML_F32x8
  7339. #define GGML_F32X_SET1 GGML_F32x8_SET1
  7340. #define GGML_F32X_LOAD GGML_F32x8_LOAD
  7341. #define GGML_F32X_STORE GGML_F32x8_STORE
  7342. #define GGML_F32X_MUL GGML_F32x8_MUL
  7343. #define GGML_F32X_FMA GGML_F32x8_FMA
  7344. #define WKV_VECTOR_SIZE 8
  7345. #elif defined(__AVX512F__)
  7346. #define GGML_F32X GGML_F32x16
  7347. #define GGML_F32X_SET1 GGML_F32x16_SET1
  7348. #define GGML_F32X_LOAD GGML_F32x16_LOAD
  7349. #define GGML_F32X_STORE GGML_F32x16_STORE
  7350. #define GGML_F32X_MUL GGML_F32x16_MUL
  7351. #define GGML_F32X_FMA GGML_F32x16_FMA
  7352. #define WKV_VECTOR_SIZE 16
  7353. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  7354. #define GGML_F32X GGML_F32xt
  7355. #define GGML_F32X_SET1 GGML_F32xt_SET1
  7356. #define GGML_F32X_LOAD GGML_F32xt_LOAD
  7357. #define GGML_F32X_STORE GGML_F32xt_STORE
  7358. #define GGML_F32X_MUL GGML_F32xt_MUL
  7359. #define GGML_F32X_FMA GGML_F32xt_FMA
  7360. #define WKV_VECTOR_SIZE 8
  7361. #elif defined(__ARM_NEON) && defined(__aarch64__)
  7362. #define GGML_F32X GGML_F32x4
  7363. #define GGML_F32X_SET1 GGML_F32x4_SET1
  7364. #define GGML_F32X_LOAD GGML_F32x4_LOAD
  7365. #define GGML_F32X_STORE GGML_F32x4_STORE
  7366. #define GGML_F32X_MUL GGML_F32x4_MUL
  7367. #define GGML_F32X_FMA GGML_F32x4_FMA
  7368. #define WKV_VECTOR_SIZE 4
  7369. #endif
  7370. #ifdef WKV_VECTOR_SIZE
  7371. int wkv_vector_size;
  7372. #if defined(__ARM_FEATURE_SVE)
  7373. wkv_vector_size = svcntw();
  7374. #else
  7375. wkv_vector_size = WKV_VECTOR_SIZE;
  7376. #endif
  7377. const int64_t vec_count = head_size / wkv_vector_size;
  7378. for (int64_t t = 0; t < T; t++) {
  7379. size_t t_offset = t * t_stride;
  7380. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7381. float * state_cur = state + state_offset;
  7382. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  7383. for (int64_t h = h_start; h < h_end; h++) {
  7384. size_t h_offset = h * h_stride;
  7385. size_t t_h_offset = t_offset + h_offset;
  7386. size_t h_2d_offset = h * h_stride_2d;
  7387. for (int64_t i = 0; i < head_size; i++) {
  7388. size_t t_h_i_offset = t_h_offset + i;
  7389. size_t h_i_offset = h_offset + i;
  7390. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7391. float k_val = k[t_h_i_offset];
  7392. float r_val = r[t_h_i_offset];
  7393. float time_faaaa_val = time_faaaa[h_i_offset];
  7394. float time_decay_val = time_decay[t_h_i_offset];
  7395. // Broadcast scalar values to vectors
  7396. GGML_F32X k_vec = GGML_F32X_SET1(k_val);
  7397. GGML_F32X r_vec = GGML_F32X_SET1(r_val);
  7398. GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
  7399. GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
  7400. for (int64_t j = 0; j < vec_count; j++) {
  7401. size_t base_j = j * wkv_vector_size;
  7402. size_t t_h_j_offset = t_h_offset + base_j;
  7403. size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
  7404. // Load x elements at once
  7405. GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
  7406. GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
  7407. GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
  7408. // Compute kv = v * k
  7409. GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
  7410. // Compute temp = kv * time_faaaa + prev_state
  7411. GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
  7412. // Update dst: dst += temp * r
  7413. dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
  7414. GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
  7415. // Update state: state = prev_state * time_decay + kv
  7416. GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
  7417. GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
  7418. }
  7419. // Handle remaining elements, this will not be used.
  7420. for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
  7421. size_t t_h_j_offset = t_h_offset + j;
  7422. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7423. float v_val = v[t_h_j_offset];
  7424. float kv_val = v_val * k_val;
  7425. float prev_state_val = state_prev[h_2d_i_j_offset];
  7426. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  7427. dst_data[t_h_j_offset] += temp_val * r_val;
  7428. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  7429. }
  7430. }
  7431. }
  7432. }
  7433. #else
  7434. // basically fused operations:
  7435. // dst = r @ (time_faaaa * (k @ v) + state),
  7436. // state = time_decay * state + (k @ v),
  7437. // recursive through each token
  7438. for (int64_t t = 0; t < T; t++) {
  7439. size_t t_offset = t * t_stride;
  7440. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7441. float * state_cur = state + state_offset;
  7442. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  7443. for (int64_t h = h_start; h < h_end; h++) {
  7444. size_t h_offset = h * h_stride;
  7445. size_t t_h_offset = t_offset + h_offset;
  7446. size_t h_2d_offset = h * h_stride_2d;
  7447. for (int64_t i = 0; i < head_size; i++) {
  7448. size_t t_h_i_offset = t_h_offset + i;
  7449. size_t h_i_offset = h_offset + i;
  7450. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7451. float k_val = k[t_h_i_offset];
  7452. float r_val = r[t_h_i_offset];
  7453. float time_faaaa_val = time_faaaa[h_i_offset];
  7454. // RWKV v6: different time_decay for each token.
  7455. float time_decay_val = time_decay[t_h_i_offset];
  7456. for (int64_t j = 0; j < head_size; j++) {
  7457. size_t t_h_j_offset = t_h_offset + j;
  7458. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7459. float v_val = v[t_h_j_offset];
  7460. float kv_val = v_val * k_val;
  7461. float prev_state_val = state_prev[h_2d_i_j_offset];
  7462. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  7463. dst_data[t_h_j_offset] += temp_val * r_val;
  7464. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  7465. }
  7466. }
  7467. }
  7468. }
  7469. #endif
  7470. }
  7471. void ggml_compute_forward_rwkv_wkv6(
  7472. const ggml_compute_params * params,
  7473. ggml_tensor * dst) {
  7474. const ggml_tensor * src0 = dst->src[0];
  7475. switch (src0->type) {
  7476. case GGML_TYPE_F32:
  7477. {
  7478. ggml_compute_forward_rwkv_wkv6_f32(params, dst);
  7479. } break;
  7480. default:
  7481. {
  7482. GGML_ABORT("fatal error");
  7483. }
  7484. }
  7485. }
  7486. // ggml_compute_forward_gla
  7487. static void ggml_compute_forward_gla_f32(
  7488. const ggml_compute_params * params,
  7489. ggml_tensor * dst) {
  7490. const int64_t T = dst->src[1]->ne[2];
  7491. const int64_t C = dst->ne[0];
  7492. const int64_t HEADS = dst->src[1]->ne[1];
  7493. const int64_t n_seqs = dst->src[4]->ne[1];
  7494. const int64_t head_size = C / HEADS;
  7495. const float scale = ggml_get_op_params_f32(dst, 0);
  7496. float * dst_data = (float *) dst->data;
  7497. float * state = ((float *) dst->data) + C * T;
  7498. const int ith = params->ith;
  7499. const int nth = params->nth;
  7500. if (ith >= HEADS) {
  7501. return;
  7502. }
  7503. const int h_start = (HEADS * ith) / nth;
  7504. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  7505. (HEADS * (ith + 1)) / nth : HEADS;
  7506. float * k = (float *) dst->src[0]->data;
  7507. float * v = (float *) dst->src[1]->data;
  7508. float * q = (float *) dst->src[2]->data;
  7509. float * g = (float *) dst->src[3]->data;
  7510. size_t t_stride = HEADS * head_size; // Same to C
  7511. size_t h_stride = C / HEADS;
  7512. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  7513. size_t h_stride_2d = head_size * head_size;
  7514. if (ith == 0) {
  7515. memset(dst_data, 0, T * C * sizeof(float));
  7516. }
  7517. ggml_barrier(params->threadpool);
  7518. #if defined(__AVX__) && !defined(__AVX512F__)
  7519. #define GGML_F32X GGML_F32x8
  7520. #define GGML_F32X_SET1 GGML_F32x8_SET1
  7521. #define GGML_F32X_LOAD GGML_F32x8_LOAD
  7522. #define GGML_F32X_STORE GGML_F32x8_STORE
  7523. #define GGML_F32X_MUL GGML_F32x8_MUL
  7524. #define GGML_F32X_FMA GGML_F32x8_FMA
  7525. #define GLA_VECTOR_SIZE 8
  7526. #elif defined(__AVX512F__)
  7527. #define GGML_F32X GGML_F32x16
  7528. #define GGML_F32X_SET1 GGML_F32x16_SET1
  7529. #define GGML_F32X_LOAD GGML_F32x16_LOAD
  7530. #define GGML_F32X_STORE GGML_F32x16_STORE
  7531. #define GGML_F32X_MUL GGML_F32x16_MUL
  7532. #define GGML_F32X_FMA GGML_F32x16_FMA
  7533. #define GLA_VECTOR_SIZE 16
  7534. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  7535. #define GGML_F32X GGML_F32xt
  7536. #define GGML_F32X_SET1 GGML_F32xt_SET1
  7537. #define GGML_F32X_LOAD GGML_F32xt_LOAD
  7538. #define GGML_F32X_STORE GGML_F32xt_STORE
  7539. #define GGML_F32X_MUL GGML_F32xt_MUL
  7540. #define GGML_F32X_FMA GGML_F32xt_FMA
  7541. #define GLA_VECTOR_SIZE 8
  7542. #elif defined(__ARM_NEON) && defined(__aarch64__)
  7543. #define GGML_F32X GGML_F32x4
  7544. #define GGML_F32X_SET1 GGML_F32x4_SET1
  7545. #define GGML_F32X_LOAD GGML_F32x4_LOAD
  7546. #define GGML_F32X_STORE GGML_F32x4_STORE
  7547. #define GGML_F32X_MUL GGML_F32x4_MUL
  7548. #define GGML_F32X_FMA GGML_F32x4_FMA
  7549. #define GLA_VECTOR_SIZE 4
  7550. #endif
  7551. #ifdef GLA_VECTOR_SIZE
  7552. int gla_vector_size;
  7553. #if defined(__ARM_FEATURE_SVE)
  7554. gla_vector_size = svcntw();
  7555. #else
  7556. gla_vector_size = GLA_VECTOR_SIZE;
  7557. #endif
  7558. const int64_t vec_count = head_size / gla_vector_size;
  7559. for (int64_t t = 0; t < T; t++) {
  7560. size_t t_offset = t * t_stride;
  7561. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7562. float * state_cur = state + state_offset;
  7563. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
  7564. for (int64_t h = h_start; h < h_end; h++) {
  7565. size_t h_offset = h * h_stride;
  7566. size_t t_h_offset = t_offset + h_offset;
  7567. size_t h_2d_offset = h * h_stride_2d;
  7568. for (int64_t i = 0; i < head_size; i++) {
  7569. size_t t_h_i_offset = t_h_offset + i;
  7570. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7571. float k_val = k[t_h_i_offset];
  7572. float q_val = q[t_h_i_offset] * scale;
  7573. float g_val = g[t_h_i_offset];
  7574. // Broadcast scalar values to vectors
  7575. GGML_F32X k_vec = GGML_F32X_SET1(k_val);
  7576. GGML_F32X q_vec = GGML_F32X_SET1(q_val);
  7577. GGML_F32X g_vec = GGML_F32X_SET1(g_val);
  7578. for (int64_t j = 0; j < vec_count; j++) {
  7579. size_t base_j = j * gla_vector_size;
  7580. size_t t_h_j_offset = t_h_offset + base_j;
  7581. size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
  7582. // Load x elements at once
  7583. GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
  7584. GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
  7585. GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
  7586. // Compute kv = v * k
  7587. GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
  7588. // Compute temp = prev_state * g + kv
  7589. GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
  7590. // Update dst: dst += temp * q
  7591. dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
  7592. GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
  7593. // Update state
  7594. GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
  7595. }
  7596. // Handle remaining elements, this will not be used.
  7597. for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
  7598. size_t t_h_j_offset = t_h_offset + j;
  7599. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7600. float v_val = v[t_h_j_offset];
  7601. float kv_val = v_val * k_val;
  7602. float prev_state_val = state_prev[h_2d_i_j_offset];
  7603. float temp_val = kv_val + prev_state_val * g_val;
  7604. dst_data[t_h_j_offset] += temp_val * q_val;
  7605. state_cur[h_2d_i_j_offset] = temp_val;
  7606. }
  7607. }
  7608. }
  7609. }
  7610. #else
  7611. for (int64_t t = 0; t < T; t++) {
  7612. size_t t_offset = t * t_stride;
  7613. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7614. float * state_cur = state + state_offset;
  7615. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
  7616. for (int64_t h = h_start; h < h_end; h++) {
  7617. size_t h_offset = h * h_stride;
  7618. size_t t_h_offset = t_offset + h_offset;
  7619. size_t h_2d_offset = h * h_stride_2d;
  7620. for (int64_t i = 0; i < head_size; i++) {
  7621. size_t t_h_i_offset = t_h_offset + i;
  7622. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7623. float k_val = k[t_h_i_offset];
  7624. float q_val = q[t_h_i_offset] * scale;
  7625. float g_val = g[t_h_i_offset];
  7626. for (int64_t j = 0; j < head_size; j++) {
  7627. size_t t_h_j_offset = t_h_offset + j;
  7628. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7629. float v_val = v[t_h_j_offset];
  7630. float kv_val = v_val * k_val;
  7631. float prev_state_val = state_prev[h_2d_i_j_offset];
  7632. float temp_val = prev_state_val * g_val + kv_val;
  7633. dst_data[t_h_j_offset] += temp_val * q_val;
  7634. state_cur[h_2d_i_j_offset] = temp_val;
  7635. }
  7636. }
  7637. }
  7638. }
  7639. #endif
  7640. }
  7641. void ggml_compute_forward_gla(
  7642. const ggml_compute_params * params,
  7643. ggml_tensor * dst) {
  7644. const ggml_tensor * src0 = dst->src[0];
  7645. switch (src0->type) {
  7646. case GGML_TYPE_F32:
  7647. {
  7648. ggml_compute_forward_gla_f32(params, dst);
  7649. } break;
  7650. default:
  7651. {
  7652. GGML_ABORT("fatal error");
  7653. }
  7654. }
  7655. }
  7656. // ggml_compute_forward_rwkv_wkv7
  7657. static void ggml_compute_forward_rwkv_wkv7_f32(
  7658. const ggml_compute_params * params,
  7659. ggml_tensor * dst) {
  7660. const int64_t T = dst->src[1]->ne[2];
  7661. const int64_t C = dst->ne[0];
  7662. const int64_t HEADS = dst->src[1]->ne[1];
  7663. const int64_t n_seqs = dst->src[6]->ne[1];
  7664. const int64_t head_size = C / HEADS;
  7665. float * dst_data = (float *) dst->data;
  7666. float * state = ((float *) dst->data) + C * T;
  7667. const int ith = params->ith;
  7668. const int nth = params->nth;
  7669. if (ith >= HEADS) {
  7670. return;
  7671. }
  7672. const int h_start = (HEADS * ith) / nth;
  7673. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  7674. (HEADS * (ith + 1)) / nth : HEADS;
  7675. float * r = (float *) dst->src[0]->data;
  7676. float * w = (float *) dst->src[1]->data;
  7677. float * k = (float *) dst->src[2]->data;
  7678. float * v = (float *) dst->src[3]->data;
  7679. float * a = (float *) dst->src[4]->data;
  7680. float * b = (float *) dst->src[5]->data;
  7681. int64_t t_stride = HEADS * head_size; // Same to C
  7682. int64_t h_stride = C / HEADS;
  7683. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  7684. int64_t h_stride_2d = head_size * head_size;
  7685. #if defined(GGML_SIMD)
  7686. #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
  7687. // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
  7688. for (int64_t t = 0; t < T; t++) {
  7689. int64_t t_offset = t * t_stride;
  7690. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  7691. float * state_cur = state + state_offset;
  7692. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  7693. for (int64_t h = h_start; h < h_end; h++) {
  7694. int64_t h_offset = h * h_stride;
  7695. int64_t t_h_offset = t_offset + h_offset;
  7696. int64_t h_2d_offset = h * h_stride_2d;
  7697. for (int64_t i = 0; i < head_size; i++) {
  7698. int64_t t_h_i_offset = t_h_offset + i;
  7699. int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7700. float v_val = v[t_h_i_offset];
  7701. float sa = 0, result = 0;
  7702. for (int64_t j = 0; j < head_size; j++) {
  7703. sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
  7704. }
  7705. for (int64_t j = 0; j < head_size; j++) {
  7706. int64_t t_h_j_offset = t_h_offset + j;
  7707. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  7708. float r_val = r[t_h_j_offset];
  7709. float w_val = w[t_h_j_offset];
  7710. float k_val = k[t_h_j_offset];
  7711. float b_val = b[t_h_j_offset];
  7712. float kv_val = v_val * k_val;
  7713. float prev_state_val = state_prev[h_2d_i_j_offset];
  7714. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  7715. result += state_cur[h_2d_i_j_offset] * r_val;
  7716. }
  7717. dst_data[t_h_i_offset] = result;
  7718. }
  7719. }
  7720. }
  7721. #else
  7722. for (int64_t t = 0; t < T; t++) {
  7723. int64_t t_offset = t * t_stride;
  7724. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  7725. float * state_cur = state + state_offset;
  7726. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  7727. for (int64_t h = h_start; h < h_end; h++) {
  7728. int64_t h_offset = h * h_stride;
  7729. int64_t t_h_offset = t_offset + h_offset;
  7730. int64_t h_2d_offset = h * h_stride_2d;
  7731. for (int64_t ii = 0; ii < head_size; ii++) {
  7732. int64_t t_h_i_offset = t_h_offset + ii;
  7733. int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
  7734. GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
  7735. float sa = 0;
  7736. {
  7737. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  7738. GGML_F32_VEC ax[GGML_F32_ARR];
  7739. GGML_F32_VEC ay[GGML_F32_ARR];
  7740. for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
  7741. for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
  7742. ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
  7743. ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
  7744. sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
  7745. }
  7746. }
  7747. GGML_F32_VEC_REDUCE(sa, sum);
  7748. }
  7749. GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
  7750. int64_t j = 0;
  7751. GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  7752. for (; j < head_size; j += GGML_F32_STEP) {
  7753. for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
  7754. int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
  7755. int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
  7756. GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
  7757. GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
  7758. GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
  7759. GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
  7760. k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
  7761. GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
  7762. // kv + s * decay + sa * b
  7763. state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
  7764. state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
  7765. GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
  7766. result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
  7767. }
  7768. }
  7769. GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
  7770. // There shouldn't be left-overs though.
  7771. for (; j < head_size; j++) {
  7772. int64_t t_h_j_offset = t_h_offset + j;
  7773. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  7774. float r_val = r[t_h_j_offset];
  7775. float w_val = w[t_h_j_offset];
  7776. float k_val = k[t_h_j_offset];
  7777. float b_val = b[t_h_j_offset];
  7778. float kv_val = v[t_h_i_offset] * k_val;
  7779. float prev_state_val = state_prev[h_2d_i_j_offset];
  7780. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  7781. dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
  7782. }
  7783. }
  7784. }
  7785. }
  7786. #endif
  7787. #else
  7788. for (int64_t t = 0; t < T; t++) {
  7789. int64_t t_offset = t * t_stride;
  7790. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  7791. float * state_cur = state + state_offset;
  7792. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  7793. for (int64_t h = h_start; h < h_end; h++) {
  7794. int64_t h_offset = h * h_stride;
  7795. int64_t t_h_offset = t_offset + h_offset;
  7796. int64_t h_2d_offset = h * h_stride_2d;
  7797. for (int64_t i = 0; i < head_size; i++) {
  7798. int64_t t_h_i_offset = t_h_offset + i;
  7799. int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7800. float v_val = v[t_h_i_offset];
  7801. float sa = 0, result = 0;
  7802. for (int64_t j = 0; j < head_size; j++) {
  7803. sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
  7804. }
  7805. for (int64_t j = 0; j < head_size; j++) {
  7806. int64_t t_h_j_offset = t_h_offset + j;
  7807. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  7808. float r_val = r[t_h_j_offset];
  7809. float w_val = w[t_h_j_offset];
  7810. float k_val = k[t_h_j_offset];
  7811. float b_val = b[t_h_j_offset];
  7812. float kv_val = v_val * k_val;
  7813. float prev_state_val = state_prev[h_2d_i_j_offset];
  7814. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  7815. result += state_cur[h_2d_i_j_offset] * r_val;
  7816. }
  7817. dst_data[t_h_i_offset] = result;
  7818. }
  7819. }
  7820. }
  7821. #endif
  7822. }
  7823. void ggml_compute_forward_rwkv_wkv7(
  7824. const ggml_compute_params * params,
  7825. ggml_tensor * dst) {
  7826. const ggml_tensor * src0 = dst->src[0];
  7827. switch (src0->type) {
  7828. case GGML_TYPE_F32:
  7829. {
  7830. ggml_compute_forward_rwkv_wkv7_f32(params, dst);
  7831. } break;
  7832. default:
  7833. {
  7834. GGML_ABORT("fatal error");
  7835. }
  7836. }
  7837. }
  7838. // ggml_compute_forward_map_custom1
  7839. void ggml_compute_forward_map_custom1(
  7840. const ggml_compute_params * params,
  7841. ggml_tensor * dst) {
  7842. const ggml_tensor * a = dst->src[0];
  7843. struct ggml_map_custom1_op_params p;
  7844. memcpy(&p, dst->op_params, sizeof(p));
  7845. p.fun(dst, a, params->ith, params->nth, p.userdata);
  7846. }
  7847. // ggml_compute_forward_map_custom2
  7848. void ggml_compute_forward_map_custom2(
  7849. const ggml_compute_params * params,
  7850. ggml_tensor * dst) {
  7851. const ggml_tensor * a = dst->src[0];
  7852. const ggml_tensor * b = dst->src[1];
  7853. struct ggml_map_custom2_op_params p;
  7854. memcpy(&p, dst->op_params, sizeof(p));
  7855. p.fun(dst, a, b, params->ith, params->nth, p.userdata);
  7856. }
  7857. // ggml_compute_forward_map_custom3
  7858. void ggml_compute_forward_map_custom3(
  7859. const ggml_compute_params * params,
  7860. ggml_tensor * dst) {
  7861. const ggml_tensor * a = dst->src[0];
  7862. const ggml_tensor * b = dst->src[1];
  7863. const ggml_tensor * c = dst->src[2];
  7864. struct ggml_map_custom3_op_params p;
  7865. memcpy(&p, dst->op_params, sizeof(p));
  7866. p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
  7867. }
  7868. // ggml_compute_forward_custom
  7869. void ggml_compute_forward_custom(
  7870. const struct ggml_compute_params * params,
  7871. struct ggml_tensor * dst) {
  7872. struct ggml_custom_op_params p;
  7873. memcpy(&p, dst->op_params, sizeof(p));
  7874. p.fun(dst, params->ith, params->nth, p.userdata);
  7875. }
  7876. // ggml_compute_forward_cross_entropy_loss
  7877. static void ggml_compute_forward_cross_entropy_loss_f32(
  7878. const ggml_compute_params * params,
  7879. ggml_tensor * dst) {
  7880. const ggml_tensor * src0 = dst->src[0];
  7881. const ggml_tensor * src1 = dst->src[1];
  7882. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  7883. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  7884. GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
  7885. GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
  7886. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  7887. GGML_ASSERT(ggml_is_scalar(dst));
  7888. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  7889. // TODO: handle transposed/permuted matrices
  7890. const int64_t nc = src0->ne[0];
  7891. const int64_t nr = ggml_nrows(src0);
  7892. const int ith = params->ith;
  7893. const int nth = params->nth;
  7894. float * sums = (float *) params->wdata;
  7895. float * st = ((float *) params->wdata) + nth + ith*nc;
  7896. float sum_thread = 0.0f;
  7897. GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
  7898. // rows per thread
  7899. const int64_t dr = (nr + nth - 1)/nth;
  7900. // row range for this thread
  7901. const int64_t ir0 = dr*ith;
  7902. const int64_t ir1 = MIN(ir0 + dr, nr);
  7903. for (int64_t i1 = ir0; i1 < ir1; ++i1) {
  7904. const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
  7905. const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
  7906. #ifndef NDEBUG
  7907. for (int64_t i = 0; i < nc; ++i) {
  7908. //printf("p[%d] = %f\n", i, p[i]);
  7909. assert(!isnan(s0[i]));
  7910. assert(!isnan(s1[i]));
  7911. }
  7912. #endif
  7913. float max = -INFINITY;
  7914. ggml_vec_max_f32(nc, &max, s0);
  7915. const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
  7916. assert(sum_softmax >= 0.0);
  7917. ggml_vec_add1_f32(nc, st, st, -sum_softmax);
  7918. ggml_vec_mul_f32(nc, st, st, s1);
  7919. float sum_st = 0.0f;
  7920. ggml_vec_sum_f32(nc, &sum_st, st);
  7921. sum_thread += sum_st;
  7922. #ifndef NDEBUG
  7923. for (int64_t i = 0; i < nc; ++i) {
  7924. assert(!isnan(st[i]));
  7925. assert(!isinf(st[i]));
  7926. }
  7927. #endif
  7928. }
  7929. sums[ith] = sum_thread;
  7930. ggml_barrier(params->threadpool);
  7931. if (ith == 0) {
  7932. float * dp = (float *) dst->data;
  7933. ggml_vec_sum_f32(nth, dp, sums);
  7934. dp[0] *= -1.0f / (float) nr;
  7935. }
  7936. }
  7937. void ggml_compute_forward_cross_entropy_loss(
  7938. const ggml_compute_params * params,
  7939. ggml_tensor * dst) {
  7940. const ggml_tensor * src0 = dst->src[0];
  7941. switch (src0->type) {
  7942. case GGML_TYPE_F32:
  7943. {
  7944. ggml_compute_forward_cross_entropy_loss_f32(params, dst);
  7945. } break;
  7946. default:
  7947. {
  7948. GGML_ABORT("fatal error");
  7949. }
  7950. }
  7951. }
  7952. // ggml_compute_forward_cross_entropy_loss_back
  7953. static void ggml_compute_forward_cross_entropy_loss_back_f32(
  7954. const ggml_compute_params * params,
  7955. ggml_tensor * dst) {
  7956. const ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
  7957. const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
  7958. const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
  7959. GGML_ASSERT(ggml_is_contiguous(dst));
  7960. GGML_ASSERT(ggml_is_contiguous(src0f));
  7961. GGML_ASSERT(ggml_is_contiguous(src1f));
  7962. GGML_ASSERT(ggml_is_contiguous(grad));
  7963. GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
  7964. const int64_t ith = params->ith;
  7965. const int64_t nth = params->nth;
  7966. // TODO: handle transposed/permuted matrices
  7967. const int64_t nc = src0f->ne[0];
  7968. const int64_t nr = ggml_nrows(src0f);
  7969. // rows per thread
  7970. const int64_t dr = (nr + nth - 1)/nth;
  7971. // row range for this thread
  7972. const int64_t ir0 = dr*ith;
  7973. const int64_t ir1 = MIN(ir0 + dr, nr);
  7974. const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
  7975. for (int64_t i1 = ir0; i1 < ir1; i1++) {
  7976. float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
  7977. const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
  7978. const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
  7979. #ifndef NDEBUG
  7980. for (int64_t i = 0; i < nc; ++i) {
  7981. //printf("p[%d] = %f\n", i, p[i]);
  7982. assert(!isnan(s0[i]));
  7983. assert(!isnan(s1[i]));
  7984. }
  7985. #endif
  7986. // soft_max
  7987. float max = -INFINITY;
  7988. ggml_vec_max_f32(nc, &max, s0);
  7989. const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
  7990. assert(sum > 0.0);
  7991. ggml_vec_scale_f32(nc, ds0, 1.0/sum);
  7992. // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
  7993. ggml_vec_sub_f32(nc, ds0, ds0, s1);
  7994. ggml_vec_scale_f32(nc, ds0, d_by_nr);
  7995. #ifndef NDEBUG
  7996. for (int64_t i = 0; i < nc; ++i) {
  7997. assert(!isnan(ds0[i]));
  7998. assert(!isinf(ds0[i]));
  7999. }
  8000. #endif
  8001. }
  8002. }
  8003. void ggml_compute_forward_cross_entropy_loss_back(
  8004. const ggml_compute_params * params,
  8005. ggml_tensor * dst) {
  8006. const ggml_tensor * src0 = dst->src[0];
  8007. switch (src0->type) {
  8008. case GGML_TYPE_F32:
  8009. {
  8010. ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
  8011. } break;
  8012. default:
  8013. {
  8014. GGML_ABORT("fatal error");
  8015. }
  8016. }
  8017. }
  8018. static void ggml_compute_forward_opt_step_adamw_f32(
  8019. const ggml_compute_params * params,
  8020. ggml_tensor * dst) {
  8021. const ggml_tensor * src0 = dst->src[0];
  8022. const ggml_tensor * src0_grad = dst->src[1];
  8023. const ggml_tensor * src0_grad_m = dst->src[2];
  8024. const ggml_tensor * src0_grad_v = dst->src[3];
  8025. const ggml_tensor * adamw_params = dst->src[4];
  8026. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
  8027. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
  8028. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
  8029. GGML_ASSERT(ggml_nelements(adamw_params) == 7);
  8030. const int ith = params->ith;
  8031. const int nth = params->nth;
  8032. const int nr = ggml_nrows(src0);
  8033. GGML_TENSOR_UNARY_OP_LOCALS
  8034. GGML_ASSERT(nb00 == sizeof(float));
  8035. // rows per thread
  8036. const int dr = (nr + nth - 1)/nth;
  8037. // row range for this thread
  8038. const int ir0 = dr*ith;
  8039. const int ir1 = MIN(ir0 + dr, nr);
  8040. const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
  8041. const float alpha = adamw_params_ptr[0];
  8042. const float beta1 = adamw_params_ptr[1];
  8043. const float beta2 = adamw_params_ptr[2];
  8044. const float eps = adamw_params_ptr[3];
  8045. const float wd = adamw_params_ptr[4];
  8046. const float beta1h = adamw_params_ptr[5];
  8047. const float beta2h = adamw_params_ptr[6];
  8048. const float keep = 1.f - alpha * wd;
  8049. for (int ir = ir0; ir < ir1; ++ir) {
  8050. const int64_t i03 = ir/(ne02*ne01);
  8051. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  8052. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  8053. const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
  8054. float * w = (float *) ((char *) src0->data + offset); // weight
  8055. const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
  8056. float * m = (float *) ((char *) src0_grad_m->data + offset);
  8057. float * v = (float *) ((char *) src0_grad_v->data + offset);
  8058. for (int i00 = 0; i00 < ne00; ++i00) {
  8059. m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
  8060. v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
  8061. const float mh = m[i00]*beta1h;
  8062. const float vh = sqrtf(v[i00]*beta2h) + eps;
  8063. // The weight decay is applied independently of the Adam momenta m and v.
  8064. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
  8065. // See: https://arxiv.org/pdf/1711.05101v3.pdf
  8066. w[i00] = w[i00] * keep - alpha * mh / vh;
  8067. }
  8068. }
  8069. }
  8070. void ggml_compute_forward_opt_step_adamw(
  8071. const ggml_compute_params * params,
  8072. ggml_tensor * dst) {
  8073. const ggml_tensor * src0 = dst->src[0];
  8074. switch (src0->type) {
  8075. case GGML_TYPE_F32:
  8076. {
  8077. ggml_compute_forward_opt_step_adamw_f32(params, dst);
  8078. } break;
  8079. default:
  8080. {
  8081. GGML_ABORT("fatal error");
  8082. }
  8083. }
  8084. }
  8085. static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
  8086. const ggml_tensor * src0 = dst->src[0];
  8087. const ggml_tensor * src0_grad = dst->src[1];
  8088. const ggml_tensor * sgd_params = dst->src[2];
  8089. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
  8090. GGML_ASSERT(ggml_nelements(sgd_params) == 2);
  8091. const int ith = params->ith;
  8092. const int nth = params->nth;
  8093. const int nr = ggml_nrows(src0);
  8094. GGML_TENSOR_UNARY_OP_LOCALS
  8095. GGML_ASSERT(nb00 == sizeof(float));
  8096. // rows per thread
  8097. const int dr = (nr + nth - 1) / nth;
  8098. // row range for this thread
  8099. const int ir0 = dr * ith;
  8100. const int ir1 = MIN(ir0 + dr, nr);
  8101. // using adamw param subset we care about - alpha, wd - could have a separate struct
  8102. const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
  8103. const float alpha = sgd_params_ptr[0];
  8104. const float keep = 1.f - alpha * sgd_params_ptr[1];
  8105. for (int ir = ir0; ir < ir1; ++ir) {
  8106. const int64_t i03 = ir / (ne02 * ne01);
  8107. const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
  8108. const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
  8109. const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
  8110. float * w = (float *) ((char *) src0->data + offset); // weight
  8111. const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
  8112. for (int i00 = 0; i00 < ne00; ++i00) {
  8113. w[i00] = w[i00] * keep - alpha * g[i00];
  8114. }
  8115. }
  8116. }
  8117. void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
  8118. const ggml_tensor * src0 = dst->src[0];
  8119. switch (src0->type) {
  8120. case GGML_TYPE_F32:
  8121. {
  8122. ggml_compute_forward_opt_step_sgd_f32(params, dst);
  8123. }
  8124. break;
  8125. default:
  8126. {
  8127. GGML_ABORT("fatal error - sgd is F32 only");
  8128. }
  8129. }
  8130. }