ops.cpp 285 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637
  1. #include "ops.h"
  2. #include "ggml-cpu.h"
  3. #include "ggml-impl.h"
  4. #include "binary-ops.h"
  5. #include "unary-ops.h"
  6. #include "vec.h"
  7. #include <float.h>
  8. #if defined(_MSC_VER)
  9. // disable "possible loss of data" to avoid hundreds of casts
  10. // we should just be careful :)
  11. #pragma warning(disable: 4244 4267)
  12. // disable POSIX deprecation warnings
  13. // these functions are never going away, anyway
  14. #pragma warning(disable: 4996)
  15. // unreachable code because of multiple instances of code after GGML_ABORT
  16. #pragma warning(disable: 4702)
  17. #endif
  18. // ggml_compute_forward_dup
  19. static void ggml_compute_forward_dup_same_cont(
  20. const ggml_compute_params * params,
  21. ggml_tensor * dst) {
  22. const ggml_tensor * src0 = dst->src[0];
  23. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  24. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  25. GGML_ASSERT(src0->type == dst->type);
  26. const size_t nb0 = ggml_type_size(src0->type);
  27. const int ith = params->ith; // thread index
  28. const int nth = params->nth; // number of threads
  29. // parallelize by blocks
  30. const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
  31. const int dr = (nk + nth - 1) / nth;
  32. const int k0 = dr * ith;
  33. const int k1 = MIN(k0 + dr, nk);
  34. if (k0 < k1) {
  35. memcpy(
  36. ((char *) dst->data + k0*nb0),
  37. ((char *) src0->data + k0*nb0),
  38. (k1 - k0) * nb0);
  39. }
  40. }
  41. static void ggml_compute_forward_dup_f16(
  42. const ggml_compute_params * params,
  43. ggml_tensor * dst) {
  44. const ggml_tensor * src0 = dst->src[0];
  45. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  46. GGML_TENSOR_UNARY_OP_LOCALS
  47. const int ith = params->ith; // thread index
  48. const int nth = params->nth; // number of threads
  49. // parallelize by rows
  50. const int nr = ne01;
  51. // number of rows per thread
  52. const int dr = (nr + nth - 1) / nth;
  53. // row range for this thread
  54. const int ir0 = dr * ith;
  55. const int ir1 = MIN(ir0 + dr, nr);
  56. if (src0->type == dst->type &&
  57. ne00 == ne0 &&
  58. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  59. // copy by rows
  60. const size_t rs = ne00*nb00;
  61. for (int64_t i03 = 0; i03 < ne03; i03++) {
  62. for (int64_t i02 = 0; i02 < ne02; i02++) {
  63. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  64. memcpy(
  65. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  66. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  67. rs);
  68. }
  69. }
  70. }
  71. return;
  72. }
  73. // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
  74. if (ggml_is_contiguous(dst)) {
  75. if (nb00 == sizeof(ggml_fp16_t)) {
  76. if (dst->type == GGML_TYPE_F16) {
  77. size_t id = 0;
  78. const size_t rs = ne00 * nb00;
  79. char * dst_ptr = (char *) dst->data;
  80. for (int i03 = 0; i03 < ne03; i03++) {
  81. for (int i02 = 0; i02 < ne02; i02++) {
  82. id += rs * ir0;
  83. for (int i01 = ir0; i01 < ir1; i01++) {
  84. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  85. memcpy(dst_ptr + id, src0_ptr, rs);
  86. id += rs;
  87. }
  88. id += rs * (ne01 - ir1);
  89. }
  90. }
  91. } else if (dst->type == GGML_TYPE_F32) {
  92. size_t id = 0;
  93. float * dst_ptr = (float *) dst->data;
  94. for (int i03 = 0; i03 < ne03; i03++) {
  95. for (int i02 = 0; i02 < ne02; i02++) {
  96. id += ne00 * ir0;
  97. for (int i01 = ir0; i01 < ir1; i01++) {
  98. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  99. for (int i00 = 0; i00 < ne00; i00++) {
  100. dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
  101. id++;
  102. }
  103. }
  104. id += ne00 * (ne01 - ir1);
  105. }
  106. }
  107. } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
  108. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  109. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  110. size_t id = 0;
  111. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  112. char * dst_ptr = (char *) dst->data;
  113. for (int i03 = 0; i03 < ne03; i03++) {
  114. for (int i02 = 0; i02 < ne02; i02++) {
  115. id += rs * ir0;
  116. for (int i01 = ir0; i01 < ir1; i01++) {
  117. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  118. for (int i00 = 0; i00 < ne00; i00++) {
  119. src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
  120. }
  121. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  122. id += rs;
  123. }
  124. id += rs * (ne01 - ir1);
  125. }
  126. }
  127. } else {
  128. GGML_ABORT("fatal error"); // TODO: implement
  129. }
  130. } else {
  131. //printf("%s: this is not optimal - fix me\n", __func__);
  132. if (dst->type == GGML_TYPE_F32) {
  133. size_t id = 0;
  134. float * dst_ptr = (float *) dst->data;
  135. for (int i03 = 0; i03 < ne03; i03++) {
  136. for (int i02 = 0; i02 < ne02; i02++) {
  137. id += ne00 * ir0;
  138. for (int i01 = ir0; i01 < ir1; i01++) {
  139. for (int i00 = 0; i00 < ne00; i00++) {
  140. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  141. dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
  142. id++;
  143. }
  144. }
  145. id += ne00 * (ne01 - ir1);
  146. }
  147. }
  148. } else if (dst->type == GGML_TYPE_F16) {
  149. size_t id = 0;
  150. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  151. for (int i03 = 0; i03 < ne03; i03++) {
  152. for (int i02 = 0; i02 < ne02; i02++) {
  153. id += ne00 * ir0;
  154. for (int i01 = ir0; i01 < ir1; i01++) {
  155. for (int i00 = 0; i00 < ne00; i00++) {
  156. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  157. dst_ptr[id] = *src0_ptr;
  158. id++;
  159. }
  160. }
  161. id += ne00 * (ne01 - ir1);
  162. }
  163. }
  164. } else {
  165. GGML_ABORT("fatal error"); // TODO: implement
  166. }
  167. }
  168. return;
  169. }
  170. // dst counters
  171. int64_t i10 = 0;
  172. int64_t i11 = 0;
  173. int64_t i12 = 0;
  174. int64_t i13 = 0;
  175. if (dst->type == GGML_TYPE_F16) {
  176. for (int64_t i03 = 0; i03 < ne03; i03++) {
  177. for (int64_t i02 = 0; i02 < ne02; i02++) {
  178. i10 += ne00 * ir0;
  179. while (i10 >= ne0) {
  180. i10 -= ne0;
  181. if (++i11 == ne1) {
  182. i11 = 0;
  183. if (++i12 == ne2) {
  184. i12 = 0;
  185. if (++i13 == ne3) {
  186. i13 = 0;
  187. }
  188. }
  189. }
  190. }
  191. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  192. for (int64_t i00 = 0; i00 < ne00; i00++) {
  193. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  194. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  195. memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
  196. if (++i10 == ne00) {
  197. i10 = 0;
  198. if (++i11 == ne01) {
  199. i11 = 0;
  200. if (++i12 == ne02) {
  201. i12 = 0;
  202. if (++i13 == ne03) {
  203. i13 = 0;
  204. }
  205. }
  206. }
  207. }
  208. }
  209. }
  210. i10 += ne00 * (ne01 - ir1);
  211. while (i10 >= ne0) {
  212. i10 -= ne0;
  213. if (++i11 == ne1) {
  214. i11 = 0;
  215. if (++i12 == ne2) {
  216. i12 = 0;
  217. if (++i13 == ne3) {
  218. i13 = 0;
  219. }
  220. }
  221. }
  222. }
  223. }
  224. }
  225. } else if (dst->type == GGML_TYPE_F32) {
  226. for (int64_t i03 = 0; i03 < ne03; i03++) {
  227. for (int64_t i02 = 0; i02 < ne02; i02++) {
  228. i10 += ne00 * ir0;
  229. while (i10 >= ne0) {
  230. i10 -= ne0;
  231. if (++i11 == ne1) {
  232. i11 = 0;
  233. if (++i12 == ne2) {
  234. i12 = 0;
  235. if (++i13 == ne3) {
  236. i13 = 0;
  237. }
  238. }
  239. }
  240. }
  241. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  242. for (int64_t i00 = 0; i00 < ne00; i00++) {
  243. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  244. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  245. *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
  246. if (++i10 == ne0) {
  247. i10 = 0;
  248. if (++i11 == ne1) {
  249. i11 = 0;
  250. if (++i12 == ne2) {
  251. i12 = 0;
  252. if (++i13 == ne3) {
  253. i13 = 0;
  254. }
  255. }
  256. }
  257. }
  258. }
  259. }
  260. i10 += ne00 * (ne01 - ir1);
  261. while (i10 >= ne0) {
  262. i10 -= ne0;
  263. if (++i11 == ne1) {
  264. i11 = 0;
  265. if (++i12 == ne2) {
  266. i12 = 0;
  267. if (++i13 == ne3) {
  268. i13 = 0;
  269. }
  270. }
  271. }
  272. }
  273. }
  274. }
  275. } else {
  276. GGML_ABORT("fatal error"); // TODO: implement
  277. }
  278. }
  279. static void ggml_compute_forward_dup_bf16(
  280. const ggml_compute_params * params,
  281. ggml_tensor * dst) {
  282. const ggml_tensor * src0 = dst->src[0];
  283. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  284. GGML_TENSOR_UNARY_OP_LOCALS
  285. const int ith = params->ith; // thread index
  286. const int nth = params->nth; // number of threads
  287. // parallelize by rows
  288. const int nr = ne01;
  289. // number of rows per thread
  290. const int dr = (nr + nth - 1) / nth;
  291. // row range for this thread
  292. const int ir0 = dr * ith;
  293. const int ir1 = MIN(ir0 + dr, nr);
  294. if (src0->type == dst->type &&
  295. ne00 == ne0 &&
  296. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  297. // copy by rows
  298. const size_t rs = ne00*nb00;
  299. for (int64_t i03 = 0; i03 < ne03; i03++) {
  300. for (int64_t i02 = 0; i02 < ne02; i02++) {
  301. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  302. memcpy(
  303. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  304. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  305. rs);
  306. }
  307. }
  308. }
  309. return;
  310. }
  311. // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
  312. if (ggml_is_contiguous(dst)) {
  313. if (nb00 == sizeof(ggml_bf16_t)) {
  314. if (dst->type == GGML_TYPE_BF16) {
  315. size_t id = 0;
  316. const size_t rs = ne00 * nb00;
  317. char * dst_ptr = (char *) dst->data;
  318. for (int i03 = 0; i03 < ne03; i03++) {
  319. for (int i02 = 0; i02 < ne02; i02++) {
  320. id += rs * ir0;
  321. for (int i01 = ir0; i01 < ir1; i01++) {
  322. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  323. memcpy(dst_ptr + id, src0_ptr, rs);
  324. id += rs;
  325. }
  326. id += rs * (ne01 - ir1);
  327. }
  328. }
  329. } else if (dst->type == GGML_TYPE_F16) {
  330. size_t id = 0;
  331. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  332. for (int i03 = 0; i03 < ne03; i03++) {
  333. for (int i02 = 0; i02 < ne02; i02++) {
  334. id += ne00 * ir0;
  335. for (int i01 = ir0; i01 < ir1; i01++) {
  336. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  337. for (int i00 = 0; i00 < ne00; i00++) {
  338. dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
  339. id++;
  340. }
  341. }
  342. id += ne00 * (ne01 - ir1);
  343. }
  344. }
  345. } else if (dst->type == GGML_TYPE_F32) {
  346. size_t id = 0;
  347. float * dst_ptr = (float *) dst->data;
  348. for (int i03 = 0; i03 < ne03; i03++) {
  349. for (int i02 = 0; i02 < ne02; i02++) {
  350. id += ne00 * ir0;
  351. for (int i01 = ir0; i01 < ir1; i01++) {
  352. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  353. for (int i00 = 0; i00 < ne00; i00++) {
  354. dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
  355. id++;
  356. }
  357. }
  358. id += ne00 * (ne01 - ir1);
  359. }
  360. }
  361. } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
  362. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  363. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  364. size_t id = 0;
  365. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  366. char * dst_ptr = (char *) dst->data;
  367. for (int i03 = 0; i03 < ne03; i03++) {
  368. for (int i02 = 0; i02 < ne02; i02++) {
  369. id += rs * ir0;
  370. for (int i01 = ir0; i01 < ir1; i01++) {
  371. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  372. for (int i00 = 0; i00 < ne00; i00++) {
  373. src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
  374. }
  375. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  376. id += rs;
  377. }
  378. id += rs * (ne01 - ir1);
  379. }
  380. }
  381. } else {
  382. GGML_ABORT("fatal error"); // TODO: implement
  383. }
  384. } else {
  385. //printf("%s: this is not optimal - fix me\n", __func__);
  386. if (dst->type == GGML_TYPE_F32) {
  387. size_t id = 0;
  388. float * dst_ptr = (float *) dst->data;
  389. for (int i03 = 0; i03 < ne03; i03++) {
  390. for (int i02 = 0; i02 < ne02; i02++) {
  391. id += ne00 * ir0;
  392. for (int i01 = ir0; i01 < ir1; i01++) {
  393. for (int i00 = 0; i00 < ne00; i00++) {
  394. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  395. dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
  396. id++;
  397. }
  398. }
  399. id += ne00 * (ne01 - ir1);
  400. }
  401. }
  402. } else if (dst->type == GGML_TYPE_BF16) {
  403. size_t id = 0;
  404. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
  405. for (int i03 = 0; i03 < ne03; i03++) {
  406. for (int i02 = 0; i02 < ne02; i02++) {
  407. id += ne00 * ir0;
  408. for (int i01 = ir0; i01 < ir1; i01++) {
  409. for (int i00 = 0; i00 < ne00; i00++) {
  410. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  411. dst_ptr[id] = *src0_ptr;
  412. id++;
  413. }
  414. }
  415. id += ne00 * (ne01 - ir1);
  416. }
  417. }
  418. } else if (dst->type == GGML_TYPE_F16) {
  419. size_t id = 0;
  420. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  421. for (int i03 = 0; i03 < ne03; i03++) {
  422. for (int i02 = 0; i02 < ne02; i02++) {
  423. id += ne00 * ir0;
  424. for (int i01 = ir0; i01 < ir1; i01++) {
  425. for (int i00 = 0; i00 < ne00; i00++) {
  426. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  427. dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
  428. id++;
  429. }
  430. }
  431. id += ne00 * (ne01 - ir1);
  432. }
  433. }
  434. } else {
  435. GGML_ABORT("fatal error"); // TODO: implement
  436. }
  437. }
  438. return;
  439. }
  440. // dst counters
  441. int64_t i10 = 0;
  442. int64_t i11 = 0;
  443. int64_t i12 = 0;
  444. int64_t i13 = 0;
  445. if (dst->type == GGML_TYPE_BF16) {
  446. for (int64_t i03 = 0; i03 < ne03; i03++) {
  447. for (int64_t i02 = 0; i02 < ne02; i02++) {
  448. i10 += ne00 * ir0;
  449. while (i10 >= ne0) {
  450. i10 -= ne0;
  451. if (++i11 == ne1) {
  452. i11 = 0;
  453. if (++i12 == ne2) {
  454. i12 = 0;
  455. if (++i13 == ne3) {
  456. i13 = 0;
  457. }
  458. }
  459. }
  460. }
  461. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  462. for (int64_t i00 = 0; i00 < ne00; i00++) {
  463. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  464. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  465. memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
  466. if (++i10 == ne00) {
  467. i10 = 0;
  468. if (++i11 == ne01) {
  469. i11 = 0;
  470. if (++i12 == ne02) {
  471. i12 = 0;
  472. if (++i13 == ne03) {
  473. i13 = 0;
  474. }
  475. }
  476. }
  477. }
  478. }
  479. }
  480. i10 += ne00 * (ne01 - ir1);
  481. while (i10 >= ne0) {
  482. i10 -= ne0;
  483. if (++i11 == ne1) {
  484. i11 = 0;
  485. if (++i12 == ne2) {
  486. i12 = 0;
  487. if (++i13 == ne3) {
  488. i13 = 0;
  489. }
  490. }
  491. }
  492. }
  493. }
  494. }
  495. } else if (dst->type == GGML_TYPE_F16) {
  496. for (int64_t i03 = 0; i03 < ne03; i03++) {
  497. for (int64_t i02 = 0; i02 < ne02; i02++) {
  498. i10 += ne00 * ir0;
  499. while (i10 >= ne0) {
  500. i10 -= ne0;
  501. if (++i11 == ne1) {
  502. i11 = 0;
  503. if (++i12 == ne2) {
  504. i12 = 0;
  505. if (++i13 == ne3) {
  506. i13 = 0;
  507. }
  508. }
  509. }
  510. }
  511. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  512. for (int64_t i00 = 0; i00 < ne00; i00++) {
  513. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  514. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  515. *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
  516. if (++i10 == ne0) {
  517. i10 = 0;
  518. if (++i11 == ne1) {
  519. i11 = 0;
  520. if (++i12 == ne2) {
  521. i12 = 0;
  522. if (++i13 == ne3) {
  523. i13 = 0;
  524. }
  525. }
  526. }
  527. }
  528. }
  529. }
  530. i10 += ne00 * (ne01 - ir1);
  531. while (i10 >= ne0) {
  532. i10 -= ne0;
  533. if (++i11 == ne1) {
  534. i11 = 0;
  535. if (++i12 == ne2) {
  536. i12 = 0;
  537. if (++i13 == ne3) {
  538. i13 = 0;
  539. }
  540. }
  541. }
  542. }
  543. }
  544. }
  545. } else if (dst->type == GGML_TYPE_F32) {
  546. for (int64_t i03 = 0; i03 < ne03; i03++) {
  547. for (int64_t i02 = 0; i02 < ne02; i02++) {
  548. i10 += ne00 * ir0;
  549. while (i10 >= ne0) {
  550. i10 -= ne0;
  551. if (++i11 == ne1) {
  552. i11 = 0;
  553. if (++i12 == ne2) {
  554. i12 = 0;
  555. if (++i13 == ne3) {
  556. i13 = 0;
  557. }
  558. }
  559. }
  560. }
  561. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  562. for (int64_t i00 = 0; i00 < ne00; i00++) {
  563. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  564. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  565. *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
  566. if (++i10 == ne0) {
  567. i10 = 0;
  568. if (++i11 == ne1) {
  569. i11 = 0;
  570. if (++i12 == ne2) {
  571. i12 = 0;
  572. if (++i13 == ne3) {
  573. i13 = 0;
  574. }
  575. }
  576. }
  577. }
  578. }
  579. }
  580. i10 += ne00 * (ne01 - ir1);
  581. while (i10 >= ne0) {
  582. i10 -= ne0;
  583. if (++i11 == ne1) {
  584. i11 = 0;
  585. if (++i12 == ne2) {
  586. i12 = 0;
  587. if (++i13 == ne3) {
  588. i13 = 0;
  589. }
  590. }
  591. }
  592. }
  593. }
  594. }
  595. } else {
  596. GGML_ABORT("fatal error"); // TODO: implement
  597. }
  598. }
  599. static void ggml_compute_forward_dup_f32(
  600. const ggml_compute_params * params,
  601. ggml_tensor * dst) {
  602. const ggml_tensor * src0 = dst->src[0];
  603. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  604. GGML_TENSOR_UNARY_OP_LOCALS
  605. const int ith = params->ith; // thread index
  606. const int nth = params->nth; // number of threads
  607. // parallelize by rows
  608. const int nr = ne01;
  609. // number of rows per thread
  610. const int dr = (nr + nth - 1) / nth;
  611. // row range for this thread
  612. const int ir0 = dr * ith;
  613. const int ir1 = MIN(ir0 + dr, nr);
  614. if (src0->type == dst->type &&
  615. ne00 == ne0 &&
  616. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  617. // copy by rows
  618. const size_t rs = ne00*nb00;
  619. for (int64_t i03 = 0; i03 < ne03; i03++) {
  620. for (int64_t i02 = 0; i02 < ne02; i02++) {
  621. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  622. memcpy(
  623. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  624. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  625. rs);
  626. }
  627. }
  628. }
  629. return;
  630. }
  631. if (ggml_is_contiguous(dst)) {
  632. // TODO: simplify
  633. if (nb00 == sizeof(float)) {
  634. if (dst->type == GGML_TYPE_F32) {
  635. size_t id = 0;
  636. const size_t rs = ne00 * nb00;
  637. char * dst_ptr = (char *) dst->data;
  638. for (int i03 = 0; i03 < ne03; i03++) {
  639. for (int i02 = 0; i02 < ne02; i02++) {
  640. id += rs * ir0;
  641. for (int i01 = ir0; i01 < ir1; i01++) {
  642. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  643. memcpy(dst_ptr + id, src0_ptr, rs);
  644. id += rs;
  645. }
  646. id += rs * (ne01 - ir1);
  647. }
  648. }
  649. } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
  650. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  651. size_t id = 0;
  652. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  653. char * dst_ptr = (char *) dst->data;
  654. for (int i03 = 0; i03 < ne03; i03++) {
  655. for (int i02 = 0; i02 < ne02; i02++) {
  656. id += rs * ir0;
  657. for (int i01 = ir0; i01 < ir1; i01++) {
  658. const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  659. quantize_row_q(src0_ptr, dst_ptr + id, ne00);
  660. id += rs;
  661. }
  662. id += rs * (ne01 - ir1);
  663. }
  664. }
  665. } else {
  666. GGML_ABORT("fatal error"); // TODO: implement
  667. }
  668. } else {
  669. //printf("%s: this is not optimal - fix me\n", __func__);
  670. if (dst->type == GGML_TYPE_F32) {
  671. size_t id = 0;
  672. float * dst_ptr = (float *) dst->data;
  673. for (int i03 = 0; i03 < ne03; i03++) {
  674. for (int i02 = 0; i02 < ne02; i02++) {
  675. id += ne00 * ir0;
  676. for (int i01 = ir0; i01 < ir1; i01++) {
  677. for (int i00 = 0; i00 < ne00; i00++) {
  678. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  679. dst_ptr[id] = *src0_ptr;
  680. id++;
  681. }
  682. }
  683. id += ne00 * (ne01 - ir1);
  684. }
  685. }
  686. } else if (dst->type == GGML_TYPE_F16) {
  687. size_t id = 0;
  688. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  689. for (int i03 = 0; i03 < ne03; i03++) {
  690. for (int i02 = 0; i02 < ne02; i02++) {
  691. id += ne00 * ir0;
  692. for (int i01 = ir0; i01 < ir1; i01++) {
  693. for (int i00 = 0; i00 < ne00; i00++) {
  694. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  695. dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
  696. id++;
  697. }
  698. }
  699. id += ne00 * (ne01 - ir1);
  700. }
  701. }
  702. } else if (dst->type == GGML_TYPE_BF16) {
  703. size_t id = 0;
  704. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
  705. for (int i03 = 0; i03 < ne03; i03++) {
  706. for (int i02 = 0; i02 < ne02; i02++) {
  707. id += ne00 * ir0;
  708. for (int i01 = ir0; i01 < ir1; i01++) {
  709. for (int i00 = 0; i00 < ne00; i00++) {
  710. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  711. dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
  712. id++;
  713. }
  714. }
  715. id += ne00 * (ne01 - ir1);
  716. }
  717. }
  718. } else {
  719. GGML_ABORT("fatal error"); // TODO: implement
  720. }
  721. }
  722. return;
  723. }
  724. // dst counters
  725. int64_t i10 = 0;
  726. int64_t i11 = 0;
  727. int64_t i12 = 0;
  728. int64_t i13 = 0;
  729. if (dst->type == GGML_TYPE_F32) {
  730. for (int64_t i03 = 0; i03 < ne03; i03++) {
  731. for (int64_t i02 = 0; i02 < ne02; i02++) {
  732. i10 += ne00 * ir0;
  733. while (i10 >= ne0) {
  734. i10 -= ne0;
  735. if (++i11 == ne1) {
  736. i11 = 0;
  737. if (++i12 == ne2) {
  738. i12 = 0;
  739. if (++i13 == ne3) {
  740. i13 = 0;
  741. }
  742. }
  743. }
  744. }
  745. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  746. for (int64_t i00 = 0; i00 < ne00; i00++) {
  747. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  748. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  749. memcpy(dst_ptr, src0_ptr, sizeof(float));
  750. if (++i10 == ne0) {
  751. i10 = 0;
  752. if (++i11 == ne1) {
  753. i11 = 0;
  754. if (++i12 == ne2) {
  755. i12 = 0;
  756. if (++i13 == ne3) {
  757. i13 = 0;
  758. }
  759. }
  760. }
  761. }
  762. }
  763. }
  764. i10 += ne00 * (ne01 - ir1);
  765. while (i10 >= ne0) {
  766. i10 -= ne0;
  767. if (++i11 == ne1) {
  768. i11 = 0;
  769. if (++i12 == ne2) {
  770. i12 = 0;
  771. if (++i13 == ne3) {
  772. i13 = 0;
  773. }
  774. }
  775. }
  776. }
  777. }
  778. }
  779. } else if (dst->type == GGML_TYPE_F16) {
  780. for (int64_t i03 = 0; i03 < ne03; i03++) {
  781. for (int64_t i02 = 0; i02 < ne02; i02++) {
  782. i10 += ne00 * ir0;
  783. while (i10 >= ne0) {
  784. i10 -= ne0;
  785. if (++i11 == ne1) {
  786. i11 = 0;
  787. if (++i12 == ne2) {
  788. i12 = 0;
  789. if (++i13 == ne3) {
  790. i13 = 0;
  791. }
  792. }
  793. }
  794. }
  795. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  796. for (int64_t i00 = 0; i00 < ne00; i00++) {
  797. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  798. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  799. *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
  800. if (++i10 == ne0) {
  801. i10 = 0;
  802. if (++i11 == ne1) {
  803. i11 = 0;
  804. if (++i12 == ne2) {
  805. i12 = 0;
  806. if (++i13 == ne3) {
  807. i13 = 0;
  808. }
  809. }
  810. }
  811. }
  812. }
  813. }
  814. i10 += ne00 * (ne01 - ir1);
  815. while (i10 >= ne0) {
  816. i10 -= ne0;
  817. if (++i11 == ne1) {
  818. i11 = 0;
  819. if (++i12 == ne2) {
  820. i12 = 0;
  821. if (++i13 == ne3) {
  822. i13 = 0;
  823. }
  824. }
  825. }
  826. }
  827. }
  828. }
  829. } else if (dst->type == GGML_TYPE_BF16) {
  830. for (int64_t i03 = 0; i03 < ne03; i03++) {
  831. for (int64_t i02 = 0; i02 < ne02; i02++) {
  832. i10 += ne00 * ir0;
  833. while (i10 >= ne0) {
  834. i10 -= ne0;
  835. if (++i11 == ne1) {
  836. i11 = 0;
  837. if (++i12 == ne2) {
  838. i12 = 0;
  839. if (++i13 == ne3) {
  840. i13 = 0;
  841. }
  842. }
  843. }
  844. }
  845. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  846. for (int64_t i00 = 0; i00 < ne00; i00++) {
  847. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  848. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  849. *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
  850. if (++i10 == ne0) {
  851. i10 = 0;
  852. if (++i11 == ne1) {
  853. i11 = 0;
  854. if (++i12 == ne2) {
  855. i12 = 0;
  856. if (++i13 == ne3) {
  857. i13 = 0;
  858. }
  859. }
  860. }
  861. }
  862. }
  863. }
  864. i10 += ne00 * (ne01 - ir1);
  865. while (i10 >= ne0) {
  866. i10 -= ne0;
  867. if (++i11 == ne1) {
  868. i11 = 0;
  869. if (++i12 == ne2) {
  870. i12 = 0;
  871. if (++i13 == ne3) {
  872. i13 = 0;
  873. }
  874. }
  875. }
  876. }
  877. }
  878. }
  879. } else {
  880. GGML_ABORT("fatal error"); // TODO: implement
  881. }
  882. }
  883. // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
  884. static void ggml_compute_forward_dup_bytes(
  885. const ggml_compute_params * params,
  886. ggml_tensor * dst) {
  887. const ggml_tensor * src0 = dst->src[0];
  888. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  889. GGML_ASSERT(src0->type == dst->type);
  890. GGML_TENSOR_UNARY_OP_LOCALS;
  891. if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
  892. ggml_compute_forward_dup_same_cont(params, dst);
  893. return;
  894. }
  895. const size_t type_size = ggml_type_size(src0->type);
  896. const int ith = params->ith; // thread index
  897. const int nth = params->nth; // number of threads
  898. // parallelize by rows
  899. const int nr = ne01;
  900. // number of rows per thread
  901. const int dr = (nr + nth - 1) / nth;
  902. // row range for this thread
  903. const int ir0 = dr * ith;
  904. const int ir1 = MIN(ir0 + dr, nr);
  905. if (src0->type == dst->type &&
  906. ggml_are_same_shape(src0, dst) &&
  907. nb00 == type_size && nb0 == type_size) {
  908. // copy by rows
  909. const size_t rs = ggml_row_size(src0->type, ne00);
  910. for (int64_t i03 = 0; i03 < ne03; i03++) {
  911. for (int64_t i02 = 0; i02 < ne02; i02++) {
  912. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  913. memcpy(
  914. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  915. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  916. rs);
  917. }
  918. }
  919. }
  920. return;
  921. }
  922. if (ggml_is_contiguous(dst)) {
  923. size_t id = 0;
  924. char * dst_ptr = (char *) dst->data;
  925. const size_t rs = ne00 * type_size;
  926. if (nb00 == type_size) {
  927. // src0 is contigous on first dimension, copy by rows
  928. for (int64_t i03 = 0; i03 < ne03; i03++) {
  929. for (int64_t i02 = 0; i02 < ne02; i02++) {
  930. id += rs * ir0;
  931. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  932. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  933. memcpy(dst_ptr + id, src0_ptr, rs);
  934. id += rs;
  935. }
  936. id += rs * (ne01 - ir1);
  937. }
  938. }
  939. } else {
  940. //printf("%s: this is not optimal - fix me\n", __func__);
  941. for (int64_t i03 = 0; i03 < ne03; i03++) {
  942. for (int64_t i02 = 0; i02 < ne02; i02++) {
  943. id += rs * ir0;
  944. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  945. for (int64_t i00 = 0; i00 < ne00; i00++) {
  946. const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
  947. memcpy(dst_ptr + id, src0_ptr, type_size);
  948. id += type_size;
  949. }
  950. }
  951. id += rs * (ne01 - ir1);
  952. }
  953. }
  954. }
  955. return;
  956. }
  957. // dst counters
  958. int64_t k10 = 0;
  959. int64_t i11 = 0;
  960. int64_t i12 = 0;
  961. int64_t i13 = 0;
  962. // number of blocks in a row
  963. const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
  964. const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
  965. for (int64_t i03 = 0; i03 < ne03; i03++) {
  966. for (int64_t i02 = 0; i02 < ne02; i02++) {
  967. k10 += nk00 * ir0;
  968. while (k10 >= nk0) {
  969. k10 -= nk0;
  970. if (++i11 == ne1) {
  971. i11 = 0;
  972. if (++i12 == ne2) {
  973. i12 = 0;
  974. if (++i13 == ne3) {
  975. i13 = 0;
  976. }
  977. }
  978. }
  979. }
  980. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  981. for (int64_t k00 = 0; k00 < nk00; k00++) {
  982. const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  983. char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  984. memcpy(dst_ptr, src0_ptr, type_size);
  985. if (++k10 == nk0) {
  986. k10 = 0;
  987. if (++i11 == ne1) {
  988. i11 = 0;
  989. if (++i12 == ne2) {
  990. i12 = 0;
  991. if (++i13 == ne3) {
  992. i13 = 0;
  993. }
  994. }
  995. }
  996. }
  997. }
  998. }
  999. k10 += nk00 * (ne01 - ir1);
  1000. while (k10 >= nk0) {
  1001. k10 -= nk0;
  1002. if (++i11 == ne1) {
  1003. i11 = 0;
  1004. if (++i12 == ne2) {
  1005. i12 = 0;
  1006. if (++i13 == ne3) {
  1007. i13 = 0;
  1008. }
  1009. }
  1010. }
  1011. }
  1012. }
  1013. }
  1014. }
  1015. static void ggml_compute_forward_dup_q(
  1016. const ggml_compute_params * params,
  1017. ggml_tensor * dst) {
  1018. const ggml_tensor * src0 = dst->src[0];
  1019. const ggml_tensor * src1 = dst->src[1];
  1020. GGML_TENSOR_BINARY_OP_LOCALS
  1021. const ggml_type type = src0->type;
  1022. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  1023. size_t qk = ggml_blck_size(type);
  1024. const int64_t nr = ggml_nelements(src1) / qk;
  1025. // destination must be contiguous in the first dimension
  1026. GGML_ASSERT(nb10 == ggml_type_size(dst->type));
  1027. // must either have first dimension large enough to hold a row, or fully contiguous
  1028. GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
  1029. const int ith = params->ith;
  1030. const int nth = params->nth;
  1031. const int dr = (nr + nth - 1)/nth;
  1032. // row range for this thread
  1033. const int ir0 = dr*ith;
  1034. const int ir1 = MIN(ir0 + dr, nr);
  1035. for (int64_t ir = ir0; ir < ir1; ++ir) {
  1036. uint32_t i = ir * qk;
  1037. const int64_t i03 = i/(ne00 * ne01 * ne02);
  1038. const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
  1039. const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
  1040. const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
  1041. const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
  1042. const int64_t i13 = i/(ne10 * ne11 * ne12);
  1043. const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
  1044. const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
  1045. const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
  1046. const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
  1047. dequantize_row_q(
  1048. (const void *) ((char *) src0->data + x_offset),
  1049. (float *) ((char *) dst->data + dst_offset), qk);
  1050. }
  1051. }
  1052. void ggml_compute_forward_dup(
  1053. const ggml_compute_params * params,
  1054. ggml_tensor * dst) {
  1055. const ggml_tensor * src0 = dst->src[0];
  1056. if (src0->type == dst->type) {
  1057. ggml_compute_forward_dup_bytes(params, dst);
  1058. return;
  1059. }
  1060. switch (src0->type) {
  1061. case GGML_TYPE_F16:
  1062. {
  1063. ggml_compute_forward_dup_f16(params, dst);
  1064. } break;
  1065. case GGML_TYPE_BF16:
  1066. {
  1067. ggml_compute_forward_dup_bf16(params, dst);
  1068. } break;
  1069. case GGML_TYPE_F32:
  1070. {
  1071. ggml_compute_forward_dup_f32(params, dst);
  1072. } break;
  1073. default:
  1074. {
  1075. if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
  1076. ggml_compute_forward_dup_q(params, dst);
  1077. break;
  1078. }
  1079. GGML_ABORT("fatal error");
  1080. }
  1081. }
  1082. }
  1083. // ggml_compute_forward_add
  1084. static void ggml_compute_forward_add_q_f32(
  1085. const ggml_compute_params * params,
  1086. ggml_tensor * dst) {
  1087. const ggml_tensor * src0 = dst->src[0];
  1088. const ggml_tensor * src1 = dst->src[1];
  1089. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  1090. const int nr = ggml_nrows(src0);
  1091. GGML_TENSOR_BINARY_OP_LOCALS
  1092. const int ith = params->ith;
  1093. const int nth = params->nth;
  1094. const ggml_type type = src0->type;
  1095. const ggml_type dtype = dst->type;
  1096. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  1097. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
  1098. // we don't support permuted src0 or src1
  1099. GGML_ASSERT(nb00 == ggml_type_size(type));
  1100. GGML_ASSERT(nb10 == sizeof(float));
  1101. // dst cannot be transposed or permuted
  1102. GGML_ASSERT(nb0 <= nb1);
  1103. GGML_ASSERT(nb1 <= nb2);
  1104. GGML_ASSERT(nb2 <= nb3);
  1105. GGML_ASSERT(ggml_is_quantized(src0->type));
  1106. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1107. // rows per thread
  1108. const int dr = (nr + nth - 1)/nth;
  1109. // row range for this thread
  1110. const int ir0 = dr*ith;
  1111. const int ir1 = MIN(ir0 + dr, nr);
  1112. float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  1113. for (int ir = ir0; ir < ir1; ++ir) {
  1114. // src0 indices
  1115. const int i03 = ir/(ne02*ne01);
  1116. const int i02 = (ir - i03*ne02*ne01)/ne01;
  1117. const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
  1118. // src1 and dst are same shape as src0 => same indices
  1119. const int i13 = i03;
  1120. const int i12 = i02;
  1121. const int i11 = i01;
  1122. const int i3 = i03;
  1123. const int i2 = i02;
  1124. const int i1 = i01;
  1125. void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
  1126. float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
  1127. void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  1128. assert(ne00 % 32 == 0);
  1129. // unquantize row from src0 to temp buffer
  1130. dequantize_row_q(src0_row, wdata, ne00);
  1131. // add src1
  1132. ggml_vec_acc_f32(ne00, wdata, src1_row);
  1133. // quantize row to dst
  1134. if (quantize_row_q != NULL) {
  1135. quantize_row_q(wdata, dst_row, ne00);
  1136. } else {
  1137. memcpy(dst_row, wdata, ne0*nb0);
  1138. }
  1139. }
  1140. }
  1141. void ggml_compute_forward_add(
  1142. const ggml_compute_params * params,
  1143. ggml_tensor * dst) {
  1144. const ggml_tensor * src0 = dst->src[0];
  1145. switch (src0->type) {
  1146. case GGML_TYPE_F32:
  1147. case GGML_TYPE_F16:
  1148. case GGML_TYPE_BF16:
  1149. {
  1150. ggml_compute_forward_add_non_quantized(params, dst);
  1151. } break;
  1152. case GGML_TYPE_Q4_0:
  1153. case GGML_TYPE_Q4_1:
  1154. case GGML_TYPE_Q5_0:
  1155. case GGML_TYPE_Q5_1:
  1156. case GGML_TYPE_Q8_0:
  1157. case GGML_TYPE_Q2_K:
  1158. case GGML_TYPE_Q3_K:
  1159. case GGML_TYPE_Q4_K:
  1160. case GGML_TYPE_Q5_K:
  1161. case GGML_TYPE_Q6_K:
  1162. case GGML_TYPE_TQ1_0:
  1163. case GGML_TYPE_TQ2_0:
  1164. case GGML_TYPE_IQ2_XXS:
  1165. case GGML_TYPE_IQ2_XS:
  1166. case GGML_TYPE_IQ3_XXS:
  1167. case GGML_TYPE_IQ1_S:
  1168. case GGML_TYPE_IQ1_M:
  1169. case GGML_TYPE_IQ4_NL:
  1170. case GGML_TYPE_IQ4_XS:
  1171. case GGML_TYPE_IQ3_S:
  1172. case GGML_TYPE_IQ2_S:
  1173. {
  1174. ggml_compute_forward_add_q_f32(params, dst);
  1175. } break;
  1176. default:
  1177. {
  1178. GGML_ABORT("fatal error");
  1179. }
  1180. }
  1181. }
  1182. // ggml_compute_forward_add1
  1183. static void ggml_compute_forward_add1_f32(
  1184. const ggml_compute_params * params,
  1185. ggml_tensor * dst) {
  1186. const ggml_tensor * src0 = dst->src[0];
  1187. const ggml_tensor * src1 = dst->src[1];
  1188. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1189. GGML_ASSERT(ggml_is_scalar(src1));
  1190. const int ith = params->ith;
  1191. const int nth = params->nth;
  1192. const int nr = ggml_nrows(src0);
  1193. GGML_TENSOR_UNARY_OP_LOCALS
  1194. GGML_ASSERT( nb0 == sizeof(float));
  1195. GGML_ASSERT(nb00 == sizeof(float));
  1196. // rows per thread
  1197. const int dr = (nr + nth - 1)/nth;
  1198. // row range for this thread
  1199. const int ir0 = dr*ith;
  1200. const int ir1 = MIN(ir0 + dr, nr);
  1201. for (int ir = ir0; ir < ir1; ++ir) {
  1202. // src0 and dst are same shape => same indices
  1203. const int i3 = ir/(ne2*ne1);
  1204. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1205. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1206. #ifdef GGML_USE_ACCELERATE
  1207. GGML_UNUSED(ggml_vec_add1_f32);
  1208. vDSP_vadd(
  1209. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
  1210. (float *) ((char *) src1->data), 0,
  1211. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
  1212. ne0);
  1213. #else
  1214. ggml_vec_add1_f32(ne0,
  1215. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
  1216. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  1217. *(float *) src1->data);
  1218. #endif
  1219. }
  1220. }
  1221. static void ggml_compute_forward_add1_f16_f32(
  1222. const ggml_compute_params * params,
  1223. ggml_tensor * dst) {
  1224. const ggml_tensor * src0 = dst->src[0];
  1225. const ggml_tensor * src1 = dst->src[1];
  1226. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1227. GGML_ASSERT(ggml_is_scalar(src1));
  1228. // scalar to add
  1229. const float v = *(float *) src1->data;
  1230. const int ith = params->ith;
  1231. const int nth = params->nth;
  1232. const int nr = ggml_nrows(src0);
  1233. GGML_TENSOR_UNARY_OP_LOCALS
  1234. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  1235. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1236. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  1237. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  1238. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1239. // rows per thread
  1240. const int dr = (nr + nth - 1)/nth;
  1241. // row range for this thread
  1242. const int ir0 = dr*ith;
  1243. const int ir1 = MIN(ir0 + dr, nr);
  1244. for (int ir = ir0; ir < ir1; ++ir) {
  1245. // src0 and dst are same shape => same indices
  1246. const int i3 = ir/(ne2*ne1);
  1247. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1248. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1249. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1250. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1251. for (int i = 0; i < ne0; i++) {
  1252. dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
  1253. }
  1254. }
  1255. }
  1256. static void ggml_compute_forward_add1_f16_f16(
  1257. const ggml_compute_params * params,
  1258. ggml_tensor * dst) {
  1259. const ggml_tensor * src0 = dst->src[0];
  1260. const ggml_tensor * src1 = dst->src[1];
  1261. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1262. GGML_ASSERT(ggml_is_scalar(src1));
  1263. // scalar to add
  1264. const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
  1265. const int ith = params->ith;
  1266. const int nth = params->nth;
  1267. const int nr = ggml_nrows(src0);
  1268. GGML_TENSOR_UNARY_OP_LOCALS
  1269. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  1270. GGML_ASSERT(src1->type == GGML_TYPE_F16);
  1271. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  1272. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  1273. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1274. // rows per thread
  1275. const int dr = (nr + nth - 1)/nth;
  1276. // row range for this thread
  1277. const int ir0 = dr*ith;
  1278. const int ir1 = MIN(ir0 + dr, nr);
  1279. for (int ir = ir0; ir < ir1; ++ir) {
  1280. // src0 and dst are same shape => same indices
  1281. const int i3 = ir/(ne2*ne1);
  1282. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1283. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1284. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1285. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1286. for (int i = 0; i < ne0; i++) {
  1287. dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
  1288. }
  1289. }
  1290. }
  1291. static void ggml_compute_forward_add1_q_f32(
  1292. const ggml_compute_params * params,
  1293. ggml_tensor * dst) {
  1294. const ggml_tensor * src0 = dst->src[0];
  1295. const ggml_tensor * src1 = dst->src[1];
  1296. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1297. GGML_ASSERT(ggml_is_scalar(src1));
  1298. // scalar to add
  1299. const float v = *(float *) src1->data;
  1300. const int ith = params->ith;
  1301. const int nth = params->nth;
  1302. const int nr = ggml_nrows(src0);
  1303. GGML_TENSOR_UNARY_OP_LOCALS
  1304. const ggml_type type = src0->type;
  1305. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  1306. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
  1307. // we don't support permuted src0
  1308. GGML_ASSERT(nb00 == ggml_type_size(type));
  1309. // dst cannot be transposed or permuted
  1310. GGML_ASSERT(nb0 <= nb1);
  1311. GGML_ASSERT(nb1 <= nb2);
  1312. GGML_ASSERT(nb2 <= nb3);
  1313. GGML_ASSERT(ggml_is_quantized(src0->type));
  1314. GGML_ASSERT(dst->type == src0->type);
  1315. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1316. // rows per thread
  1317. const int dr = (nr + nth - 1)/nth;
  1318. // row range for this thread
  1319. const int ir0 = dr*ith;
  1320. const int ir1 = MIN(ir0 + dr, nr);
  1321. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  1322. for (int ir = ir0; ir < ir1; ++ir) {
  1323. // src0 and dst are same shape => same indices
  1324. const int i3 = ir/(ne2*ne1);
  1325. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1326. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1327. void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
  1328. void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
  1329. assert(ne0 % 32 == 0);
  1330. // unquantize row from src0 to temp buffer
  1331. dequantize_row_q(src0_row, wdata, ne0);
  1332. // add src1
  1333. ggml_vec_acc1_f32(ne0, wdata, v);
  1334. // quantize row to dst
  1335. quantize_row_q(wdata, dst_row, ne0);
  1336. }
  1337. }
  1338. static void ggml_compute_forward_add1_bf16_f32(
  1339. const ggml_compute_params * params,
  1340. ggml_tensor * dst) {
  1341. const ggml_tensor * src0 = dst->src[0];
  1342. const ggml_tensor * src1 = dst->src[1];
  1343. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1344. GGML_ASSERT(ggml_is_scalar(src1));
  1345. // scalar to add
  1346. const float v = *(float *) src1->data;
  1347. const int ith = params->ith;
  1348. const int nth = params->nth;
  1349. const int nr = ggml_nrows(src0);
  1350. GGML_TENSOR_UNARY_OP_LOCALS
  1351. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  1352. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1353. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  1354. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  1355. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  1356. // rows per thread
  1357. const int dr = (nr + nth - 1)/nth;
  1358. // row range for this thread
  1359. const int ir0 = dr*ith;
  1360. const int ir1 = MIN(ir0 + dr, nr);
  1361. for (int ir = ir0; ir < ir1; ++ir) {
  1362. // src0 and dst are same shape => same indices
  1363. const int i3 = ir/(ne2*ne1);
  1364. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1365. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1366. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1367. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1368. for (int i = 0; i < ne0; i++) {
  1369. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  1370. }
  1371. }
  1372. }
  1373. static void ggml_compute_forward_add1_bf16_bf16(
  1374. const ggml_compute_params * params,
  1375. ggml_tensor * dst) {
  1376. const ggml_tensor * src0 = dst->src[0];
  1377. const ggml_tensor * src1 = dst->src[1];
  1378. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1379. GGML_ASSERT(ggml_is_scalar(src1));
  1380. // scalar to add
  1381. const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
  1382. const int ith = params->ith;
  1383. const int nth = params->nth;
  1384. const int nr = ggml_nrows(src0);
  1385. GGML_TENSOR_UNARY_OP_LOCALS
  1386. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  1387. GGML_ASSERT(src1->type == GGML_TYPE_BF16);
  1388. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  1389. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  1390. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  1391. // rows per thread
  1392. const int dr = (nr + nth - 1)/nth;
  1393. // row range for this thread
  1394. const int ir0 = dr*ith;
  1395. const int ir1 = MIN(ir0 + dr, nr);
  1396. for (int ir = ir0; ir < ir1; ++ir) {
  1397. // src0 and dst are same shape => same indices
  1398. const int i3 = ir/(ne2*ne1);
  1399. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1400. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1401. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1402. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1403. for (int i = 0; i < ne0; i++) {
  1404. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  1405. }
  1406. }
  1407. }
  1408. void ggml_compute_forward_add1(
  1409. const ggml_compute_params * params,
  1410. ggml_tensor * dst) {
  1411. const ggml_tensor * src0 = dst->src[0];
  1412. const ggml_tensor * src1 = dst->src[1];
  1413. switch (src0->type) {
  1414. case GGML_TYPE_F32:
  1415. {
  1416. ggml_compute_forward_add1_f32(params, dst);
  1417. } break;
  1418. case GGML_TYPE_F16:
  1419. {
  1420. if (src1->type == GGML_TYPE_F16) {
  1421. ggml_compute_forward_add1_f16_f16(params, dst);
  1422. }
  1423. else if (src1->type == GGML_TYPE_F32) {
  1424. ggml_compute_forward_add1_f16_f32(params, dst);
  1425. }
  1426. else {
  1427. GGML_ABORT("fatal error");
  1428. }
  1429. } break;
  1430. case GGML_TYPE_BF16:
  1431. {
  1432. if (src1->type == GGML_TYPE_BF16) {
  1433. ggml_compute_forward_add1_bf16_bf16(params, dst);
  1434. }
  1435. else if (src1->type == GGML_TYPE_F32) {
  1436. ggml_compute_forward_add1_bf16_f32(params, dst);
  1437. }
  1438. else {
  1439. GGML_ABORT("fatal error");
  1440. }
  1441. } break;
  1442. case GGML_TYPE_Q4_0:
  1443. case GGML_TYPE_Q4_1:
  1444. case GGML_TYPE_Q5_0:
  1445. case GGML_TYPE_Q5_1:
  1446. case GGML_TYPE_Q8_0:
  1447. case GGML_TYPE_Q8_1:
  1448. case GGML_TYPE_Q2_K:
  1449. case GGML_TYPE_Q3_K:
  1450. case GGML_TYPE_Q4_K:
  1451. case GGML_TYPE_Q5_K:
  1452. case GGML_TYPE_Q6_K:
  1453. case GGML_TYPE_TQ1_0:
  1454. case GGML_TYPE_TQ2_0:
  1455. case GGML_TYPE_IQ2_XXS:
  1456. case GGML_TYPE_IQ2_XS:
  1457. case GGML_TYPE_IQ3_XXS:
  1458. case GGML_TYPE_IQ1_S:
  1459. case GGML_TYPE_IQ1_M:
  1460. case GGML_TYPE_IQ4_NL:
  1461. case GGML_TYPE_IQ4_XS:
  1462. case GGML_TYPE_IQ3_S:
  1463. case GGML_TYPE_IQ2_S:
  1464. {
  1465. ggml_compute_forward_add1_q_f32(params, dst);
  1466. } break;
  1467. default:
  1468. {
  1469. GGML_ABORT("fatal error");
  1470. }
  1471. }
  1472. }
  1473. // ggml_compute_forward_acc
  1474. static void ggml_compute_forward_acc_f32(
  1475. const ggml_compute_params * params,
  1476. ggml_tensor * dst) {
  1477. const ggml_tensor * src0 = dst->src[0];
  1478. const ggml_tensor * src1 = dst->src[1];
  1479. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1480. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  1481. // view src0 and dst with these strides and data offset inbytes during acc
  1482. // nb0 is implicitly element_size because src0 and dst are contiguous
  1483. size_t nb1 = ((int32_t *) dst->op_params)[0];
  1484. size_t nb2 = ((int32_t *) dst->op_params)[1];
  1485. size_t nb3 = ((int32_t *) dst->op_params)[2];
  1486. size_t offset = ((int32_t *) dst->op_params)[3];
  1487. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  1488. if (!inplace) {
  1489. if (params->ith == 0) {
  1490. // memcpy needs to be synchronized across threads to avoid race conditions.
  1491. // => do it in INIT phase
  1492. memcpy(
  1493. ((char *) dst->data),
  1494. ((char *) src0->data),
  1495. ggml_nbytes(dst));
  1496. }
  1497. ggml_barrier(params->threadpool);
  1498. }
  1499. const int ith = params->ith;
  1500. const int nth = params->nth;
  1501. const int nr = ggml_nrows(src1);
  1502. const int nc = src1->ne[0];
  1503. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  1504. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  1505. // src0 and dst as viewed during acc
  1506. const size_t nb0 = ggml_element_size(src0);
  1507. const size_t nb00 = nb0;
  1508. const size_t nb01 = nb1;
  1509. const size_t nb02 = nb2;
  1510. const size_t nb03 = nb3;
  1511. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
  1512. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
  1513. GGML_ASSERT(nb10 == sizeof(float));
  1514. // rows per thread
  1515. const int dr = (nr + nth - 1)/nth;
  1516. // row range for this thread
  1517. const int ir0 = dr*ith;
  1518. const int ir1 = MIN(ir0 + dr, nr);
  1519. for (int ir = ir0; ir < ir1; ++ir) {
  1520. // src0 and dst are viewed with shape of src1 and offset
  1521. // => same indices
  1522. const int i3 = ir/(ne12*ne11);
  1523. const int i2 = (ir - i3*ne12*ne11)/ne11;
  1524. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  1525. #ifdef GGML_USE_ACCELERATE
  1526. vDSP_vadd(
  1527. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
  1528. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
  1529. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
  1530. #else
  1531. ggml_vec_add_f32(nc,
  1532. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  1533. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
  1534. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  1535. #endif
  1536. }
  1537. }
  1538. void ggml_compute_forward_acc(
  1539. const ggml_compute_params * params,
  1540. ggml_tensor * dst) {
  1541. const ggml_tensor * src0 = dst->src[0];
  1542. switch (src0->type) {
  1543. case GGML_TYPE_F32:
  1544. {
  1545. ggml_compute_forward_acc_f32(params, dst);
  1546. } break;
  1547. case GGML_TYPE_F16:
  1548. case GGML_TYPE_BF16:
  1549. case GGML_TYPE_Q4_0:
  1550. case GGML_TYPE_Q4_1:
  1551. case GGML_TYPE_Q5_0:
  1552. case GGML_TYPE_Q5_1:
  1553. case GGML_TYPE_Q8_0:
  1554. case GGML_TYPE_Q8_1:
  1555. case GGML_TYPE_Q2_K:
  1556. case GGML_TYPE_Q3_K:
  1557. case GGML_TYPE_Q4_K:
  1558. case GGML_TYPE_Q5_K:
  1559. case GGML_TYPE_Q6_K:
  1560. case GGML_TYPE_TQ1_0:
  1561. case GGML_TYPE_TQ2_0:
  1562. case GGML_TYPE_IQ2_XXS:
  1563. case GGML_TYPE_IQ2_XS:
  1564. case GGML_TYPE_IQ3_XXS:
  1565. case GGML_TYPE_IQ1_S:
  1566. case GGML_TYPE_IQ1_M:
  1567. case GGML_TYPE_IQ4_NL:
  1568. case GGML_TYPE_IQ4_XS:
  1569. case GGML_TYPE_IQ3_S:
  1570. case GGML_TYPE_IQ2_S:
  1571. default:
  1572. {
  1573. GGML_ABORT("fatal error");
  1574. }
  1575. }
  1576. }
  1577. // ggml_compute_forward_sum
  1578. static void ggml_compute_forward_sum_f32(
  1579. const ggml_compute_params * params,
  1580. ggml_tensor * dst) {
  1581. const ggml_tensor * src0 = dst->src[0];
  1582. if (params->ith != 0) {
  1583. return;
  1584. }
  1585. assert(ggml_is_scalar(dst));
  1586. assert(src0->nb[0] == sizeof(float));
  1587. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1588. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1589. ggml_float sum = 0;
  1590. ggml_float row_sum = 0;
  1591. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1592. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1593. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1594. ggml_vec_sum_f32_ggf(ne00,
  1595. &row_sum,
  1596. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  1597. sum += row_sum;
  1598. }
  1599. }
  1600. }
  1601. ((float *) dst->data)[0] = sum;
  1602. }
  1603. static void ggml_compute_forward_sum_f16(
  1604. const ggml_compute_params * params,
  1605. ggml_tensor * dst) {
  1606. const ggml_tensor * src0 = dst->src[0];
  1607. if (params->ith != 0) {
  1608. return;
  1609. }
  1610. assert(ggml_is_scalar(dst));
  1611. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  1612. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1613. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1614. float sum = 0;
  1615. float row_sum = 0;
  1616. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1617. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1618. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1619. ggml_vec_sum_f16_ggf(ne00,
  1620. &row_sum,
  1621. (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  1622. sum += row_sum;
  1623. }
  1624. }
  1625. }
  1626. ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
  1627. }
  1628. static void ggml_compute_forward_sum_bf16(
  1629. const ggml_compute_params * params,
  1630. ggml_tensor * dst) {
  1631. const ggml_tensor * src0 = dst->src[0];
  1632. if (params->ith != 0) {
  1633. return;
  1634. }
  1635. assert(ggml_is_scalar(dst));
  1636. assert(src0->nb[0] == sizeof(ggml_bf16_t));
  1637. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1638. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1639. float sum = 0;
  1640. float row_sum = 0;
  1641. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1642. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1643. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1644. ggml_vec_sum_bf16_ggf(ne00,
  1645. &row_sum,
  1646. (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  1647. sum += row_sum;
  1648. }
  1649. }
  1650. }
  1651. ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
  1652. }
  1653. void ggml_compute_forward_sum(
  1654. const ggml_compute_params * params,
  1655. ggml_tensor * dst) {
  1656. const ggml_tensor * src0 = dst->src[0];
  1657. switch (src0->type) {
  1658. case GGML_TYPE_F32:
  1659. {
  1660. ggml_compute_forward_sum_f32(params, dst);
  1661. } break;
  1662. case GGML_TYPE_F16:
  1663. {
  1664. ggml_compute_forward_sum_f16(params, dst);
  1665. } break;
  1666. case GGML_TYPE_BF16:
  1667. {
  1668. ggml_compute_forward_sum_bf16(params, dst);
  1669. } break;
  1670. default:
  1671. {
  1672. GGML_ABORT("fatal error");
  1673. }
  1674. }
  1675. }
  1676. // ggml_compute_forward_sum_rows
  1677. static void ggml_compute_forward_sum_rows_f32(
  1678. const ggml_compute_params * params,
  1679. ggml_tensor * dst) {
  1680. const ggml_tensor * src0 = dst->src[0];
  1681. if (params->ith != 0) {
  1682. return;
  1683. }
  1684. GGML_ASSERT(src0->nb[0] == sizeof(float));
  1685. GGML_ASSERT(dst->nb[0] == sizeof(float));
  1686. GGML_TENSOR_UNARY_OP_LOCALS
  1687. GGML_ASSERT(ne0 == 1);
  1688. GGML_ASSERT(ne1 == ne01);
  1689. GGML_ASSERT(ne2 == ne02);
  1690. GGML_ASSERT(ne3 == ne03);
  1691. for (int64_t i3 = 0; i3 < ne03; i3++) {
  1692. for (int64_t i2 = 0; i2 < ne02; i2++) {
  1693. for (int64_t i1 = 0; i1 < ne01; i1++) {
  1694. float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
  1695. float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
  1696. float row_sum = 0;
  1697. ggml_vec_sum_f32(ne00, &row_sum, src_row);
  1698. dst_row[0] = row_sum;
  1699. }
  1700. }
  1701. }
  1702. }
  1703. void ggml_compute_forward_sum_rows(
  1704. const ggml_compute_params * params,
  1705. ggml_tensor * dst) {
  1706. const ggml_tensor * src0 = dst->src[0];
  1707. switch (src0->type) {
  1708. case GGML_TYPE_F32:
  1709. {
  1710. ggml_compute_forward_sum_rows_f32(params, dst);
  1711. } break;
  1712. default:
  1713. {
  1714. GGML_ABORT("fatal error");
  1715. }
  1716. }
  1717. }
  1718. // ggml_compute_forward_mean
  1719. static void ggml_compute_forward_mean_f32(
  1720. const ggml_compute_params * params,
  1721. ggml_tensor * dst) {
  1722. const ggml_tensor * src0 = dst->src[0];
  1723. if (params->ith != 0) {
  1724. return;
  1725. }
  1726. assert(src0->nb[0] == sizeof(float));
  1727. GGML_TENSOR_UNARY_OP_LOCALS
  1728. assert(ne0 == 1);
  1729. assert(ne1 == ne01);
  1730. assert(ne2 == ne02);
  1731. assert(ne3 == ne03);
  1732. GGML_UNUSED(ne0);
  1733. GGML_UNUSED(ne1);
  1734. GGML_UNUSED(ne2);
  1735. GGML_UNUSED(ne3);
  1736. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1737. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1738. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1739. ggml_vec_sum_f32(ne00,
  1740. (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  1741. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  1742. *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
  1743. }
  1744. }
  1745. }
  1746. }
  1747. void ggml_compute_forward_mean(
  1748. const ggml_compute_params * params,
  1749. ggml_tensor * dst) {
  1750. const ggml_tensor * src0 = dst->src[0];
  1751. switch (src0->type) {
  1752. case GGML_TYPE_F32:
  1753. {
  1754. ggml_compute_forward_mean_f32(params, dst);
  1755. } break;
  1756. default:
  1757. {
  1758. GGML_ABORT("fatal error");
  1759. }
  1760. }
  1761. }
  1762. // ggml_compute_forward_argmax
  1763. static void ggml_compute_forward_argmax_f32(
  1764. const ggml_compute_params * params,
  1765. ggml_tensor * dst) {
  1766. const ggml_tensor * src0 = dst->src[0];
  1767. if (params->ith != 0) {
  1768. return;
  1769. }
  1770. assert(src0->nb[0] == sizeof(float));
  1771. assert(dst->nb[0] == sizeof(float));
  1772. const int64_t ne00 = src0->ne[0];
  1773. const int64_t ne01 = src0->ne[1];
  1774. const size_t nb01 = src0->nb[1];
  1775. const size_t nb0 = dst->nb[0];
  1776. for (int64_t i1 = 0; i1 < ne01; i1++) {
  1777. float * src = (float *) ((char *) src0->data + i1*nb01);
  1778. int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
  1779. int v = 0;
  1780. ggml_vec_argmax_f32(ne00, &v, src);
  1781. dst_[0] = v;
  1782. }
  1783. }
  1784. void ggml_compute_forward_argmax(
  1785. const ggml_compute_params * params,
  1786. ggml_tensor * dst) {
  1787. const ggml_tensor * src0 = dst->src[0];
  1788. switch (src0->type) {
  1789. case GGML_TYPE_F32:
  1790. {
  1791. ggml_compute_forward_argmax_f32(params, dst);
  1792. } break;
  1793. default:
  1794. {
  1795. GGML_ABORT("fatal error");
  1796. }
  1797. }
  1798. }
  1799. // ggml_compute_forward_count_equal
  1800. static void ggml_compute_forward_count_equal_i32(
  1801. const ggml_compute_params * params,
  1802. ggml_tensor * dst) {
  1803. const ggml_tensor * src0 = dst->src[0];
  1804. const ggml_tensor * src1 = dst->src[1];
  1805. GGML_TENSOR_BINARY_OP_LOCALS;
  1806. GGML_ASSERT(src0->type == GGML_TYPE_I32);
  1807. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  1808. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  1809. GGML_ASSERT(ggml_is_scalar(dst));
  1810. GGML_ASSERT(dst->type == GGML_TYPE_I64);
  1811. const int64_t nr = ggml_nrows(src0);
  1812. const int ith = params->ith;
  1813. const int nth = params->nth;
  1814. int64_t * sums = (int64_t *) params->wdata;
  1815. int64_t sum_thread = 0;
  1816. // rows per thread
  1817. const int64_t dr = (nr + nth - 1)/nth;
  1818. // row range for this thread
  1819. const int64_t ir0 = dr*ith;
  1820. const int64_t ir1 = MIN(ir0 + dr, nr);
  1821. for (int64_t ir = ir0; ir < ir1; ++ir) {
  1822. const int64_t i03 = ir / (ne02*ne01);
  1823. const int64_t i02 = (ir - i03*ne03) / ne01;
  1824. const int64_t i01 = ir - i03*ne03 - i02*ne02;
  1825. const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
  1826. const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
  1827. for (int64_t i00 = 0; i00 < ne00; ++i00) {
  1828. const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
  1829. const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
  1830. sum_thread += val0 == val1;
  1831. }
  1832. }
  1833. if (ith != 0) {
  1834. sums[ith] = sum_thread;
  1835. }
  1836. ggml_barrier(params->threadpool);
  1837. if (ith != 0) {
  1838. return;
  1839. }
  1840. for (int ith_other = 1; ith_other < nth; ++ith_other) {
  1841. sum_thread += sums[ith_other];
  1842. }
  1843. *((int64_t *) dst->data) = sum_thread;
  1844. }
  1845. void ggml_compute_forward_count_equal(
  1846. const ggml_compute_params * params,
  1847. ggml_tensor * dst) {
  1848. const ggml_tensor * src0 = dst->src[0];
  1849. switch (src0->type) {
  1850. case GGML_TYPE_I32:
  1851. {
  1852. ggml_compute_forward_count_equal_i32(params, dst);
  1853. } break;
  1854. default:
  1855. {
  1856. GGML_ABORT("fatal error");
  1857. }
  1858. }
  1859. }
  1860. // ggml_compute_forward_repeat
  1861. static void ggml_compute_forward_repeat_f32(
  1862. const ggml_compute_params * params,
  1863. ggml_tensor * dst) {
  1864. const ggml_tensor * src0 = dst->src[0];
  1865. if (params->ith != 0) {
  1866. return;
  1867. }
  1868. GGML_ASSERT(ggml_can_repeat(src0, dst));
  1869. GGML_TENSOR_UNARY_OP_LOCALS
  1870. // guaranteed to be an integer due to the check in ggml_can_repeat
  1871. const int nr0 = (int)(ne0/ne00);
  1872. const int nr1 = (int)(ne1/ne01);
  1873. const int nr2 = (int)(ne2/ne02);
  1874. const int nr3 = (int)(ne3/ne03);
  1875. // TODO: support for transposed / permuted tensors
  1876. GGML_ASSERT(nb0 == sizeof(float));
  1877. GGML_ASSERT(nb00 == sizeof(float));
  1878. // TODO: maybe this is not optimal?
  1879. for (int i3 = 0; i3 < nr3; i3++) {
  1880. for (int k3 = 0; k3 < ne03; k3++) {
  1881. for (int i2 = 0; i2 < nr2; i2++) {
  1882. for (int k2 = 0; k2 < ne02; k2++) {
  1883. for (int i1 = 0; i1 < nr1; i1++) {
  1884. for (int k1 = 0; k1 < ne01; k1++) {
  1885. for (int i0 = 0; i0 < nr0; i0++) {
  1886. ggml_vec_cpy_f32(ne00,
  1887. (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
  1888. (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
  1889. }
  1890. }
  1891. }
  1892. }
  1893. }
  1894. }
  1895. }
  1896. }
  1897. static void ggml_compute_forward_repeat_f16(
  1898. const ggml_compute_params * params,
  1899. ggml_tensor * dst) {
  1900. const ggml_tensor * src0 = dst->src[0];
  1901. if (params->ith != 0) {
  1902. return;
  1903. }
  1904. GGML_ASSERT(ggml_can_repeat(src0, dst));
  1905. GGML_TENSOR_UNARY_OP_LOCALS
  1906. // guaranteed to be an integer due to the check in ggml_can_repeat
  1907. const int nr0 = (int)(ne0/ne00);
  1908. const int nr1 = (int)(ne1/ne01);
  1909. const int nr2 = (int)(ne2/ne02);
  1910. const int nr3 = (int)(ne3/ne03);
  1911. // TODO: support for transposed / permuted tensors
  1912. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  1913. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1914. // TODO: maybe this is not optimal?
  1915. for (int i3 = 0; i3 < nr3; i3++) {
  1916. for (int k3 = 0; k3 < ne03; k3++) {
  1917. for (int i2 = 0; i2 < nr2; i2++) {
  1918. for (int k2 = 0; k2 < ne02; k2++) {
  1919. for (int i1 = 0; i1 < nr1; i1++) {
  1920. for (int k1 = 0; k1 < ne01; k1++) {
  1921. for (int i0 = 0; i0 < nr0; i0++) {
  1922. ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
  1923. ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
  1924. // ggml_vec_cpy_f16(ne00, y, x)
  1925. for (int i = 0; i < ne00; ++i) {
  1926. y[i] = x[i];
  1927. }
  1928. }
  1929. }
  1930. }
  1931. }
  1932. }
  1933. }
  1934. }
  1935. }
  1936. void ggml_compute_forward_repeat(
  1937. const ggml_compute_params * params,
  1938. ggml_tensor * dst) {
  1939. const ggml_tensor * src0 = dst->src[0];
  1940. switch (src0->type) {
  1941. case GGML_TYPE_F16:
  1942. case GGML_TYPE_BF16:
  1943. case GGML_TYPE_I16:
  1944. {
  1945. ggml_compute_forward_repeat_f16(params, dst);
  1946. } break;
  1947. case GGML_TYPE_F32:
  1948. case GGML_TYPE_I32:
  1949. {
  1950. ggml_compute_forward_repeat_f32(params, dst);
  1951. } break;
  1952. default:
  1953. {
  1954. GGML_ABORT("fatal error");
  1955. }
  1956. }
  1957. }
  1958. // ggml_compute_forward_repeat_back
  1959. static void ggml_compute_forward_repeat_back_f32(
  1960. const ggml_compute_params * params,
  1961. ggml_tensor * dst) {
  1962. const ggml_tensor * src0 = dst->src[0];
  1963. if (params->ith != 0) {
  1964. return;
  1965. }
  1966. GGML_ASSERT(ggml_can_repeat(dst, src0));
  1967. GGML_TENSOR_UNARY_OP_LOCALS
  1968. // guaranteed to be an integer due to the check in ggml_can_repeat
  1969. const int nr0 = (int)(ne00/ne0);
  1970. const int nr1 = (int)(ne01/ne1);
  1971. const int nr2 = (int)(ne02/ne2);
  1972. const int nr3 = (int)(ne03/ne3);
  1973. // TODO: support for transposed / permuted tensors
  1974. GGML_ASSERT(nb0 == sizeof(float));
  1975. GGML_ASSERT(nb00 == sizeof(float));
  1976. if (ggml_is_contiguous(dst)) {
  1977. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  1978. } else {
  1979. for (int k3 = 0; k3 < ne3; k3++) {
  1980. for (int k2 = 0; k2 < ne2; k2++) {
  1981. for (int k1 = 0; k1 < ne1; k1++) {
  1982. ggml_vec_set_f32(ne0,
  1983. (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
  1984. 0);
  1985. }
  1986. }
  1987. }
  1988. }
  1989. // TODO: maybe this is not optimal?
  1990. for (int i3 = 0; i3 < nr3; i3++) {
  1991. for (int k3 = 0; k3 < ne3; k3++) {
  1992. for (int i2 = 0; i2 < nr2; i2++) {
  1993. for (int k2 = 0; k2 < ne2; k2++) {
  1994. for (int i1 = 0; i1 < nr1; i1++) {
  1995. for (int k1 = 0; k1 < ne1; k1++) {
  1996. for (int i0 = 0; i0 < nr0; i0++) {
  1997. ggml_vec_acc_f32(ne0,
  1998. (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
  1999. (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
  2000. }
  2001. }
  2002. }
  2003. }
  2004. }
  2005. }
  2006. }
  2007. }
  2008. void ggml_compute_forward_repeat_back(
  2009. const ggml_compute_params * params,
  2010. ggml_tensor * dst) {
  2011. const ggml_tensor * src0 = dst->src[0];
  2012. switch (src0->type) {
  2013. case GGML_TYPE_F32:
  2014. {
  2015. ggml_compute_forward_repeat_back_f32(params, dst);
  2016. } break;
  2017. default:
  2018. {
  2019. GGML_ABORT("fatal error");
  2020. }
  2021. }
  2022. }
  2023. // ggml_compute_forward_concat
  2024. static void ggml_compute_forward_concat_any(
  2025. const ggml_compute_params * params,
  2026. ggml_tensor * dst) {
  2027. const ggml_tensor * src0 = dst->src[0];
  2028. const ggml_tensor * src1 = dst->src[1];
  2029. const size_t len = ggml_type_size(src0->type);
  2030. const int ith = params->ith;
  2031. const int nth = params->nth;
  2032. GGML_TENSOR_BINARY_OP_LOCALS
  2033. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2034. GGML_ASSERT(dim >= 0 && dim < 4);
  2035. int64_t o[4] = {0, 0, 0, 0};
  2036. o[dim] = src0->ne[dim];
  2037. const char * x;
  2038. // TODO: smarter multi-theading
  2039. for (int i3 = 0; i3 < ne3; i3++) {
  2040. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2041. for (int i1 = 0; i1 < ne1; i1++) {
  2042. for (int i0 = 0; i0 < ne0; i0++) {
  2043. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2044. x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
  2045. } else {
  2046. x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
  2047. }
  2048. char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
  2049. memcpy(y, x, len);
  2050. }
  2051. }
  2052. }
  2053. }
  2054. }
  2055. static void ggml_compute_forward_concat_i8(
  2056. const ggml_compute_params * params,
  2057. ggml_tensor * dst) {
  2058. const ggml_tensor * src0 = dst->src[0];
  2059. const ggml_tensor * src1 = dst->src[1];
  2060. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
  2061. const int ith = params->ith;
  2062. const int nth = params->nth;
  2063. GGML_TENSOR_BINARY_OP_LOCALS
  2064. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2065. GGML_ASSERT(dim >= 0 && dim < 4);
  2066. int64_t o[4] = {0, 0, 0, 0};
  2067. o[dim] = src0->ne[dim];
  2068. const int8_t * x;
  2069. // TODO: smarter multi-theading
  2070. for (int i3 = 0; i3 < ne3; i3++) {
  2071. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2072. for (int i1 = 0; i1 < ne1; i1++) {
  2073. for (int i0 = 0; i0 < ne0; i0++) {
  2074. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2075. x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  2076. } else {
  2077. x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  2078. }
  2079. int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  2080. *y = *x;
  2081. }
  2082. }
  2083. }
  2084. }
  2085. }
  2086. static void ggml_compute_forward_concat_f16(
  2087. const ggml_compute_params * params,
  2088. ggml_tensor * dst) {
  2089. const ggml_tensor * src0 = dst->src[0];
  2090. const ggml_tensor * src1 = dst->src[1];
  2091. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
  2092. const int ith = params->ith;
  2093. const int nth = params->nth;
  2094. GGML_TENSOR_BINARY_OP_LOCALS
  2095. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2096. GGML_ASSERT(dim >= 0 && dim < 4);
  2097. int64_t o[4] = {0, 0, 0, 0};
  2098. o[dim] = src0->ne[dim];
  2099. const ggml_fp16_t * x;
  2100. // TODO: smarter multi-theading
  2101. for (int i3 = 0; i3 < ne3; i3++) {
  2102. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2103. for (int i1 = 0; i1 < ne1; i1++) {
  2104. for (int i0 = 0; i0 < ne0; i0++) {
  2105. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2106. x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  2107. } else {
  2108. x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  2109. }
  2110. ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  2111. *y = *x;
  2112. }
  2113. }
  2114. }
  2115. }
  2116. }
  2117. static void ggml_compute_forward_concat_f32(
  2118. const ggml_compute_params * params,
  2119. ggml_tensor * dst) {
  2120. const ggml_tensor * src0 = dst->src[0];
  2121. const ggml_tensor * src1 = dst->src[1];
  2122. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
  2123. const int ith = params->ith;
  2124. const int nth = params->nth;
  2125. GGML_TENSOR_BINARY_OP_LOCALS
  2126. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2127. GGML_ASSERT(dim >= 0 && dim < 4);
  2128. int64_t o[4] = {0, 0, 0, 0};
  2129. o[dim] = src0->ne[dim];
  2130. const float * x;
  2131. // TODO: smarter multi-theading
  2132. for (int i3 = 0; i3 < ne3; i3++) {
  2133. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2134. for (int i1 = 0; i1 < ne1; i1++) {
  2135. for (int i0 = 0; i0 < ne0; i0++) {
  2136. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2137. x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  2138. } else {
  2139. x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  2140. }
  2141. float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  2142. *y = *x;
  2143. }
  2144. }
  2145. }
  2146. }
  2147. }
  2148. void ggml_compute_forward_concat(
  2149. const ggml_compute_params * params,
  2150. ggml_tensor * dst) {
  2151. const ggml_tensor * src0 = dst->src[0];
  2152. switch (src0->type) {
  2153. case GGML_TYPE_F16:
  2154. case GGML_TYPE_BF16:
  2155. case GGML_TYPE_I16:
  2156. {
  2157. ggml_compute_forward_concat_f16(params, dst);
  2158. } break;
  2159. case GGML_TYPE_I8:
  2160. {
  2161. ggml_compute_forward_concat_i8(params, dst);
  2162. } break;
  2163. case GGML_TYPE_F32:
  2164. case GGML_TYPE_I32:
  2165. {
  2166. ggml_compute_forward_concat_f32(params, dst);
  2167. } break;
  2168. default:
  2169. {
  2170. ggml_compute_forward_concat_any(params, dst);
  2171. }
  2172. }
  2173. }
  2174. // ggml_compute_forward_gelu
  2175. static void ggml_compute_forward_gelu_f32(
  2176. const ggml_compute_params * params,
  2177. ggml_tensor * dst) {
  2178. const ggml_tensor * src0 = dst->src[0];
  2179. assert(ggml_is_contiguous_1(src0));
  2180. assert(ggml_is_contiguous_1(dst));
  2181. assert(ggml_are_same_shape(src0, dst));
  2182. const int ith = params->ith;
  2183. const int nth = params->nth;
  2184. const int nc = src0->ne[0];
  2185. const int nr = ggml_nrows(src0);
  2186. // rows per thread
  2187. const int dr = (nr + nth - 1)/nth;
  2188. // row range for this thread
  2189. const int ir0 = dr*ith;
  2190. const int ir1 = MIN(ir0 + dr, nr);
  2191. for (int i1 = ir0; i1 < ir1; i1++) {
  2192. ggml_vec_gelu_f32(nc,
  2193. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2194. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2195. #ifndef NDEBUG
  2196. for (int k = 0; k < nc; k++) {
  2197. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2198. GGML_UNUSED(x);
  2199. assert(!isnan(x));
  2200. assert(!isinf(x));
  2201. }
  2202. #endif
  2203. }
  2204. }
  2205. static void ggml_compute_forward_gelu_f16(
  2206. const ggml_compute_params * params,
  2207. ggml_tensor * dst) {
  2208. const ggml_tensor * src0 = dst->src[0];
  2209. assert(ggml_is_contiguous_1(src0));
  2210. assert(ggml_is_contiguous_1(dst));
  2211. assert(ggml_are_same_shape(src0, dst));
  2212. const int ith = params->ith;
  2213. const int nth = params->nth;
  2214. const int nc = src0->ne[0];
  2215. const int nr = ggml_nrows(src0);
  2216. // rows per thread
  2217. const int dr = (nr + nth - 1)/nth;
  2218. // row range for this thread
  2219. const int ir0 = dr*ith;
  2220. const int ir1 = MIN(ir0 + dr, nr);
  2221. for (int i1 = ir0; i1 < ir1; i1++) {
  2222. ggml_vec_gelu_f16(nc,
  2223. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2224. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2225. #ifndef NDEBUG
  2226. for (int k = 0; k < nc; k++) {
  2227. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2228. const float v = GGML_FP16_TO_FP32(x);
  2229. GGML_UNUSED(v);
  2230. assert(!isnan(v));
  2231. assert(!isinf(v));
  2232. }
  2233. #endif
  2234. }
  2235. }
  2236. static void ggml_compute_forward_gelu(
  2237. const ggml_compute_params * params,
  2238. ggml_tensor * dst) {
  2239. const ggml_tensor * src0 = dst->src[0];
  2240. switch (src0->type) {
  2241. case GGML_TYPE_F32:
  2242. {
  2243. ggml_compute_forward_gelu_f32(params, dst);
  2244. } break;
  2245. case GGML_TYPE_F16:
  2246. {
  2247. ggml_compute_forward_gelu_f16(params, dst);
  2248. } break;
  2249. default:
  2250. {
  2251. GGML_ABORT("fatal error");
  2252. }
  2253. }
  2254. }
  2255. // ggml_compute_forward_gelu_quick
  2256. static void ggml_compute_forward_gelu_quick_f32(
  2257. const ggml_compute_params * params,
  2258. ggml_tensor * dst) {
  2259. const ggml_tensor * src0 = dst->src[0];
  2260. assert(ggml_is_contiguous_1(src0));
  2261. assert(ggml_is_contiguous_1(dst));
  2262. assert(ggml_are_same_shape(src0, dst));
  2263. const int ith = params->ith;
  2264. const int nth = params->nth;
  2265. const int nc = src0->ne[0];
  2266. const int nr = ggml_nrows(src0);
  2267. // rows per thread
  2268. const int dr = (nr + nth - 1)/nth;
  2269. // row range for this thread
  2270. const int ir0 = dr*ith;
  2271. const int ir1 = MIN(ir0 + dr, nr);
  2272. for (int i1 = ir0; i1 < ir1; i1++) {
  2273. ggml_vec_gelu_quick_f32(nc,
  2274. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2275. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2276. #ifndef NDEBUG
  2277. for (int k = 0; k < nc; k++) {
  2278. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2279. GGML_UNUSED(x);
  2280. assert(!isnan(x));
  2281. assert(!isinf(x));
  2282. }
  2283. #endif
  2284. }
  2285. }
  2286. static void ggml_compute_forward_gelu_quick_f16(
  2287. const ggml_compute_params * params,
  2288. ggml_tensor * dst) {
  2289. const ggml_tensor * src0 = dst->src[0];
  2290. assert(ggml_is_contiguous_1(src0));
  2291. assert(ggml_is_contiguous_1(dst));
  2292. assert(ggml_are_same_shape(src0, dst));
  2293. const int ith = params->ith;
  2294. const int nth = params->nth;
  2295. const int nc = src0->ne[0];
  2296. const int nr = ggml_nrows(src0);
  2297. // rows per thread
  2298. const int dr = (nr + nth - 1)/nth;
  2299. // row range for this thread
  2300. const int ir0 = dr*ith;
  2301. const int ir1 = MIN(ir0 + dr, nr);
  2302. for (int i1 = ir0; i1 < ir1; i1++) {
  2303. ggml_vec_gelu_quick_f16(nc,
  2304. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2305. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2306. #ifndef NDEBUG
  2307. for (int k = 0; k < nc; k++) {
  2308. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2309. const float v = GGML_FP16_TO_FP32(x);
  2310. GGML_UNUSED(v);
  2311. assert(!isnan(v));
  2312. assert(!isinf(v));
  2313. }
  2314. #endif
  2315. }
  2316. }
  2317. static void ggml_compute_forward_gelu_quick(
  2318. const ggml_compute_params * params,
  2319. ggml_tensor * dst) {
  2320. const ggml_tensor * src0 = dst->src[0];
  2321. switch (src0->type) {
  2322. case GGML_TYPE_F32:
  2323. {
  2324. ggml_compute_forward_gelu_quick_f32(params, dst);
  2325. } break;
  2326. case GGML_TYPE_F16:
  2327. {
  2328. ggml_compute_forward_gelu_quick_f16(params, dst);
  2329. } break;
  2330. default:
  2331. {
  2332. GGML_ABORT("fatal error");
  2333. }
  2334. }
  2335. }
  2336. // ggml_compute_forward_silu
  2337. static void ggml_compute_forward_silu_f32(
  2338. const ggml_compute_params * params,
  2339. ggml_tensor * dst) {
  2340. const ggml_tensor * src0 = dst->src[0];
  2341. assert(ggml_is_contiguous_1(src0));
  2342. assert(ggml_is_contiguous_1(dst));
  2343. assert(ggml_are_same_shape(src0, dst));
  2344. const int ith = params->ith;
  2345. const int nth = params->nth;
  2346. const int nc = src0->ne[0];
  2347. const int nr = ggml_nrows(src0);
  2348. // rows per thread
  2349. const int dr = (nr + nth - 1)/nth;
  2350. // row range for this thread
  2351. const int ir0 = dr*ith;
  2352. const int ir1 = MIN(ir0 + dr, nr);
  2353. for (int i1 = ir0; i1 < ir1; i1++) {
  2354. ggml_vec_silu_f32(nc,
  2355. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2356. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2357. #ifndef NDEBUG
  2358. for (int k = 0; k < nc; k++) {
  2359. const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  2360. GGML_UNUSED(x);
  2361. assert(!isnan(x));
  2362. assert(!isinf(x));
  2363. }
  2364. #endif
  2365. }
  2366. }
  2367. static void ggml_compute_forward_silu_f16(
  2368. const ggml_compute_params * params,
  2369. ggml_tensor * dst) {
  2370. const ggml_tensor * src0 = dst->src[0];
  2371. assert(ggml_is_contiguous_1(src0));
  2372. assert(ggml_is_contiguous_1(dst));
  2373. assert(ggml_are_same_shape(src0, dst));
  2374. const int ith = params->ith;
  2375. const int nth = params->nth;
  2376. const int nc = src0->ne[0];
  2377. const int nr = ggml_nrows(src0);
  2378. // rows per thread
  2379. const int dr = (nr + nth - 1)/nth;
  2380. // row range for this thread
  2381. const int ir0 = dr*ith;
  2382. const int ir1 = MIN(ir0 + dr, nr);
  2383. for (int i1 = ir0; i1 < ir1; i1++) {
  2384. ggml_vec_silu_f16(nc,
  2385. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2386. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2387. #ifndef NDEBUG
  2388. for (int k = 0; k < nc; k++) {
  2389. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  2390. const float v = GGML_FP16_TO_FP32(x);
  2391. GGML_UNUSED(v);
  2392. assert(!isnan(v));
  2393. assert(!isinf(v));
  2394. }
  2395. #endif
  2396. }
  2397. }
  2398. static void ggml_compute_forward_silu(
  2399. const ggml_compute_params * params,
  2400. ggml_tensor * dst) {
  2401. const ggml_tensor * src0 = dst->src[0];
  2402. switch (src0->type) {
  2403. case GGML_TYPE_F32:
  2404. {
  2405. ggml_compute_forward_silu_f32(params, dst);
  2406. } break;
  2407. case GGML_TYPE_F16:
  2408. {
  2409. ggml_compute_forward_silu_f16(params, dst);
  2410. } break;
  2411. default:
  2412. {
  2413. GGML_ABORT("fatal error");
  2414. }
  2415. }
  2416. }
  2417. // ggml_compute_forward_leaky_relu
  2418. static void ggml_compute_forward_leaky_relu_f32(
  2419. const ggml_compute_params * params,
  2420. ggml_tensor * dst) {
  2421. const ggml_tensor * src0 = dst->src[0];
  2422. if (params->ith != 0) {
  2423. return;
  2424. }
  2425. assert(ggml_is_contiguous_1(src0));
  2426. assert(ggml_is_contiguous_1(dst));
  2427. assert(ggml_are_same_shape(src0, dst));
  2428. const int n = ggml_nrows(src0);
  2429. const int nc = src0->ne[0];
  2430. float negative_slope;
  2431. memcpy(&negative_slope, dst->op_params, sizeof(float));
  2432. assert(dst->nb[0] == sizeof(float));
  2433. assert(src0->nb[0] == sizeof(float));
  2434. for (int i = 0; i < n; i++) {
  2435. ggml_vec_leaky_relu_f32(nc,
  2436. (float *) ((char *) dst->data + i*( dst->nb[1])),
  2437. (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  2438. }
  2439. }
  2440. static void ggml_compute_forward_leaky_relu_f16(
  2441. const ggml_compute_params * params,
  2442. ggml_tensor * dst) {
  2443. const ggml_tensor * src0 = dst->src[0];
  2444. if (params->ith != 0) {
  2445. return;
  2446. }
  2447. assert(ggml_is_contiguous_1(src0));
  2448. assert(ggml_is_contiguous_1(dst));
  2449. assert(ggml_are_same_shape(src0, dst));
  2450. const int n = ggml_nrows(src0);
  2451. const int nc = src0->ne[0];
  2452. float negative_slope;
  2453. memcpy(&negative_slope, dst->op_params, sizeof(float));
  2454. assert(dst->nb[0] == sizeof(ggml_fp16_t));
  2455. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  2456. for (int i = 0; i < n; i++) {
  2457. ggml_vec_leaky_relu_f16(nc,
  2458. (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
  2459. (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  2460. }
  2461. }
  2462. void ggml_compute_forward_leaky_relu(
  2463. const ggml_compute_params * params,
  2464. ggml_tensor * dst) {
  2465. const ggml_tensor * src0 = dst->src[0];
  2466. switch (src0->type) {
  2467. case GGML_TYPE_F32:
  2468. {
  2469. ggml_compute_forward_leaky_relu_f32(params, dst);
  2470. } break;
  2471. case GGML_TYPE_F16:
  2472. {
  2473. ggml_compute_forward_leaky_relu_f16(params, dst);
  2474. } break;
  2475. default:
  2476. {
  2477. GGML_ABORT("fatal error");
  2478. }
  2479. }
  2480. }
  2481. // ggml_compute_forward_silu_back
  2482. static void ggml_compute_forward_silu_back_f32(
  2483. const ggml_compute_params * params,
  2484. ggml_tensor * dst) {
  2485. const ggml_tensor * grad = dst->src[0];
  2486. const ggml_tensor * src1 = dst->src[1];
  2487. assert(ggml_is_contiguous_1(grad));
  2488. assert(ggml_is_contiguous_1(src1));
  2489. assert(ggml_is_contiguous_1(dst));
  2490. assert(ggml_are_same_shape(src1, dst));
  2491. assert(ggml_are_same_shape(src1, grad));
  2492. const int ith = params->ith;
  2493. const int nth = params->nth;
  2494. const int nc = src1->ne[0];
  2495. const int nr = ggml_nrows(src1);
  2496. // rows per thread
  2497. const int dr = (nr + nth - 1)/nth;
  2498. // row range for this thread
  2499. const int ir0 = dr*ith;
  2500. const int ir1 = MIN(ir0 + dr, nr);
  2501. for (int i1 = ir0; i1 < ir1; i1++) {
  2502. ggml_vec_silu_backward_f32(nc,
  2503. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2504. (float *) ((char *) src1->data + i1*(src1->nb[1])),
  2505. (float *) ((char *) grad->data + i1*(grad->nb[1])));
  2506. #ifndef NDEBUG
  2507. for (int k = 0; k < nc; k++) {
  2508. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2509. GGML_UNUSED(x);
  2510. assert(!isnan(x));
  2511. assert(!isinf(x));
  2512. }
  2513. #endif
  2514. }
  2515. }
  2516. static void ggml_compute_forward_silu_back_f16(
  2517. const ggml_compute_params * params,
  2518. ggml_tensor * dst) {
  2519. const ggml_tensor * grad = dst->src[0];
  2520. const ggml_tensor * src1 = dst->src[1];
  2521. assert(ggml_is_contiguous_1(grad));
  2522. assert(ggml_is_contiguous_1(src1));
  2523. assert(ggml_is_contiguous_1(dst));
  2524. assert(ggml_are_same_shape(src1, dst));
  2525. assert(ggml_are_same_shape(src1, grad));
  2526. const int ith = params->ith;
  2527. const int nth = params->nth;
  2528. const int nc = src1->ne[0];
  2529. const int nr = ggml_nrows(src1);
  2530. // rows per thread
  2531. const int dr = (nr + nth - 1)/nth;
  2532. // row range for this thread
  2533. const int ir0 = dr*ith;
  2534. const int ir1 = MIN(ir0 + dr, nr);
  2535. for (int i1 = ir0; i1 < ir1; i1++) {
  2536. ggml_vec_silu_backward_f16(nc,
  2537. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2538. (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
  2539. (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
  2540. #ifndef NDEBUG
  2541. for (int k = 0; k < nc; k++) {
  2542. const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2543. const float v = GGML_FP16_TO_FP32(x);
  2544. GGML_UNUSED(v);
  2545. assert(!isnan(v));
  2546. assert(!isinf(v));
  2547. }
  2548. #endif
  2549. }
  2550. }
  2551. void ggml_compute_forward_silu_back(
  2552. const ggml_compute_params * params,
  2553. ggml_tensor * dst) {
  2554. const ggml_tensor * src0 = dst->src[0];
  2555. switch (src0->type) {
  2556. case GGML_TYPE_F32:
  2557. {
  2558. ggml_compute_forward_silu_back_f32(params, dst);
  2559. } break;
  2560. case GGML_TYPE_F16:
  2561. {
  2562. ggml_compute_forward_silu_back_f16(params, dst);
  2563. } break;
  2564. default:
  2565. {
  2566. GGML_ABORT("fatal error");
  2567. }
  2568. }
  2569. }
  2570. // ggml_compute_forward_norm
  2571. static void ggml_compute_forward_norm_f32(
  2572. const ggml_compute_params * params,
  2573. ggml_tensor * dst) {
  2574. const ggml_tensor * src0 = dst->src[0];
  2575. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  2576. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2577. const int ith = params->ith;
  2578. const int nth = params->nth;
  2579. GGML_TENSOR_UNARY_OP_LOCALS
  2580. float eps;
  2581. memcpy(&eps, dst->op_params, sizeof(float));
  2582. GGML_ASSERT(eps >= 0.0f);
  2583. // TODO: optimize
  2584. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2585. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2586. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2587. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2588. ggml_float sum = 0.0;
  2589. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2590. sum += (ggml_float)x[i00];
  2591. }
  2592. float mean = sum/ne00;
  2593. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2594. ggml_float sum2 = 0.0;
  2595. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2596. float v = x[i00] - mean;
  2597. y[i00] = v;
  2598. sum2 += (ggml_float)(v*v);
  2599. }
  2600. float variance = sum2/ne00;
  2601. const float scale = 1.0f/sqrtf(variance + eps);
  2602. ggml_vec_scale_f32(ne00, y, scale);
  2603. }
  2604. }
  2605. }
  2606. }
  2607. void ggml_compute_forward_norm(
  2608. const ggml_compute_params * params,
  2609. ggml_tensor * dst) {
  2610. const ggml_tensor * src0 = dst->src[0];
  2611. switch (src0->type) {
  2612. case GGML_TYPE_F32:
  2613. {
  2614. ggml_compute_forward_norm_f32(params, dst);
  2615. } break;
  2616. default:
  2617. {
  2618. GGML_ABORT("fatal error");
  2619. }
  2620. }
  2621. }
  2622. // ggml_compute_forward_group_rms_norm
  2623. static void ggml_compute_forward_rms_norm_f32(
  2624. const ggml_compute_params * params,
  2625. ggml_tensor * dst) {
  2626. const ggml_tensor * src0 = dst->src[0];
  2627. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  2628. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2629. const int ith = params->ith;
  2630. const int nth = params->nth;
  2631. GGML_TENSOR_UNARY_OP_LOCALS
  2632. float eps;
  2633. memcpy(&eps, dst->op_params, sizeof(float));
  2634. GGML_ASSERT(eps >= 0.0f);
  2635. // TODO: optimize
  2636. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2637. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2638. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2639. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2640. ggml_float sum = 0.0;
  2641. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2642. sum += (ggml_float)(x[i00] * x[i00]);
  2643. }
  2644. const float mean = sum/ne00;
  2645. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2646. memcpy(y, x, ne00 * sizeof(float));
  2647. // for (int i00 = 0; i00 < ne00; i00++) {
  2648. // y[i00] = x[i00];
  2649. // }
  2650. const float scale = 1.0f/sqrtf(mean + eps);
  2651. ggml_vec_scale_f32(ne00, y, scale);
  2652. }
  2653. }
  2654. }
  2655. }
  2656. void ggml_compute_forward_rms_norm(
  2657. const ggml_compute_params * params,
  2658. ggml_tensor * dst) {
  2659. const ggml_tensor * src0 = dst->src[0];
  2660. switch (src0->type) {
  2661. case GGML_TYPE_F32:
  2662. {
  2663. ggml_compute_forward_rms_norm_f32(params, dst);
  2664. } break;
  2665. default:
  2666. {
  2667. GGML_ABORT("fatal error");
  2668. }
  2669. }
  2670. }
  2671. static void ggml_compute_forward_rms_norm_back_f32(
  2672. const ggml_compute_params * params,
  2673. ggml_tensor * dst) {
  2674. const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
  2675. const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
  2676. GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
  2677. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2678. GGML_ASSERT(src1->nb[0] == sizeof(float));
  2679. const int ith = params->ith;
  2680. const int nth = params->nth;
  2681. GGML_TENSOR_BINARY_OP_LOCALS
  2682. float eps;
  2683. memcpy(&eps, dst->op_params, sizeof(float));
  2684. // TODO: optimize
  2685. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2686. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2687. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2688. // src1 is same shape as src0 => same indices
  2689. const int64_t i11 = i01;
  2690. const int64_t i12 = i02;
  2691. const int64_t i13 = i03;
  2692. const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2693. const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
  2694. ggml_float sum_xx = 0.0;
  2695. ggml_float sum_xdz = 0.0;
  2696. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2697. sum_xx += (ggml_float)(x[i00] * x[i00]);
  2698. sum_xdz += (ggml_float)(x[i00] * dz[i00]);
  2699. }
  2700. //const float mean = (float)(sum_xx)/ne00;
  2701. const float mean_eps = (float)(sum_xx)/ne00 + eps;
  2702. const float sum_eps = (float)(sum_xx) + eps*ne00;
  2703. //const float mean_xdz = (float)(sum_xdz)/ne00;
  2704. // we could cache rms from forward pass to improve performance.
  2705. // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
  2706. //const float rms = sqrtf(mean_eps);
  2707. const float rrms = 1.0f / sqrtf(mean_eps);
  2708. //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
  2709. {
  2710. // z = rms_norm(x)
  2711. //
  2712. // rms_norm(src1) =
  2713. // scale(
  2714. // src1,
  2715. // div(
  2716. // 1,
  2717. // sqrt(
  2718. // add(
  2719. // scale(
  2720. // sum(
  2721. // sqr(
  2722. // src1)),
  2723. // (1.0/N)),
  2724. // eps))));
  2725. // postorder:
  2726. // ## op args grad
  2727. // 00 param src1 grad[#00]
  2728. // 01 const 1
  2729. // 02 sqr (#00) grad[#02]
  2730. // 03 sum (#02) grad[#03]
  2731. // 04 const 1/N
  2732. // 05 scale (#03, #04) grad[#05]
  2733. // 06 const eps
  2734. // 07 add (#05, #06) grad[#07]
  2735. // 08 sqrt (#07) grad[#08]
  2736. // 09 div (#01,#08) grad[#09]
  2737. // 10 scale (#00,#09) grad[#10]
  2738. //
  2739. // backward pass, given grad[#10]
  2740. // #10: scale
  2741. // grad[#00] += scale(grad[#10],#09)
  2742. // grad[#09] += sum(mul(grad[#10],#00))
  2743. // #09: div
  2744. // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
  2745. // #08: sqrt
  2746. // grad[#07] += mul(grad[#08], div(0.5, #08))
  2747. // #07: add
  2748. // grad[#05] += grad[#07]
  2749. // #05: scale
  2750. // grad[#03] += scale(grad[#05],#04)
  2751. // #03: sum
  2752. // grad[#02] += repeat(grad[#03], #02)
  2753. // #02:
  2754. // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
  2755. //
  2756. // substitute and simplify:
  2757. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  2758. // grad[#02] = repeat(grad[#03], #02)
  2759. // grad[#02] = repeat(scale(grad[#05],#04), #02)
  2760. // grad[#02] = repeat(scale(grad[#07],#04), #02)
  2761. // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
  2762. // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
  2763. // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
  2764. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
  2765. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
  2766. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
  2767. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
  2768. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  2769. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
  2770. // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
  2771. // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
  2772. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  2773. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  2774. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
  2775. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
  2776. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
  2777. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
  2778. // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
  2779. // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
  2780. // a = b*c + d*e
  2781. // a = b*c*f/f + d*e*f/f
  2782. // a = (b*c*f + d*e*f)*(1/f)
  2783. // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
  2784. // a = (b + d*e/c)*c
  2785. // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
  2786. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
  2787. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
  2788. // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
  2789. // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
  2790. // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
  2791. // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
  2792. // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
  2793. // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  2794. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  2795. }
  2796. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  2797. // post-order:
  2798. // dx := x
  2799. // dx := scale(dx,-mean_xdz/mean_eps)
  2800. // dx := add(dx, dz)
  2801. // dx := scale(dx, rrms)
  2802. float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2803. // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
  2804. ggml_vec_cpy_f32 (ne00, dx, x);
  2805. // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
  2806. ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
  2807. ggml_vec_acc_f32 (ne00, dx, dz);
  2808. ggml_vec_scale_f32(ne00, dx, rrms);
  2809. }
  2810. }
  2811. }
  2812. }
  2813. void ggml_compute_forward_rms_norm_back(
  2814. const ggml_compute_params * params,
  2815. ggml_tensor * dst) {
  2816. const ggml_tensor * src0 = dst->src[0];
  2817. switch (src0->type) {
  2818. case GGML_TYPE_F32:
  2819. {
  2820. ggml_compute_forward_rms_norm_back_f32(params, dst);
  2821. } break;
  2822. default:
  2823. {
  2824. GGML_ABORT("fatal error");
  2825. }
  2826. }
  2827. }
  2828. // ggml_compute_forward_group_norm
  2829. static void ggml_compute_forward_group_norm_f32(
  2830. const ggml_compute_params * params,
  2831. ggml_tensor * dst) {
  2832. const ggml_tensor * src0 = dst->src[0];
  2833. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  2834. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2835. const int ith = params->ith;
  2836. const int nth = params->nth;
  2837. GGML_TENSOR_UNARY_OP_LOCALS
  2838. // TODO: optimize
  2839. float eps;
  2840. memcpy(&eps, dst->op_params + 1, sizeof(float));
  2841. int n_channels = src0->ne[2];
  2842. int n_groups = dst->op_params[0];
  2843. int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
  2844. for (int i = ith; i < n_groups; i += nth) {
  2845. int start = i * n_channels_per_group;
  2846. int end = start + n_channels_per_group;
  2847. if (end > n_channels) {
  2848. end = n_channels;
  2849. }
  2850. int step = end - start;
  2851. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2852. ggml_float sum = 0.0;
  2853. for (int64_t i02 = start; i02 < end; i02++) {
  2854. for (int64_t i01 = 0; i01 < ne01; i01++) {
  2855. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  2856. ggml_float sumr = 0.0;
  2857. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2858. sumr += (ggml_float)x[i00];
  2859. }
  2860. sum += sumr;
  2861. }
  2862. }
  2863. const float mean = sum / (ne00 * ne01 * step);
  2864. ggml_float sum2 = 0.0;
  2865. for (int64_t i02 = start; i02 < end; i02++) {
  2866. for (int64_t i01 = 0; i01 < ne01; i01++) {
  2867. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  2868. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  2869. ggml_float sumr = 0.0;
  2870. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2871. float v = x[i00] - mean;
  2872. y[i00] = v;
  2873. sumr += (ggml_float)(v * v);
  2874. }
  2875. sum2 += sumr;
  2876. }
  2877. }
  2878. const float variance = sum2 / (ne00 * ne01 * step);
  2879. const float scale = 1.0f / sqrtf(variance + eps);
  2880. for (int64_t i02 = start; i02 < end; i02++) {
  2881. for (int64_t i01 = 0; i01 < ne01; i01++) {
  2882. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  2883. ggml_vec_scale_f32(ne00, y, scale);
  2884. }
  2885. }
  2886. }
  2887. }
  2888. }
  2889. void ggml_compute_forward_group_norm(
  2890. const ggml_compute_params * params,
  2891. ggml_tensor * dst) {
  2892. const ggml_tensor * src0 = dst->src[0];
  2893. switch (src0->type) {
  2894. case GGML_TYPE_F32:
  2895. {
  2896. ggml_compute_forward_group_norm_f32(params, dst);
  2897. } break;
  2898. default:
  2899. {
  2900. GGML_ABORT("fatal error");
  2901. }
  2902. }
  2903. }
  2904. // ggml_compute_forward_l2_norm
  2905. static void ggml_compute_forward_l2_norm_f32(
  2906. const ggml_compute_params * params,
  2907. ggml_tensor * dst) {
  2908. const ggml_tensor * src0 = dst->src[0];
  2909. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  2910. GGML_ASSERT(src0->nb[0] == sizeof(float));
  2911. const int ith = params->ith;
  2912. const int nth = params->nth;
  2913. GGML_TENSOR_UNARY_OP_LOCALS
  2914. float eps;
  2915. memcpy(&eps, dst->op_params, sizeof(float));
  2916. GGML_ASSERT(eps >= 0.0f);
  2917. // TODO: optimize
  2918. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2919. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2920. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  2921. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2922. ggml_float sum = 0.0;
  2923. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2924. sum += (ggml_float)(x[i00] * x[i00]);
  2925. }
  2926. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  2927. memcpy(y, x, ne00 * sizeof(float));
  2928. const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
  2929. ggml_vec_scale_f32(ne00, y, scale);
  2930. }
  2931. }
  2932. }
  2933. }
  2934. void ggml_compute_forward_l2_norm(
  2935. const ggml_compute_params * params,
  2936. ggml_tensor * dst) {
  2937. const ggml_tensor * src0 = dst->src[0];
  2938. switch (src0->type) {
  2939. case GGML_TYPE_F32:
  2940. {
  2941. ggml_compute_forward_l2_norm_f32(params, dst);
  2942. } break;
  2943. default:
  2944. {
  2945. GGML_ABORT("fatal error");
  2946. }
  2947. }
  2948. }
  2949. // ggml_compute_forward_out_prod
  2950. static void ggml_compute_forward_out_prod_f32(
  2951. const ggml_compute_params * params,
  2952. ggml_tensor * dst) {
  2953. const ggml_tensor * src0 = dst->src[0];
  2954. const ggml_tensor * src1 = dst->src[1];
  2955. GGML_TENSOR_BINARY_OP_LOCALS
  2956. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  2957. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  2958. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  2959. const int ith = params->ith;
  2960. const int nth = params->nth;
  2961. GGML_ASSERT(ne0 == ne00);
  2962. GGML_ASSERT(ne1 == ne10);
  2963. GGML_ASSERT(ne2 == ne12);
  2964. GGML_ASSERT(ne3 == ne13);
  2965. GGML_ASSERT(ne2 % ne02 == 0);
  2966. GGML_ASSERT(ne3 % ne03 == 0);
  2967. // we don't support permuted src0 or src1
  2968. GGML_ASSERT(nb00 == sizeof(float));
  2969. // dst cannot be transposed or permuted
  2970. GGML_ASSERT(nb0 == sizeof(float));
  2971. // GGML_ASSERT(nb0 <= nb1);
  2972. // GGML_ASSERT(nb1 <= nb2);
  2973. // GGML_ASSERT(nb2 <= nb3);
  2974. // nb01 >= nb00 - src0 is not transposed
  2975. // compute by src0 rows
  2976. if (ith == 0) {
  2977. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  2978. }
  2979. ggml_barrier(params->threadpool);
  2980. // dst[:,:,:,:] = 0
  2981. // for i2,i3:
  2982. // for i1:
  2983. // for i01:
  2984. // for i0:
  2985. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  2986. // parallelize by last three dimensions
  2987. // total rows in dst
  2988. const int64_t nr = ne1*ne2*ne3;
  2989. // rows per thread
  2990. const int64_t dr = (nr + nth - 1)/nth;
  2991. // row range for this thread
  2992. const int64_t ir0 = dr*ith;
  2993. const int64_t ir1 = MIN(ir0 + dr, nr);
  2994. // block-tiling attempt
  2995. const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
  2996. const int64_t blck_1 = 16;
  2997. // dps == dst per src0, used for group query attention
  2998. const int64_t dps2 = ne2 / ne02;
  2999. const int64_t dps3 = ne3 / ne03;
  3000. for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
  3001. const int64_t bir1 = MIN(bir + blck_1, ir1);
  3002. for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
  3003. const int64_t bne01 = MIN(bi01 + blck_0, ne01);
  3004. for (int64_t ir = bir; ir < bir1; ++ir) {
  3005. // dst indices
  3006. const int64_t i3 = ir/(ne2*ne1);
  3007. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  3008. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3009. const int64_t i02 = i2 / dps2;
  3010. const int64_t i03 = i3 / dps3;
  3011. //const int64_t i10 = i1;
  3012. const int64_t i12 = i2;
  3013. const int64_t i13 = i3;
  3014. #if GGML_VEC_MAD_UNROLL > 2
  3015. const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
  3016. for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
  3017. const int64_t i11 = i01;
  3018. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3019. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3020. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3021. ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
  3022. }
  3023. for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
  3024. const int64_t i11 = i01;
  3025. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3026. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3027. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3028. ggml_vec_mad_f32(ne0, d, s0, *s1);
  3029. }
  3030. #else
  3031. for (int64_t i01 = bi01; i01 < bne01; ++i01) {
  3032. const int64_t i11 = i01;
  3033. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3034. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3035. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3036. ggml_vec_mad_f32(ne0, d, s0, *s1);
  3037. }
  3038. #endif
  3039. }
  3040. }
  3041. }
  3042. }
  3043. static void ggml_compute_forward_out_prod_q_f32(
  3044. const ggml_compute_params * params,
  3045. ggml_tensor * dst) {
  3046. const ggml_tensor * src0 = dst->src[0];
  3047. const ggml_tensor * src1 = dst->src[1];
  3048. GGML_TENSOR_BINARY_OP_LOCALS;
  3049. const int ith = params->ith;
  3050. const int nth = params->nth;
  3051. const ggml_type type = src0->type;
  3052. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  3053. GGML_ASSERT(ne02 == ne12);
  3054. GGML_ASSERT(ne03 == ne13);
  3055. GGML_ASSERT(ne2 == ne12);
  3056. GGML_ASSERT(ne3 == ne13);
  3057. // we don't support permuted src0 dim0
  3058. GGML_ASSERT(nb00 == ggml_type_size(type));
  3059. // dst dim0 cannot be transposed or permuted
  3060. GGML_ASSERT(nb0 == sizeof(float));
  3061. // GGML_ASSERT(nb0 <= nb1);
  3062. // GGML_ASSERT(nb1 <= nb2);
  3063. // GGML_ASSERT(nb2 <= nb3);
  3064. GGML_ASSERT(ne0 == ne00);
  3065. GGML_ASSERT(ne1 == ne10);
  3066. GGML_ASSERT(ne2 == ne02);
  3067. GGML_ASSERT(ne3 == ne03);
  3068. // nb01 >= nb00 - src0 is not transposed
  3069. // compute by src0 rows
  3070. if (ith == 0) {
  3071. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  3072. }
  3073. ggml_barrier(params->threadpool);
  3074. // parallelize by last three dimensions
  3075. // total rows in dst
  3076. const int64_t nr = ne1*ne2*ne3;
  3077. // rows per thread
  3078. const int64_t dr = (nr + nth - 1)/nth;
  3079. // row range for this thread
  3080. const int64_t ir0 = dr*ith;
  3081. const int64_t ir1 = MIN(ir0 + dr, nr);
  3082. // dst[:,:,:,:] = 0
  3083. // for i2,i3:
  3084. // for i1:
  3085. // for i01:
  3086. // for i0:
  3087. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  3088. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  3089. for (int64_t ir = ir0; ir < ir1; ++ir) {
  3090. // dst indices
  3091. const int64_t i3 = ir/(ne2*ne1);
  3092. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  3093. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3094. const int64_t i02 = i2;
  3095. const int64_t i03 = i3;
  3096. //const int64_t i10 = i1;
  3097. const int64_t i12 = i2;
  3098. const int64_t i13 = i3;
  3099. for (int64_t i01 = 0; i01 < ne01; ++i01) {
  3100. const int64_t i11 = i01;
  3101. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3102. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3103. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3104. dequantize_row_q(s0, wdata, ne0);
  3105. ggml_vec_mad_f32(ne0, d, wdata, *s1);
  3106. }
  3107. }
  3108. }
  3109. void ggml_compute_forward_out_prod(
  3110. const ggml_compute_params * params,
  3111. ggml_tensor * dst) {
  3112. const ggml_tensor * src0 = dst->src[0];
  3113. switch (src0->type) {
  3114. case GGML_TYPE_Q4_0:
  3115. case GGML_TYPE_Q4_1:
  3116. case GGML_TYPE_Q5_0:
  3117. case GGML_TYPE_Q5_1:
  3118. case GGML_TYPE_Q8_0:
  3119. case GGML_TYPE_Q2_K:
  3120. case GGML_TYPE_Q3_K:
  3121. case GGML_TYPE_Q4_K:
  3122. case GGML_TYPE_Q5_K:
  3123. case GGML_TYPE_Q6_K:
  3124. case GGML_TYPE_TQ1_0:
  3125. case GGML_TYPE_TQ2_0:
  3126. case GGML_TYPE_IQ2_XXS:
  3127. case GGML_TYPE_IQ2_XS:
  3128. case GGML_TYPE_IQ3_XXS:
  3129. case GGML_TYPE_IQ1_S:
  3130. case GGML_TYPE_IQ1_M:
  3131. case GGML_TYPE_IQ4_NL:
  3132. case GGML_TYPE_IQ4_XS:
  3133. case GGML_TYPE_IQ3_S:
  3134. case GGML_TYPE_IQ2_S:
  3135. {
  3136. ggml_compute_forward_out_prod_q_f32(params, dst);
  3137. } break;
  3138. case GGML_TYPE_F16:
  3139. {
  3140. GGML_ABORT("fatal error"); // todo
  3141. // ggml_compute_forward_out_prod_f16_f32(params, dst);
  3142. }
  3143. case GGML_TYPE_F32:
  3144. {
  3145. ggml_compute_forward_out_prod_f32(params, dst);
  3146. } break;
  3147. default:
  3148. {
  3149. GGML_ABORT("fatal error");
  3150. }
  3151. }
  3152. }
  3153. // ggml_compute_forward_scale
  3154. static void ggml_compute_forward_scale_f32(
  3155. const ggml_compute_params * params,
  3156. ggml_tensor * dst) {
  3157. const ggml_tensor * src0 = dst->src[0];
  3158. GGML_ASSERT(ggml_is_contiguous(src0));
  3159. GGML_ASSERT(ggml_is_contiguous(dst));
  3160. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3161. // scale factor
  3162. float v;
  3163. memcpy(&v, dst->op_params, sizeof(float));
  3164. const int ith = params->ith;
  3165. const int nth = params->nth;
  3166. const int nc = src0->ne[0];
  3167. const int nr = ggml_nrows(src0);
  3168. // rows per thread
  3169. const int dr = (nr + nth - 1)/nth;
  3170. // row range for this thread
  3171. const int ir0 = dr*ith;
  3172. const int ir1 = MIN(ir0 + dr, nr);
  3173. const size_t nb01 = src0->nb[1];
  3174. const size_t nb1 = dst->nb[1];
  3175. for (int i1 = ir0; i1 < ir1; i1++) {
  3176. if (dst->data != src0->data) {
  3177. // src0 is same shape as dst => same indices
  3178. memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
  3179. }
  3180. ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
  3181. }
  3182. }
  3183. void ggml_compute_forward_scale(
  3184. const ggml_compute_params * params,
  3185. ggml_tensor * dst) {
  3186. const ggml_tensor * src0 = dst->src[0];
  3187. switch (src0->type) {
  3188. case GGML_TYPE_F32:
  3189. {
  3190. ggml_compute_forward_scale_f32(params, dst);
  3191. } break;
  3192. default:
  3193. {
  3194. GGML_ABORT("fatal error");
  3195. }
  3196. }
  3197. }
  3198. // ggml_compute_forward_set
  3199. static void ggml_compute_forward_set_f32(
  3200. const ggml_compute_params * params,
  3201. ggml_tensor * dst) {
  3202. const ggml_tensor * src0 = dst->src[0];
  3203. const ggml_tensor * src1 = dst->src[1];
  3204. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3205. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3206. // view src0 and dst with these strides and data offset inbytes during set
  3207. // nb0 is implicitly element_size because src0 and dst are contiguous
  3208. size_t nb1 = ((int32_t *) dst->op_params)[0];
  3209. size_t nb2 = ((int32_t *) dst->op_params)[1];
  3210. size_t nb3 = ((int32_t *) dst->op_params)[2];
  3211. size_t offset = ((int32_t *) dst->op_params)[3];
  3212. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  3213. if (!inplace) {
  3214. if (params->ith == 0) {
  3215. // memcpy needs to be synchronized across threads to avoid race conditions.
  3216. // => do it in INIT phase
  3217. memcpy(
  3218. ((char *) dst->data),
  3219. ((char *) src0->data),
  3220. ggml_nbytes(dst));
  3221. }
  3222. ggml_barrier(params->threadpool);
  3223. }
  3224. const int ith = params->ith;
  3225. const int nth = params->nth;
  3226. const int nr = ggml_nrows(src1);
  3227. const int nc = src1->ne[0];
  3228. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  3229. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  3230. // src0 and dst as viewed during set
  3231. const size_t nb0 = ggml_element_size(src0);
  3232. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  3233. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  3234. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  3235. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  3236. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  3237. GGML_ASSERT(nb10 == sizeof(float));
  3238. // rows per thread
  3239. const int dr = (nr + nth - 1)/nth;
  3240. // row range for this thread
  3241. const int ir0 = dr*ith;
  3242. const int ir1 = MIN(ir0 + dr, nr);
  3243. for (int ir = ir0; ir < ir1; ++ir) {
  3244. // src0 and dst are viewed with shape of src1 and offset
  3245. // => same indices
  3246. const int i3 = ir/(ne12*ne11);
  3247. const int i2 = (ir - i3*ne12*ne11)/ne11;
  3248. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  3249. ggml_vec_cpy_f32(nc,
  3250. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  3251. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  3252. }
  3253. }
  3254. static void ggml_compute_forward_set_i32(
  3255. const ggml_compute_params * params,
  3256. ggml_tensor * dst) {
  3257. const ggml_tensor * src0 = dst->src[0];
  3258. const ggml_tensor * src1 = dst->src[1];
  3259. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3260. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3261. // view src0 and dst with these strides and data offset inbytes during set
  3262. // nb0 is implicitly element_size because src0 and dst are contiguous
  3263. size_t nb1 = ((int32_t *) dst->op_params)[0];
  3264. size_t nb2 = ((int32_t *) dst->op_params)[1];
  3265. size_t nb3 = ((int32_t *) dst->op_params)[2];
  3266. size_t offset = ((int32_t *) dst->op_params)[3];
  3267. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  3268. if (!inplace) {
  3269. if (params->ith == 0) {
  3270. // memcpy needs to be synchronized across threads to avoid race conditions.
  3271. // => do it in INIT phase
  3272. memcpy(
  3273. ((char *) dst->data),
  3274. ((char *) src0->data),
  3275. ggml_nbytes(dst));
  3276. }
  3277. ggml_barrier(params->threadpool);
  3278. }
  3279. const int ith = params->ith;
  3280. const int nth = params->nth;
  3281. const int nr = ggml_nrows(src1);
  3282. const int nc = src1->ne[0];
  3283. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  3284. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  3285. // src0 and dst as viewed during set
  3286. const size_t nb0 = ggml_element_size(src0);
  3287. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  3288. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  3289. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  3290. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  3291. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  3292. GGML_ASSERT(nb10 == sizeof(int32_t));
  3293. // rows per thread
  3294. const int dr = (nr + nth - 1)/nth;
  3295. // row range for this thread
  3296. const int ir0 = dr*ith;
  3297. const int ir1 = MIN(ir0 + dr, nr);
  3298. for (int ir = ir0; ir < ir1; ++ir) {
  3299. // src0 and dst are viewed with shape of src1 and offset
  3300. // => same indices
  3301. const int i3 = ir/(ne12*ne11);
  3302. const int i2 = (ir - i3*ne12*ne11)/ne11;
  3303. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  3304. ggml_vec_cpy_i32(nc,
  3305. (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  3306. (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  3307. }
  3308. }
  3309. void ggml_compute_forward_set(
  3310. const ggml_compute_params * params,
  3311. ggml_tensor * dst) {
  3312. const ggml_tensor * src0 = dst->src[0];
  3313. switch (src0->type) {
  3314. case GGML_TYPE_F32:
  3315. {
  3316. ggml_compute_forward_set_f32(params, dst);
  3317. } break;
  3318. case GGML_TYPE_I32:
  3319. {
  3320. ggml_compute_forward_set_i32(params, dst);
  3321. } break;
  3322. case GGML_TYPE_F16:
  3323. case GGML_TYPE_BF16:
  3324. case GGML_TYPE_Q4_0:
  3325. case GGML_TYPE_Q4_1:
  3326. case GGML_TYPE_Q5_0:
  3327. case GGML_TYPE_Q5_1:
  3328. case GGML_TYPE_Q8_0:
  3329. case GGML_TYPE_Q8_1:
  3330. case GGML_TYPE_Q2_K:
  3331. case GGML_TYPE_Q3_K:
  3332. case GGML_TYPE_Q4_K:
  3333. case GGML_TYPE_Q5_K:
  3334. case GGML_TYPE_Q6_K:
  3335. case GGML_TYPE_TQ1_0:
  3336. case GGML_TYPE_TQ2_0:
  3337. case GGML_TYPE_IQ2_XXS:
  3338. case GGML_TYPE_IQ2_XS:
  3339. case GGML_TYPE_IQ3_XXS:
  3340. case GGML_TYPE_IQ1_S:
  3341. case GGML_TYPE_IQ1_M:
  3342. case GGML_TYPE_IQ4_NL:
  3343. case GGML_TYPE_IQ4_XS:
  3344. case GGML_TYPE_IQ3_S:
  3345. case GGML_TYPE_IQ2_S:
  3346. default:
  3347. {
  3348. GGML_ABORT("fatal error");
  3349. }
  3350. }
  3351. }
  3352. // ggml_compute_forward_cpy
  3353. void ggml_compute_forward_cpy(
  3354. const ggml_compute_params * params,
  3355. ggml_tensor * dst) {
  3356. ggml_compute_forward_dup(params, dst);
  3357. }
  3358. // ggml_compute_forward_cont
  3359. void ggml_compute_forward_cont(
  3360. const ggml_compute_params * params,
  3361. ggml_tensor * dst) {
  3362. ggml_compute_forward_dup(params, dst);
  3363. }
  3364. // ggml_compute_forward_reshape
  3365. void ggml_compute_forward_reshape(
  3366. const ggml_compute_params * params,
  3367. ggml_tensor * dst) {
  3368. // NOP
  3369. GGML_UNUSED(params);
  3370. GGML_UNUSED(dst);
  3371. }
  3372. // ggml_compute_forward_view
  3373. void ggml_compute_forward_view(
  3374. const ggml_compute_params * params,
  3375. ggml_tensor * dst) {
  3376. // NOP
  3377. GGML_UNUSED(params);
  3378. GGML_UNUSED(dst);
  3379. }
  3380. // ggml_compute_forward_permute
  3381. void ggml_compute_forward_permute(
  3382. const ggml_compute_params * params,
  3383. ggml_tensor * dst) {
  3384. // NOP
  3385. GGML_UNUSED(params);
  3386. GGML_UNUSED(dst);
  3387. }
  3388. // ggml_compute_forward_transpose
  3389. void ggml_compute_forward_transpose(
  3390. const ggml_compute_params * params,
  3391. ggml_tensor * dst) {
  3392. // NOP
  3393. GGML_UNUSED(params);
  3394. GGML_UNUSED(dst);
  3395. }
  3396. // ggml_compute_forward_get_rows
  3397. static void ggml_compute_forward_get_rows_q(
  3398. const ggml_compute_params * params,
  3399. ggml_tensor * dst) {
  3400. const ggml_tensor * src0 = dst->src[0];
  3401. const ggml_tensor * src1 = dst->src[1];
  3402. GGML_TENSOR_BINARY_OP_LOCALS
  3403. const int64_t nc = ne00;
  3404. const int64_t nr = ggml_nelements(src1);
  3405. const ggml_type type = src0->type;
  3406. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  3407. assert(ne0 == nc);
  3408. assert(ne02 == ne11);
  3409. assert(nb00 == ggml_type_size(type));
  3410. assert(ggml_nrows(dst) == nr);
  3411. const int ith = params->ith;
  3412. const int nth = params->nth;
  3413. // rows per thread
  3414. const int dr = (nr + nth - 1)/nth;
  3415. // row range for this thread
  3416. const int ir0 = dr*ith;
  3417. const int ir1 = MIN(ir0 + dr, nr);
  3418. for (int64_t i = ir0; i < ir1; ++i) {
  3419. const int64_t i12 = i/(ne11*ne10);
  3420. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3421. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3422. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3423. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3424. dequantize_row_q(
  3425. (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  3426. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  3427. }
  3428. }
  3429. static void ggml_compute_forward_get_rows_f16(
  3430. const ggml_compute_params * params,
  3431. ggml_tensor * dst) {
  3432. const ggml_tensor * src0 = dst->src[0];
  3433. const ggml_tensor * src1 = dst->src[1];
  3434. GGML_TENSOR_BINARY_OP_LOCALS
  3435. const int64_t nc = ne00;
  3436. const int64_t nr = ggml_nelements(src1);
  3437. assert(ne0 == nc);
  3438. assert(ne02 == ne11);
  3439. assert(nb00 == sizeof(ggml_fp16_t));
  3440. assert(ggml_nrows(dst) == nr);
  3441. const int ith = params->ith;
  3442. const int nth = params->nth;
  3443. // rows per thread
  3444. const int dr = (nr + nth - 1)/nth;
  3445. // row range for this thread
  3446. const int ir0 = dr*ith;
  3447. const int ir1 = MIN(ir0 + dr, nr);
  3448. for (int64_t i = ir0; i < ir1; ++i) {
  3449. const int64_t i12 = i/(ne11*ne10);
  3450. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3451. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3452. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3453. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3454. ggml_fp16_to_fp32_row(
  3455. (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  3456. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  3457. }
  3458. }
  3459. static void ggml_compute_forward_get_rows_bf16(
  3460. const ggml_compute_params * params,
  3461. ggml_tensor * dst) {
  3462. const ggml_tensor * src0 = dst->src[0];
  3463. const ggml_tensor * src1 = dst->src[1];
  3464. GGML_TENSOR_BINARY_OP_LOCALS
  3465. const int64_t nc = ne00;
  3466. const int64_t nr = ggml_nelements(src1);
  3467. assert(ne0 == nc);
  3468. assert(ne02 == ne11);
  3469. assert(nb00 == sizeof(ggml_bf16_t));
  3470. assert(ggml_nrows(dst) == nr);
  3471. const int ith = params->ith;
  3472. const int nth = params->nth;
  3473. // rows per thread
  3474. const int dr = (nr + nth - 1)/nth;
  3475. // row range for this thread
  3476. const int ir0 = dr*ith;
  3477. const int ir1 = MIN(ir0 + dr, nr);
  3478. for (int64_t i = ir0; i < ir1; ++i) {
  3479. const int64_t i12 = i/(ne11*ne10);
  3480. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3481. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3482. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3483. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3484. ggml_bf16_to_fp32_row(
  3485. (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  3486. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  3487. }
  3488. }
  3489. static void ggml_compute_forward_get_rows_f32(
  3490. const ggml_compute_params * params,
  3491. ggml_tensor * dst) {
  3492. const ggml_tensor * src0 = dst->src[0];
  3493. const ggml_tensor * src1 = dst->src[1];
  3494. GGML_TENSOR_BINARY_OP_LOCALS
  3495. const int64_t nc = ne00;
  3496. const int64_t nr = ggml_nelements(src1);
  3497. assert(ne0 == nc);
  3498. assert(ne02 == ne11);
  3499. assert(nb00 == sizeof(float));
  3500. assert(ggml_nrows(dst) == nr);
  3501. const int ith = params->ith;
  3502. const int nth = params->nth;
  3503. // rows per thread
  3504. const int dr = (nr + nth - 1)/nth;
  3505. // row range for this thread
  3506. const int ir0 = dr*ith;
  3507. const int ir1 = MIN(ir0 + dr, nr);
  3508. for (int64_t i = ir0; i < ir1; ++i) {
  3509. const int64_t i12 = i/(ne11*ne10);
  3510. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  3511. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  3512. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  3513. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  3514. ggml_vec_cpy_f32(nc,
  3515. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
  3516. (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
  3517. }
  3518. }
  3519. void ggml_compute_forward_get_rows(
  3520. const ggml_compute_params * params,
  3521. ggml_tensor * dst) {
  3522. const ggml_tensor * src0 = dst->src[0];
  3523. switch (src0->type) {
  3524. case GGML_TYPE_Q4_0:
  3525. case GGML_TYPE_Q4_1:
  3526. case GGML_TYPE_Q5_0:
  3527. case GGML_TYPE_Q5_1:
  3528. case GGML_TYPE_Q8_0:
  3529. case GGML_TYPE_Q8_1:
  3530. case GGML_TYPE_Q2_K:
  3531. case GGML_TYPE_Q3_K:
  3532. case GGML_TYPE_Q4_K:
  3533. case GGML_TYPE_Q5_K:
  3534. case GGML_TYPE_Q6_K:
  3535. case GGML_TYPE_TQ1_0:
  3536. case GGML_TYPE_TQ2_0:
  3537. case GGML_TYPE_IQ2_XXS:
  3538. case GGML_TYPE_IQ2_XS:
  3539. case GGML_TYPE_IQ3_XXS:
  3540. case GGML_TYPE_IQ1_S:
  3541. case GGML_TYPE_IQ1_M:
  3542. case GGML_TYPE_IQ4_NL:
  3543. case GGML_TYPE_IQ4_XS:
  3544. case GGML_TYPE_IQ3_S:
  3545. case GGML_TYPE_IQ2_S:
  3546. {
  3547. ggml_compute_forward_get_rows_q(params, dst);
  3548. } break;
  3549. case GGML_TYPE_F16:
  3550. {
  3551. ggml_compute_forward_get_rows_f16(params, dst);
  3552. } break;
  3553. case GGML_TYPE_BF16:
  3554. {
  3555. ggml_compute_forward_get_rows_bf16(params, dst);
  3556. } break;
  3557. case GGML_TYPE_F32:
  3558. case GGML_TYPE_I32:
  3559. {
  3560. ggml_compute_forward_get_rows_f32(params, dst);
  3561. } break;
  3562. default:
  3563. {
  3564. GGML_ABORT("fatal error");
  3565. }
  3566. }
  3567. //static bool first = true;
  3568. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  3569. //if (first) {
  3570. // first = false;
  3571. //} else {
  3572. // for (int k = 0; k < dst->ne[1]; ++k) {
  3573. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  3574. // for (int i = 0; i < 16; ++i) {
  3575. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  3576. // }
  3577. // printf("\n");
  3578. // }
  3579. // printf("\n");
  3580. // }
  3581. // printf("\n");
  3582. // exit(0);
  3583. //}
  3584. }
  3585. // ggml_compute_forward_get_rows_back
  3586. static void ggml_compute_forward_get_rows_back_f32_f16(
  3587. const ggml_compute_params * params,
  3588. ggml_tensor * dst) {
  3589. const ggml_tensor * src0 = dst->src[0];
  3590. const ggml_tensor * src1 = dst->src[1];
  3591. if (params->ith != 0) {
  3592. return;
  3593. }
  3594. GGML_ASSERT(ggml_is_contiguous(dst));
  3595. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  3596. memset(dst->data, 0, ggml_nbytes(dst));
  3597. const int nc = src0->ne[0];
  3598. const int nr = ggml_nelements(src1);
  3599. GGML_ASSERT( dst->ne[0] == nc);
  3600. GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
  3601. for (int i = 0; i < nr; ++i) {
  3602. const int r = ((int32_t *) src1->data)[i];
  3603. for (int j = 0; j < nc; ++j) {
  3604. ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
  3605. ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
  3606. }
  3607. }
  3608. }
  3609. static void ggml_compute_forward_get_rows_back_f32(
  3610. const ggml_compute_params * params,
  3611. ggml_tensor * dst) {
  3612. const ggml_tensor * src0 = dst->src[0];
  3613. const ggml_tensor * src1 = dst->src[1];
  3614. if (params->ith != 0) {
  3615. return;
  3616. }
  3617. GGML_ASSERT(ggml_is_contiguous(dst));
  3618. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  3619. memset(dst->data, 0, ggml_nbytes(dst));
  3620. const int nc = src0->ne[0];
  3621. const int nr = ggml_nelements(src1);
  3622. GGML_ASSERT( dst->ne[0] == nc);
  3623. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3624. for (int i = 0; i < nr; ++i) {
  3625. const int r = ((int32_t *) src1->data)[i];
  3626. ggml_vec_add_f32(nc,
  3627. (float *) ((char *) dst->data + r*dst->nb[1]),
  3628. (float *) ((char *) dst->data + r*dst->nb[1]),
  3629. (float *) ((char *) src0->data + i*src0->nb[1]));
  3630. }
  3631. }
  3632. void ggml_compute_forward_get_rows_back(
  3633. const ggml_compute_params * params,
  3634. ggml_tensor * dst) {
  3635. const ggml_tensor * src0 = dst->src[0];
  3636. switch (src0->type) {
  3637. case GGML_TYPE_F16:
  3638. {
  3639. ggml_compute_forward_get_rows_back_f32_f16(params, dst);
  3640. } break;
  3641. case GGML_TYPE_F32:
  3642. {
  3643. ggml_compute_forward_get_rows_back_f32(params, dst);
  3644. } break;
  3645. default:
  3646. {
  3647. GGML_ABORT("fatal error");
  3648. }
  3649. }
  3650. //static bool first = true;
  3651. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  3652. //if (first) {
  3653. // first = false;
  3654. //} else {
  3655. // for (int k = 0; k < dst->ne[1]; ++k) {
  3656. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  3657. // for (int i = 0; i < 16; ++i) {
  3658. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  3659. // }
  3660. // printf("\n");
  3661. // }
  3662. // printf("\n");
  3663. // }
  3664. // printf("\n");
  3665. // exit(0);
  3666. //}
  3667. }
  3668. // ggml_compute_forward_diag
  3669. static void ggml_compute_forward_diag_f32(
  3670. const ggml_compute_params * params,
  3671. ggml_tensor * dst) {
  3672. const ggml_tensor * src0 = dst->src[0];
  3673. if (params->ith != 0) {
  3674. return;
  3675. }
  3676. // TODO: handle transposed/permuted matrices
  3677. GGML_TENSOR_UNARY_OP_LOCALS
  3678. GGML_ASSERT(ne00 == ne0);
  3679. GGML_ASSERT(ne00 == ne1);
  3680. GGML_ASSERT(ne01 == 1);
  3681. GGML_ASSERT(ne02 == ne2);
  3682. GGML_ASSERT(ne03 == ne3);
  3683. GGML_ASSERT(nb00 == sizeof(float));
  3684. GGML_ASSERT(nb0 == sizeof(float));
  3685. for (int i3 = 0; i3 < ne3; i3++) {
  3686. for (int i2 = 0; i2 < ne2; i2++) {
  3687. for (int i1 = 0; i1 < ne1; i1++) {
  3688. float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3689. float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
  3690. for (int i0 = 0; i0 < i1; i0++) {
  3691. d[i0] = 0;
  3692. }
  3693. d[i1] = s[i1];
  3694. for (int i0 = i1+1; i0 < ne0; i0++) {
  3695. d[i0] = 0;
  3696. }
  3697. }
  3698. }
  3699. }
  3700. }
  3701. void ggml_compute_forward_diag(
  3702. const ggml_compute_params * params,
  3703. ggml_tensor * dst) {
  3704. const ggml_tensor * src0 = dst->src[0];
  3705. switch (src0->type) {
  3706. case GGML_TYPE_F32:
  3707. {
  3708. ggml_compute_forward_diag_f32(params, dst);
  3709. } break;
  3710. default:
  3711. {
  3712. GGML_ABORT("fatal error");
  3713. }
  3714. }
  3715. }
  3716. // ggml_compute_forward_diag_mask_inf
  3717. static void ggml_compute_forward_diag_mask_f32(
  3718. const ggml_compute_params * params,
  3719. ggml_tensor * dst,
  3720. const float value) {
  3721. const ggml_tensor * src0 = dst->src[0];
  3722. const int ith = params->ith;
  3723. const int nth = params->nth;
  3724. const int n_past = ((int32_t *) dst->op_params)[0];
  3725. const bool inplace = src0->data == dst->data;
  3726. GGML_ASSERT(n_past >= 0);
  3727. if (!inplace) {
  3728. if (ith == 0) {
  3729. // memcpy needs to be synchronized across threads to avoid race conditions.
  3730. // => do it in INIT phase
  3731. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  3732. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3733. memcpy(
  3734. ((char *) dst->data),
  3735. ((char *) src0->data),
  3736. ggml_nbytes(dst));
  3737. }
  3738. ggml_barrier(params->threadpool);
  3739. }
  3740. // TODO: handle transposed/permuted matrices
  3741. const int n = ggml_nrows(src0);
  3742. const int nc = src0->ne[0];
  3743. const int nr = src0->ne[1];
  3744. const int nz = n/nr;
  3745. GGML_ASSERT( dst->nb[0] == sizeof(float));
  3746. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3747. for (int k = 0; k < nz; k++) {
  3748. for (int j = ith; j < nr; j += nth) {
  3749. for (int i = n_past; i < nc; i++) {
  3750. if (i > n_past + j) {
  3751. *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
  3752. }
  3753. }
  3754. }
  3755. }
  3756. }
  3757. void ggml_compute_forward_diag_mask_inf(
  3758. const ggml_compute_params * params,
  3759. ggml_tensor * dst) {
  3760. const ggml_tensor * src0 = dst->src[0];
  3761. switch (src0->type) {
  3762. case GGML_TYPE_F32:
  3763. {
  3764. ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
  3765. } break;
  3766. default:
  3767. {
  3768. GGML_ABORT("fatal error");
  3769. }
  3770. }
  3771. }
  3772. void ggml_compute_forward_diag_mask_zero(
  3773. const ggml_compute_params * params,
  3774. ggml_tensor * dst) {
  3775. const ggml_tensor * src0 = dst->src[0];
  3776. switch (src0->type) {
  3777. case GGML_TYPE_F32:
  3778. {
  3779. ggml_compute_forward_diag_mask_f32(params, dst, 0);
  3780. } break;
  3781. default:
  3782. {
  3783. GGML_ABORT("fatal error");
  3784. }
  3785. }
  3786. }
  3787. // ggml_compute_forward_soft_max
  3788. static void ggml_compute_forward_soft_max_f32(
  3789. const ggml_compute_params * params,
  3790. ggml_tensor * dst) {
  3791. const ggml_tensor * src0 = dst->src[0];
  3792. const ggml_tensor * src1 = dst->src[1];
  3793. assert(ggml_is_contiguous(dst));
  3794. assert(ggml_are_same_shape(src0, dst));
  3795. float scale = 1.0f;
  3796. float max_bias = 0.0f;
  3797. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  3798. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  3799. // TODO: handle transposed/permuted matrices
  3800. const int ith = params->ith;
  3801. const int nth = params->nth;
  3802. GGML_TENSOR_UNARY_OP_LOCALS
  3803. //const int64_t ne11 = src1 ? src1->ne[1] : 1;
  3804. // TODO: is this supposed to be ceil instead of floor?
  3805. // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
  3806. const uint32_t n_head = ne02;
  3807. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  3808. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  3809. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  3810. const int nc = src0->ne[0];
  3811. const int nr = ggml_nrows(src0);
  3812. // rows per thread
  3813. const int dr = (nr + nth - 1)/nth;
  3814. // row range for this thread
  3815. const int ir0 = dr*ith;
  3816. const int ir1 = MIN(ir0 + dr, nr);
  3817. float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
  3818. const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
  3819. for (int i1 = ir0; i1 < ir1; i1++) {
  3820. // ALiBi
  3821. const uint32_t h = (i1/ne01)%ne02; // head
  3822. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  3823. float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
  3824. float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
  3825. // broadcast the mask across rows
  3826. ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
  3827. float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
  3828. ggml_vec_cpy_f32 (nc, wp, sp);
  3829. ggml_vec_scale_f32(nc, wp, scale);
  3830. if (mp_f32) {
  3831. if (use_f16) {
  3832. for (int i = 0; i < nc; ++i) {
  3833. wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
  3834. }
  3835. } else {
  3836. for (int i = 0; i < nc; ++i) {
  3837. wp[i] += slope*mp_f32[i];
  3838. }
  3839. }
  3840. }
  3841. #ifndef NDEBUG
  3842. for (int i = 0; i < nc; ++i) {
  3843. //printf("p[%d] = %f\n", i, p[i]);
  3844. assert(!isnan(wp[i]));
  3845. }
  3846. #endif
  3847. float max = -INFINITY;
  3848. ggml_vec_max_f32(nc, &max, wp);
  3849. ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
  3850. assert(sum > 0.0);
  3851. sum = 1.0/sum;
  3852. ggml_vec_scale_f32(nc, dp, sum);
  3853. #ifndef NDEBUG
  3854. for (int i = 0; i < nc; ++i) {
  3855. assert(!isnan(dp[i]));
  3856. assert(!isinf(dp[i]));
  3857. }
  3858. #endif
  3859. }
  3860. }
  3861. void ggml_compute_forward_soft_max(
  3862. const ggml_compute_params * params,
  3863. ggml_tensor * dst) {
  3864. const ggml_tensor * src0 = dst->src[0];
  3865. switch (src0->type) {
  3866. case GGML_TYPE_F32:
  3867. {
  3868. ggml_compute_forward_soft_max_f32(params, dst);
  3869. } break;
  3870. default:
  3871. {
  3872. GGML_ABORT("fatal error");
  3873. }
  3874. }
  3875. }
  3876. // ggml_compute_forward_soft_max_ext_back
  3877. static void ggml_compute_forward_soft_max_ext_back_f32(
  3878. const ggml_compute_params * params,
  3879. ggml_tensor * dst) {
  3880. const ggml_tensor * src0 = dst->src[0];
  3881. const ggml_tensor * src1 = dst->src[1];
  3882. GGML_ASSERT(ggml_is_contiguous(src0));
  3883. GGML_ASSERT(ggml_is_contiguous(src1));
  3884. GGML_ASSERT(ggml_is_contiguous(dst));
  3885. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3886. GGML_ASSERT(ggml_are_same_shape(src1, dst));
  3887. float scale = 1.0f;
  3888. float max_bias = 0.0f;
  3889. memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
  3890. memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
  3891. GGML_ASSERT(max_bias == 0.0f);
  3892. // TODO: handle transposed/permuted matrices
  3893. const int ith = params->ith;
  3894. const int nth = params->nth;
  3895. const int nc = src0->ne[0];
  3896. const int nr = ggml_nrows(src0);
  3897. // rows per thread
  3898. const int dr = (nr + nth - 1)/nth;
  3899. // row range for this thread
  3900. const int ir0 = dr*ith;
  3901. const int ir1 = MIN(ir0 + dr, nr);
  3902. for (int i1 = ir0; i1 < ir1; i1++) {
  3903. float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
  3904. float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
  3905. float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
  3906. #ifndef NDEBUG
  3907. for (int i = 0; i < nc; ++i) {
  3908. //printf("p[%d] = %f\n", i, p[i]);
  3909. assert(!isnan(dy[i]));
  3910. assert(!isnan(y[i]));
  3911. }
  3912. #endif
  3913. // Jii = yi - yi*yi
  3914. // Jij = -yi*yj
  3915. // J = diag(y)-y.T*y
  3916. // dx = J * dy
  3917. // dxk = sum_i(Jki * dyi)
  3918. // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
  3919. // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
  3920. // dxk = sum_i(-yk*yi * dyi) + yk*dyk
  3921. // dxk = -yk * sum_i(yi * dyi) + yk*dyk
  3922. // dxk = -yk * dot(y, dy) + yk*dyk
  3923. // dxk = yk * (- dot(y, dy) + dyk)
  3924. // dxk = yk * (dyk - dot(y, dy))
  3925. //
  3926. // post-order:
  3927. // dot_y_dy := dot(y, dy)
  3928. // dx := dy
  3929. // dx := dx - dot_y_dy
  3930. // dx := dx * y
  3931. // linear runtime, no additional memory
  3932. float dot_y_dy = 0;
  3933. ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
  3934. ggml_vec_cpy_f32 (nc, dx, dy);
  3935. ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
  3936. ggml_vec_mul_f32 (nc, dx, dx, y);
  3937. ggml_vec_scale_f32(nc, dx, scale);
  3938. #ifndef NDEBUG
  3939. for (int i = 0; i < nc; ++i) {
  3940. assert(!isnan(dx[i]));
  3941. assert(!isinf(dx[i]));
  3942. }
  3943. #endif
  3944. }
  3945. }
  3946. void ggml_compute_forward_soft_max_ext_back(
  3947. const ggml_compute_params * params,
  3948. ggml_tensor * dst) {
  3949. const ggml_tensor * src0 = dst->src[0];
  3950. switch (src0->type) {
  3951. case GGML_TYPE_F32:
  3952. {
  3953. ggml_compute_forward_soft_max_ext_back_f32(params, dst);
  3954. } break;
  3955. default:
  3956. {
  3957. GGML_ABORT("fatal error");
  3958. }
  3959. }
  3960. }
  3961. // ggml_compute_forward_clamp
  3962. static void ggml_compute_forward_clamp_f32(
  3963. const ggml_compute_params * params,
  3964. ggml_tensor * dst) {
  3965. const ggml_tensor * src0 = dst->src[0];
  3966. float min;
  3967. float max;
  3968. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  3969. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  3970. const int ith = params->ith;
  3971. const int nth = params->nth;
  3972. const int n = ggml_nrows(src0);
  3973. const int nc = src0->ne[0];
  3974. const size_t nb00 = src0->nb[0];
  3975. const size_t nb01 = src0->nb[1];
  3976. const size_t nb0 = dst->nb[0];
  3977. const size_t nb1 = dst->nb[1];
  3978. GGML_ASSERT( nb0 == sizeof(float));
  3979. GGML_ASSERT(nb00 == sizeof(float));
  3980. for (int j = ith; j < n; j += nth) {
  3981. float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
  3982. float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
  3983. for (int i = 0; i < nc; i++) {
  3984. dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
  3985. }
  3986. }
  3987. }
  3988. static void ggml_compute_forward_clamp_f16(
  3989. const ggml_compute_params * params,
  3990. ggml_tensor * dst) {
  3991. const ggml_tensor * src0 = dst->src[0];
  3992. float min;
  3993. float max;
  3994. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  3995. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  3996. const int ith = params->ith;
  3997. const int nth = params->nth;
  3998. const int n = ggml_nrows(src0);
  3999. const int nc = src0->ne[0];
  4000. const size_t nb00 = src0->nb[0];
  4001. const size_t nb01 = src0->nb[1];
  4002. const size_t nb0 = dst->nb[0];
  4003. const size_t nb1 = dst->nb[1];
  4004. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  4005. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4006. for (int j = ith; j < n; j += nth) {
  4007. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
  4008. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
  4009. for (int i = 0; i < nc; i++) {
  4010. float v = GGML_FP16_TO_FP32(src0_ptr[i]);
  4011. dst_ptr[i] = GGML_FP32_TO_FP16(MAX(MIN(v, max), min));
  4012. }
  4013. }
  4014. }
  4015. void ggml_compute_forward_clamp(
  4016. const ggml_compute_params * params,
  4017. ggml_tensor * dst) {
  4018. const ggml_tensor * src0 = dst->src[0];
  4019. switch (src0->type) {
  4020. case GGML_TYPE_F32:
  4021. {
  4022. ggml_compute_forward_clamp_f32(params, dst);
  4023. } break;
  4024. case GGML_TYPE_F16:
  4025. {
  4026. ggml_compute_forward_clamp_f16(params, dst);
  4027. } break;
  4028. case GGML_TYPE_BF16:
  4029. case GGML_TYPE_Q4_0:
  4030. case GGML_TYPE_Q4_1:
  4031. case GGML_TYPE_Q5_0:
  4032. case GGML_TYPE_Q5_1:
  4033. case GGML_TYPE_Q8_0:
  4034. case GGML_TYPE_Q8_1:
  4035. case GGML_TYPE_Q2_K:
  4036. case GGML_TYPE_Q3_K:
  4037. case GGML_TYPE_Q4_K:
  4038. case GGML_TYPE_Q5_K:
  4039. case GGML_TYPE_Q6_K:
  4040. case GGML_TYPE_TQ1_0:
  4041. case GGML_TYPE_TQ2_0:
  4042. case GGML_TYPE_IQ2_XXS:
  4043. case GGML_TYPE_IQ2_XS:
  4044. case GGML_TYPE_IQ3_XXS:
  4045. case GGML_TYPE_IQ1_S:
  4046. case GGML_TYPE_IQ1_M:
  4047. case GGML_TYPE_IQ4_NL:
  4048. case GGML_TYPE_IQ4_XS:
  4049. case GGML_TYPE_IQ3_S:
  4050. case GGML_TYPE_IQ2_S:
  4051. case GGML_TYPE_Q8_K:
  4052. case GGML_TYPE_I8:
  4053. case GGML_TYPE_I16:
  4054. case GGML_TYPE_I32:
  4055. case GGML_TYPE_I64:
  4056. case GGML_TYPE_F64:
  4057. case GGML_TYPE_COUNT:
  4058. {
  4059. GGML_ABORT("fatal error");
  4060. }
  4061. }
  4062. }
  4063. // ggml_compute_forward_rope
  4064. static float rope_yarn_ramp(const float low, const float high, const int i0) {
  4065. const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
  4066. return 1 - MIN(1, MAX(0, y));
  4067. }
  4068. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  4069. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  4070. static void rope_yarn(
  4071. float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
  4072. float * cos_theta, float * sin_theta) {
  4073. // Get n-d rotational scaling corrected for extrapolation
  4074. float theta_interp = freq_scale * theta_extrap;
  4075. float theta = theta_interp;
  4076. if (ext_factor != 0.0f) {
  4077. float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
  4078. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  4079. // Get n-d magnitude scaling corrected for interpolation
  4080. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  4081. }
  4082. *cos_theta = cosf(theta) * mscale;
  4083. *sin_theta = sinf(theta) * mscale;
  4084. }
  4085. static void ggml_rope_cache_init(
  4086. float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  4087. float * cache, float sin_sign, float theta_scale) {
  4088. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  4089. float theta = theta_base;
  4090. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  4091. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  4092. rope_yarn(
  4093. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  4094. );
  4095. cache[i0 + 1] *= sin_sign;
  4096. theta *= theta_scale;
  4097. }
  4098. }
  4099. static void ggml_mrope_cache_init(
  4100. float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
  4101. float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  4102. float * cache, float sin_sign, float theta_scale) {
  4103. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  4104. float theta_t = theta_base_t;
  4105. float theta_h = theta_base_h;
  4106. float theta_w = theta_base_w;
  4107. float theta_e = theta_base_e; // extra position id for vision encoder
  4108. int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
  4109. int sec_w = sections[1] + sections[0];
  4110. int sec_e = sections[2] + sec_w;
  4111. GGML_ASSERT(sect_dims <= ne0);
  4112. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  4113. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  4114. int sector = (i0 / 2) % sect_dims;
  4115. if (indep_sects) {
  4116. // compute theta independently for each dim sections
  4117. // (i.e. reset corresponding theta when `i0` go from one section to another)
  4118. if (sector == 0) {
  4119. theta_t = theta_base_t;
  4120. }
  4121. else if (sector == sections[0]) {
  4122. theta_h = theta_base_h;;
  4123. }
  4124. else if (sector == sec_w) {
  4125. theta_w = theta_base_w;
  4126. }
  4127. else if (sector == sec_e) {
  4128. theta_e = theta_base_e;
  4129. }
  4130. }
  4131. float theta = theta_t;
  4132. if (sector >= sections[0] && sector < sec_w) {
  4133. theta = theta_h;
  4134. }
  4135. else if (sector >= sec_w && sector < sec_w + sections[2]) {
  4136. theta = theta_w;
  4137. }
  4138. else if (sector >= sec_w + sections[2]) {
  4139. theta = theta_e;
  4140. }
  4141. rope_yarn(
  4142. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  4143. );
  4144. cache[i0 + 1] *= sin_sign;
  4145. theta_t *= theta_scale;
  4146. theta_w *= theta_scale;
  4147. theta_h *= theta_scale;
  4148. theta_e *= theta_scale;
  4149. }
  4150. }
  4151. static void ggml_compute_forward_rope_f32(
  4152. const ggml_compute_params * params,
  4153. ggml_tensor * dst,
  4154. const bool forward) {
  4155. const ggml_tensor * src0 = dst->src[0];
  4156. const ggml_tensor * src1 = dst->src[1];
  4157. const ggml_tensor * src2 = dst->src[2];
  4158. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  4159. int sections[4];
  4160. //const int n_past = ((int32_t *) dst->op_params)[0];
  4161. const int n_dims = ((int32_t *) dst->op_params)[1];
  4162. const int mode = ((int32_t *) dst->op_params)[2];
  4163. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  4164. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  4165. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  4166. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  4167. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  4168. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  4169. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  4170. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  4171. memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
  4172. GGML_TENSOR_UNARY_OP_LOCALS
  4173. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  4174. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  4175. GGML_ASSERT(nb00 == sizeof(float));
  4176. const int ith = params->ith;
  4177. const int nth = params->nth;
  4178. const int nr = ggml_nrows(dst);
  4179. GGML_ASSERT(n_dims <= ne0);
  4180. GGML_ASSERT(n_dims % 2 == 0);
  4181. // rows per thread
  4182. const int dr = (nr + nth - 1)/nth;
  4183. // row range for this thread
  4184. const int ir0 = dr*ith;
  4185. const int ir1 = MIN(ir0 + dr, nr);
  4186. // row index used to determine which thread to use
  4187. int ir = 0;
  4188. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  4189. float corr_dims[2];
  4190. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  4191. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  4192. const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
  4193. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  4194. if (is_mrope) {
  4195. GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
  4196. }
  4197. if (is_vision) {
  4198. GGML_ASSERT(n_dims == ne0/2);
  4199. }
  4200. const float * freq_factors = NULL;
  4201. if (src2 != NULL) {
  4202. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  4203. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  4204. freq_factors = (const float *) src2->data;
  4205. }
  4206. // backward process uses inverse rotation by cos and sin.
  4207. // cos and sin build a rotation matrix, where the inverse is the transpose.
  4208. // this essentially just switches the sign of sin.
  4209. const float sin_sign = forward ? 1.0f : -1.0f;
  4210. const int32_t * pos = (const int32_t *) src1->data;
  4211. for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
  4212. for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
  4213. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  4214. if (!is_mrope) {
  4215. const int64_t p = pos[i2];
  4216. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4217. }
  4218. else {
  4219. const int64_t p_t = pos[i2];
  4220. const int64_t p_h = pos[i2 + ne2];
  4221. const int64_t p_w = pos[i2 + ne2 * 2];
  4222. const int64_t p_e = pos[i2 + ne2 * 3];
  4223. ggml_mrope_cache_init(
  4224. p_t, p_h, p_w, p_e, sections, is_vision,
  4225. freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4226. }
  4227. for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
  4228. if (ir++ < ir0) continue;
  4229. if (ir > ir1) break;
  4230. if (is_neox || is_mrope) {
  4231. if (is_vision){
  4232. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4233. const int64_t ic = i0/2;
  4234. const float cos_theta = cache[i0 + 0];
  4235. const float sin_theta = cache[i0 + 1];
  4236. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4237. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4238. const float x0 = src[0];
  4239. const float x1 = src[n_dims];
  4240. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4241. dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
  4242. }
  4243. } else {
  4244. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4245. const int64_t ic = i0/2;
  4246. const float cos_theta = cache[i0 + 0];
  4247. const float sin_theta = cache[i0 + 1];
  4248. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4249. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4250. const float x0 = src[0];
  4251. const float x1 = src[n_dims/2];
  4252. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4253. dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
  4254. }
  4255. }
  4256. } else {
  4257. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4258. const float cos_theta = cache[i0 + 0];
  4259. const float sin_theta = cache[i0 + 1];
  4260. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4261. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4262. const float x0 = src[0];
  4263. const float x1 = src[1];
  4264. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4265. dst_data[1] = x0*sin_theta + x1*cos_theta;
  4266. }
  4267. }
  4268. if (is_vision) {
  4269. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4270. const int64_t ic = i0/2;
  4271. const float cos_theta = cache[i0 + 0];
  4272. const float sin_theta = cache[i0 + 1];
  4273. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4274. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4275. const float x0 = src[0];
  4276. const float x1 = src[n_dims];
  4277. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4278. dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
  4279. }
  4280. } else {
  4281. // fill the remain channels with data from src tensor
  4282. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4283. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4284. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4285. dst_data[0] = src[0];
  4286. dst_data[1] = src[1];
  4287. }
  4288. }
  4289. }
  4290. }
  4291. }
  4292. }
  4293. // TODO: deduplicate f16/f32 code
  4294. static void ggml_compute_forward_rope_f16(
  4295. const ggml_compute_params * params,
  4296. ggml_tensor * dst,
  4297. const bool forward) {
  4298. const ggml_tensor * src0 = dst->src[0];
  4299. const ggml_tensor * src1 = dst->src[1];
  4300. const ggml_tensor * src2 = dst->src[2];
  4301. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  4302. int sections[4];
  4303. //const int n_past = ((int32_t *) dst->op_params)[0];
  4304. const int n_dims = ((int32_t *) dst->op_params)[1];
  4305. const int mode = ((int32_t *) dst->op_params)[2];
  4306. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  4307. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  4308. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  4309. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  4310. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  4311. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  4312. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  4313. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  4314. memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
  4315. GGML_TENSOR_UNARY_OP_LOCALS
  4316. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  4317. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  4318. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  4319. const int ith = params->ith;
  4320. const int nth = params->nth;
  4321. const int nr = ggml_nrows(dst);
  4322. GGML_ASSERT(n_dims <= ne0);
  4323. GGML_ASSERT(n_dims % 2 == 0);
  4324. // rows per thread
  4325. const int dr = (nr + nth - 1)/nth;
  4326. // row range for this thread
  4327. const int ir0 = dr*ith;
  4328. const int ir1 = MIN(ir0 + dr, nr);
  4329. // row index used to determine which thread to use
  4330. int ir = 0;
  4331. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  4332. float corr_dims[2];
  4333. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  4334. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  4335. const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
  4336. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  4337. if (is_mrope) {
  4338. GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
  4339. }
  4340. if (is_vision) {
  4341. GGML_ASSERT(n_dims == ne0/2);
  4342. }
  4343. const float * freq_factors = NULL;
  4344. if (src2 != NULL) {
  4345. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  4346. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  4347. freq_factors = (const float *) src2->data;
  4348. }
  4349. // backward process uses inverse rotation by cos and sin.
  4350. // cos and sin build a rotation matrix, where the inverse is the transpose.
  4351. // this essentially just switches the sign of sin.
  4352. const float sin_sign = forward ? 1.0f : -1.0f;
  4353. const int32_t * pos = (const int32_t *) src1->data;
  4354. for (int64_t i3 = 0; i3 < ne3; i3++) {
  4355. for (int64_t i2 = 0; i2 < ne2; i2++) {
  4356. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  4357. if (!is_mrope) {
  4358. const int64_t p = pos[i2];
  4359. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4360. }
  4361. else {
  4362. const int64_t p_t = pos[i2];
  4363. const int64_t p_h = pos[i2 + ne2];
  4364. const int64_t p_w = pos[i2 + ne2 * 2];
  4365. const int64_t p_e = pos[i2 + ne2 * 3];
  4366. ggml_mrope_cache_init(
  4367. p_t, p_h, p_w, p_e, sections, is_vision,
  4368. freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4369. }
  4370. for (int64_t i1 = 0; i1 < ne1; i1++) {
  4371. if (ir++ < ir0) continue;
  4372. if (ir > ir1) break;
  4373. if (is_neox || is_mrope) {
  4374. if (is_vision) {
  4375. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4376. const int64_t ic = i0/2;
  4377. const float cos_theta = cache[i0 + 0];
  4378. const float sin_theta = cache[i0 + 1];
  4379. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4380. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4381. const float x0 = GGML_FP16_TO_FP32(src[0]);
  4382. const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
  4383. dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  4384. dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  4385. }
  4386. } else {
  4387. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4388. const int64_t ic = i0/2;
  4389. const float cos_theta = cache[i0 + 0];
  4390. const float sin_theta = cache[i0 + 1];
  4391. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4392. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4393. const float x0 = GGML_FP16_TO_FP32(src[0]);
  4394. const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
  4395. dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  4396. dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  4397. }
  4398. }
  4399. } else {
  4400. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4401. const float cos_theta = cache[i0 + 0];
  4402. const float sin_theta = cache[i0 + 1];
  4403. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4404. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4405. const float x0 = GGML_FP16_TO_FP32(src[0]);
  4406. const float x1 = GGML_FP16_TO_FP32(src[1]);
  4407. dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  4408. dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  4409. }
  4410. }
  4411. if (is_vision) {
  4412. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4413. const int64_t ic = i0/2;
  4414. const float cos_theta = cache[i0 + 0];
  4415. const float sin_theta = cache[i0 + 1];
  4416. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4417. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4418. const float x0 = GGML_FP16_TO_FP32(src[0]);
  4419. const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
  4420. dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  4421. dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  4422. }
  4423. } else {
  4424. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4425. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4426. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4427. dst_data[0] = src[0];
  4428. dst_data[1] = src[1];
  4429. }
  4430. }
  4431. }
  4432. }
  4433. }
  4434. }
  4435. void ggml_compute_forward_rope(
  4436. const ggml_compute_params * params,
  4437. ggml_tensor * dst) {
  4438. const ggml_tensor * src0 = dst->src[0];
  4439. switch (src0->type) {
  4440. case GGML_TYPE_F16:
  4441. {
  4442. ggml_compute_forward_rope_f16(params, dst, true);
  4443. } break;
  4444. case GGML_TYPE_F32:
  4445. {
  4446. ggml_compute_forward_rope_f32(params, dst, true);
  4447. } break;
  4448. default:
  4449. {
  4450. GGML_ABORT("fatal error");
  4451. }
  4452. }
  4453. }
  4454. // ggml_compute_forward_rope_back
  4455. void ggml_compute_forward_rope_back(
  4456. const ggml_compute_params * params,
  4457. ggml_tensor * dst) {
  4458. const ggml_tensor * src0 = dst->src[0];
  4459. switch (src0->type) {
  4460. case GGML_TYPE_F16:
  4461. {
  4462. ggml_compute_forward_rope_f16(params, dst, false);
  4463. } break;
  4464. case GGML_TYPE_F32:
  4465. {
  4466. ggml_compute_forward_rope_f32(params, dst, false);
  4467. } break;
  4468. default:
  4469. {
  4470. GGML_ABORT("fatal error");
  4471. }
  4472. }
  4473. }
  4474. // ggml_compute_forward_conv_transpose_1d
  4475. static void ggml_compute_forward_conv_transpose_1d_f16_f32(
  4476. const ggml_compute_params * params,
  4477. ggml_tensor * dst) {
  4478. const ggml_tensor * src0 = dst->src[0];
  4479. const ggml_tensor * src1 = dst->src[1];
  4480. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4481. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4482. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4483. GGML_TENSOR_BINARY_OP_LOCALS
  4484. const int ith = params->ith;
  4485. const int nth = params->nth;
  4486. const int nk = ne00*ne01*ne02;
  4487. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4488. GGML_ASSERT(nb10 == sizeof(float));
  4489. if (ith == 0) {
  4490. memset(params->wdata, 0, params->wsize);
  4491. // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  4492. {
  4493. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  4494. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4495. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4496. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
  4497. ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
  4498. for (int64_t i00 = 0; i00 < ne00; i00++) {
  4499. dst_data[i00*ne02 + i02] = src[i00];
  4500. }
  4501. }
  4502. }
  4503. }
  4504. // permute source data (src1) from (L x Cin) to (Cin x L)
  4505. {
  4506. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  4507. ggml_fp16_t * dst_data = wdata;
  4508. for (int64_t i11 = 0; i11 < ne11; i11++) {
  4509. const float * const src = (float *)((char *) src1->data + i11*nb11);
  4510. for (int64_t i10 = 0; i10 < ne10; i10++) {
  4511. dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
  4512. }
  4513. }
  4514. }
  4515. // need to zero dst since we are accumulating into it
  4516. memset(dst->data, 0, ggml_nbytes(dst));
  4517. }
  4518. ggml_barrier(params->threadpool);
  4519. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  4520. // total rows in dst
  4521. const int nr = ne1;
  4522. // rows per thread
  4523. const int dr = (nr + nth - 1)/nth;
  4524. // row range for this thread
  4525. const int ir0 = dr*ith;
  4526. const int ir1 = MIN(ir0 + dr, nr);
  4527. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  4528. ggml_fp16_t * const wdata_src = wdata + nk;
  4529. for (int i1 = ir0; i1 < ir1; i1++) {
  4530. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  4531. ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
  4532. for (int i10 = 0; i10 < ne10; i10++) {
  4533. const int i1n = i10*ne11;
  4534. for (int i00 = 0; i00 < ne00; i00++) {
  4535. float v = 0;
  4536. ggml_vec_dot_f16(ne02, &v, 0,
  4537. (ggml_fp16_t *) wdata_src + i1n, 0,
  4538. (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
  4539. dst_data[i10*s0 + i00] += v;
  4540. }
  4541. }
  4542. }
  4543. }
  4544. static void ggml_compute_forward_conv_transpose_1d_f32(
  4545. const ggml_compute_params * params,
  4546. ggml_tensor * dst) {
  4547. const ggml_tensor * src0 = dst->src[0];
  4548. const ggml_tensor * src1 = dst->src[1];
  4549. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  4550. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4551. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4552. GGML_TENSOR_BINARY_OP_LOCALS
  4553. const int ith = params->ith;
  4554. const int nth = params->nth;
  4555. const int nk = ne00*ne01*ne02;
  4556. GGML_ASSERT(nb00 == sizeof(float));
  4557. GGML_ASSERT(nb10 == sizeof(float));
  4558. if (ith == 0) {
  4559. memset(params->wdata, 0, params->wsize);
  4560. // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  4561. {
  4562. float * const wdata = (float *) params->wdata + 0;
  4563. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4564. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4565. const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
  4566. float * dst_data = wdata + i01*ne00*ne02;
  4567. for (int64_t i00 = 0; i00 < ne00; i00++) {
  4568. dst_data[i00*ne02 + i02] = src[i00];
  4569. }
  4570. }
  4571. }
  4572. }
  4573. // prepare source data (src1)
  4574. {
  4575. float * const wdata = (float *) params->wdata + nk;
  4576. float * dst_data = wdata;
  4577. for (int64_t i11 = 0; i11 < ne11; i11++) {
  4578. const float * const src = (float *)((char *) src1->data + i11*nb11);
  4579. for (int64_t i10 = 0; i10 < ne10; i10++) {
  4580. dst_data[i10*ne11 + i11] = src[i10];
  4581. }
  4582. }
  4583. }
  4584. // need to zero dst since we are accumulating into it
  4585. memset(dst->data, 0, ggml_nbytes(dst));
  4586. }
  4587. ggml_barrier(params->threadpool);
  4588. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  4589. // total rows in dst
  4590. const int nr = ne1;
  4591. // rows per thread
  4592. const int dr = (nr + nth - 1)/nth;
  4593. // row range for this thread
  4594. const int ir0 = dr*ith;
  4595. const int ir1 = MIN(ir0 + dr, nr);
  4596. float * const wdata = (float *) params->wdata + 0;
  4597. float * const wdata_src = wdata + nk;
  4598. for (int i1 = ir0; i1 < ir1; i1++) {
  4599. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  4600. float * wdata_kernel = wdata + i1*ne02*ne00;
  4601. for (int i10 = 0; i10 < ne10; i10++) {
  4602. const int i1n = i10*ne11;
  4603. for (int i00 = 0; i00 < ne00; i00++) {
  4604. float v = 0;
  4605. ggml_vec_dot_f32(ne02, &v, 0,
  4606. wdata_src + i1n, 0,
  4607. wdata_kernel + i00*ne02, 0, 1);
  4608. dst_data[i10*s0 + i00] += v;
  4609. }
  4610. }
  4611. }
  4612. }
  4613. void ggml_compute_forward_conv_transpose_1d(
  4614. const ggml_compute_params * params,
  4615. ggml_tensor * dst) {
  4616. const ggml_tensor * src0 = dst->src[0];
  4617. switch (src0->type) {
  4618. case GGML_TYPE_F16:
  4619. {
  4620. ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
  4621. } break;
  4622. case GGML_TYPE_F32:
  4623. {
  4624. ggml_compute_forward_conv_transpose_1d_f32(params, dst);
  4625. } break;
  4626. default:
  4627. {
  4628. GGML_ABORT("fatal error");
  4629. }
  4630. }
  4631. }
  4632. // ggml_compute_forward_im2col_f32
  4633. // src0: kernel [OC, IC, KH, KW]
  4634. // src1: image [N, IC, IH, IW]
  4635. // dst: result [N, OH, OW, IC*KH*KW]
  4636. static void ggml_compute_forward_im2col_f32(
  4637. const ggml_compute_params * params,
  4638. ggml_tensor * dst) {
  4639. const ggml_tensor * src0 = dst->src[0];
  4640. const ggml_tensor * src1 = dst->src[1];
  4641. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4642. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4643. GGML_TENSOR_BINARY_OP_LOCALS;
  4644. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4645. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4646. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  4647. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  4648. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  4649. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  4650. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  4651. const int ith = params->ith;
  4652. const int nth = params->nth;
  4653. const int64_t N = is_2D ? ne13 : ne12;
  4654. const int64_t IC = is_2D ? ne12 : ne11;
  4655. const int64_t IH = is_2D ? ne11 : 1;
  4656. const int64_t IW = ne10;
  4657. const int64_t KH = is_2D ? ne01 : 1;
  4658. const int64_t KW = ne00;
  4659. const int64_t OH = is_2D ? ne2 : 1;
  4660. const int64_t OW = ne1;
  4661. int ofs0 = is_2D ? nb13 : nb12;
  4662. int ofs1 = is_2D ? nb12 : nb11;
  4663. GGML_ASSERT(nb10 == sizeof(float));
  4664. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  4665. {
  4666. float * const wdata = (float *) dst->data;
  4667. for (int64_t in = 0; in < N; in++) {
  4668. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  4669. for (int64_t iow = 0; iow < OW; iow++) {
  4670. for (int64_t iic = ith; iic < IC; iic += nth) {
  4671. // micro kernel
  4672. float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  4673. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  4674. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  4675. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4676. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  4677. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  4678. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  4679. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  4680. } else {
  4681. dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
  4682. }
  4683. }
  4684. }
  4685. }
  4686. }
  4687. }
  4688. }
  4689. }
  4690. }
  4691. // ggml_compute_forward_im2col_f16
  4692. // src0: kernel [OC, IC, KH, KW]
  4693. // src1: image [N, IC, IH, IW]
  4694. // dst: result [N, OH, OW, IC*KH*KW]
  4695. static void ggml_compute_forward_im2col_f16(
  4696. const ggml_compute_params * params,
  4697. ggml_tensor * dst) {
  4698. const ggml_tensor * src0 = dst->src[0];
  4699. const ggml_tensor * src1 = dst->src[1];
  4700. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4701. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4702. GGML_ASSERT( dst->type == GGML_TYPE_F16);
  4703. GGML_TENSOR_BINARY_OP_LOCALS;
  4704. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4705. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4706. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  4707. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  4708. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  4709. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  4710. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  4711. const int ith = params->ith;
  4712. const int nth = params->nth;
  4713. const int64_t N = is_2D ? ne13 : ne12;
  4714. const int64_t IC = is_2D ? ne12 : ne11;
  4715. const int64_t IH = is_2D ? ne11 : 1;
  4716. const int64_t IW = ne10;
  4717. const int64_t KH = is_2D ? ne01 : 1;
  4718. const int64_t KW = ne00;
  4719. const int64_t OH = is_2D ? ne2 : 1;
  4720. const int64_t OW = ne1;
  4721. int ofs0 = is_2D ? nb13 : nb12;
  4722. int ofs1 = is_2D ? nb12 : nb11;
  4723. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4724. GGML_ASSERT(nb10 == sizeof(float));
  4725. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  4726. {
  4727. ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
  4728. for (int64_t in = 0; in < N; in++) {
  4729. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  4730. for (int64_t iow = 0; iow < OW; iow++) {
  4731. for (int64_t iic = ith; iic < IC; iic += nth) {
  4732. // micro kernel
  4733. ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  4734. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  4735. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  4736. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4737. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  4738. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  4739. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  4740. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  4741. } else {
  4742. dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
  4743. }
  4744. }
  4745. }
  4746. }
  4747. }
  4748. }
  4749. }
  4750. }
  4751. }
  4752. void ggml_compute_forward_im2col(
  4753. const ggml_compute_params * params,
  4754. ggml_tensor * dst) {
  4755. switch (dst->type) {
  4756. case GGML_TYPE_F16:
  4757. {
  4758. ggml_compute_forward_im2col_f16(params, dst);
  4759. } break;
  4760. case GGML_TYPE_F32:
  4761. {
  4762. ggml_compute_forward_im2col_f32(params, dst);
  4763. } break;
  4764. default:
  4765. {
  4766. GGML_ABORT("fatal error");
  4767. }
  4768. }
  4769. }
  4770. // ggml_compute_forward_im2col_back_f32
  4771. void ggml_compute_forward_im2col_back_f32(
  4772. const ggml_compute_params * params,
  4773. ggml_tensor * dst) {
  4774. const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
  4775. const ggml_tensor * src1 = dst->src[1]; // convolution kernel
  4776. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  4777. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4778. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4779. GGML_TENSOR_BINARY_OP_LOCALS;
  4780. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  4781. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  4782. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  4783. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  4784. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  4785. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  4786. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  4787. const int ith = params->ith;
  4788. const int nth = params->nth;
  4789. const int64_t N = is_2D ? ne3 : ne2;
  4790. const int64_t IC = is_2D ? ne2 : ne1;
  4791. const int64_t IH = is_2D ? ne1 : 1;
  4792. const int64_t IW = ne0;
  4793. const int64_t KH = is_2D ? ne11 : 1;
  4794. const int64_t KW = ne10;
  4795. const int64_t OH = is_2D ? ne02 : 1;
  4796. const int64_t OW = ne01;
  4797. int ofs0 = is_2D ? nb3 : nb2;
  4798. int ofs1 = is_2D ? nb2 : nb1;
  4799. GGML_ASSERT(nb0 == sizeof(float));
  4800. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  4801. {
  4802. float * const wdata = (float *) dst->data;
  4803. for (int64_t in = 0; in < N; in++) {
  4804. for (int64_t iic = ith; iic < IC; iic += nth) {
  4805. for (int64_t iih = 0; iih < IH; iih++) {
  4806. for (int64_t iiw = 0; iiw < IW; iiw++) {
  4807. // micro kernel
  4808. float grad = 0.0f;
  4809. for (int64_t ikh = 0; ikh < KH; ikh++) {
  4810. for (int64_t ikw = 0; ikw < KW; ikw++) {
  4811. // For s0 > 1 some values were skipped over in the forward pass.
  4812. // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
  4813. const int64_t tmpw = (iiw + p0 - ikw*d0);
  4814. if (tmpw % s0 != 0) {
  4815. continue;
  4816. }
  4817. const int64_t iow = tmpw / s0;
  4818. // Equivalent logic as above except for s1.
  4819. int64_t ioh;
  4820. if (is_2D) {
  4821. const int64_t tmph = iih + p1 - ikh*d1;
  4822. if (tmph % s1 != 0) {
  4823. continue;
  4824. }
  4825. ioh = tmph / s1;
  4826. } else {
  4827. ioh = 0;
  4828. }
  4829. if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
  4830. continue;
  4831. }
  4832. const float * const grad_in = (const float *) src0->data
  4833. + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  4834. grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
  4835. }
  4836. }
  4837. float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
  4838. dst_data[iih*IW + iiw] = grad;
  4839. }
  4840. }
  4841. }
  4842. }
  4843. }
  4844. }
  4845. // ggml_compute_forward_conv_transpose_2d
  4846. void ggml_compute_forward_conv_transpose_2d(
  4847. const ggml_compute_params * params,
  4848. ggml_tensor * dst) {
  4849. const ggml_tensor * src0 = dst->src[0];
  4850. const ggml_tensor * src1 = dst->src[1];
  4851. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4852. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4853. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  4854. GGML_TENSOR_BINARY_OP_LOCALS
  4855. const int ith = params->ith;
  4856. const int nth = params->nth;
  4857. const int nk = ne00*ne01*ne02*ne03;
  4858. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4859. GGML_ASSERT(nb10 == sizeof(float));
  4860. if (ith == 0) {
  4861. memset(params->wdata, 0, params->wsize);
  4862. // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
  4863. {
  4864. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  4865. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4866. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4867. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
  4868. ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
  4869. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4870. for (int64_t i00 = 0; i00 < ne00; i00++) {
  4871. dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
  4872. }
  4873. }
  4874. }
  4875. }
  4876. }
  4877. // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
  4878. {
  4879. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  4880. for (int i12 = 0; i12 < ne12; i12++) {
  4881. for (int i11 = 0; i11 < ne11; i11++) {
  4882. const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
  4883. ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
  4884. for (int i10 = 0; i10 < ne10; i10++) {
  4885. dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
  4886. }
  4887. }
  4888. }
  4889. }
  4890. memset(dst->data, 0, ggml_nbytes(dst));
  4891. }
  4892. ggml_barrier(params->threadpool);
  4893. const int32_t stride = ggml_get_op_params_i32(dst, 0);
  4894. // total patches in dst
  4895. const int np = ne2;
  4896. // patches per thread
  4897. const int dp = (np + nth - 1)/nth;
  4898. // patch range for this thread
  4899. const int ip0 = dp*ith;
  4900. const int ip1 = MIN(ip0 + dp, np);
  4901. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  4902. ggml_fp16_t * const wdata_src = wdata + nk;
  4903. for (int i2 = ip0; i2 < ip1; i2++) { // Cout
  4904. float * dst_data = (float *)((char *) dst->data + i2*nb2);
  4905. ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
  4906. for (int i11 = 0; i11 < ne11; i11++) {
  4907. for (int i10 = 0; i10 < ne10; i10++) {
  4908. const int i1n = i11*ne10*ne12 + i10*ne12;
  4909. for (int i01 = 0; i01 < ne01; i01++) {
  4910. for (int i00 = 0; i00 < ne00; i00++) {
  4911. float v = 0;
  4912. ggml_vec_dot_f16(ne03, &v, 0,
  4913. wdata_src + i1n, 0,
  4914. wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
  4915. dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
  4916. }
  4917. }
  4918. }
  4919. }
  4920. }
  4921. }
  4922. // ggml_compute_forward_pool_1d_sk_p0
  4923. static void ggml_compute_forward_pool_1d_sk_p0(
  4924. const ggml_compute_params * params,
  4925. const ggml_op_pool op,
  4926. const int k,
  4927. ggml_tensor * dst) {
  4928. const ggml_tensor * src = dst->src[0];
  4929. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  4930. if (params->ith != 0) {
  4931. return;
  4932. }
  4933. const char * cdata = (const char *)src->data;
  4934. const char * const data_end = cdata + ggml_nbytes(src);
  4935. float * drow = (float *)dst->data;
  4936. const int64_t rs = dst->ne[0];
  4937. while (cdata < data_end) {
  4938. const void * srow = (const void *)cdata;
  4939. int j = 0;
  4940. for (int64_t i = 0; i < rs; ++i) {
  4941. switch (op) {
  4942. case GGML_OP_POOL_AVG: drow[i] = 0; break;
  4943. case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
  4944. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  4945. }
  4946. for (int ki = 0; ki < k; ++ki) {
  4947. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  4948. switch (op) {
  4949. case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
  4950. case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
  4951. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  4952. }
  4953. ++j;
  4954. }
  4955. switch (op) {
  4956. case GGML_OP_POOL_AVG: drow[i] /= k; break;
  4957. case GGML_OP_POOL_MAX: break;
  4958. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  4959. }
  4960. }
  4961. cdata += src->nb[1];
  4962. drow += rs;
  4963. }
  4964. }
  4965. // ggml_compute_forward_pool_1d
  4966. void ggml_compute_forward_pool_1d(
  4967. const ggml_compute_params * params,
  4968. ggml_tensor * dst) {
  4969. const int32_t * opts = (const int32_t *)dst->op_params;
  4970. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  4971. const int k0 = opts[1];
  4972. const int s0 = opts[2];
  4973. const int p0 = opts[3];
  4974. GGML_ASSERT(p0 == 0); // padding not supported
  4975. GGML_ASSERT(k0 == s0); // only s = k supported
  4976. ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
  4977. }
  4978. // ggml_compute_forward_pool_2d
  4979. void ggml_compute_forward_pool_2d(
  4980. const ggml_compute_params * params,
  4981. ggml_tensor * dst) {
  4982. const ggml_tensor * src = dst->src[0];
  4983. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  4984. if (params->ith != 0) {
  4985. return;
  4986. }
  4987. const int32_t * opts = (const int32_t *)dst->op_params;
  4988. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  4989. const int k0 = opts[1];
  4990. const int k1 = opts[2];
  4991. const int s0 = opts[3];
  4992. const int s1 = opts[4];
  4993. const int p0 = opts[5];
  4994. const int p1 = opts[6];
  4995. const char * cdata = (const char*)src->data;
  4996. const char * const data_end = cdata + ggml_nbytes(src);
  4997. const int64_t px = dst->ne[0];
  4998. const int64_t py = dst->ne[1];
  4999. const int64_t pa = px * py;
  5000. float * dplane = (float *)dst->data;
  5001. const int ka = k0 * k1;
  5002. const int offset0 = -p0;
  5003. const int offset1 = -p1;
  5004. while (cdata < data_end) {
  5005. for (int oy = 0; oy < py; ++oy) {
  5006. float * const drow = dplane + oy * px;
  5007. for (int ox = 0; ox < px; ++ox) {
  5008. float * const out = drow + ox;
  5009. switch (op) {
  5010. case GGML_OP_POOL_AVG: *out = 0; break;
  5011. case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
  5012. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5013. }
  5014. const int ix = offset0 + ox * s0;
  5015. const int iy = offset1 + oy * s1;
  5016. for (int ky = 0; ky < k1; ++ky) {
  5017. if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
  5018. const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
  5019. for (int kx = 0; kx < k0; ++kx) {
  5020. int j = ix + kx;
  5021. if (j < 0 || j >= src->ne[0]) continue;
  5022. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  5023. switch (op) {
  5024. case GGML_OP_POOL_AVG: *out += srow_j; break;
  5025. case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
  5026. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5027. }
  5028. }
  5029. }
  5030. switch (op) {
  5031. case GGML_OP_POOL_AVG: *out /= ka; break;
  5032. case GGML_OP_POOL_MAX: break;
  5033. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5034. }
  5035. }
  5036. }
  5037. cdata += src->nb[2];
  5038. dplane += pa;
  5039. }
  5040. }
  5041. // ggml_compute_forward_pool_2d_back
  5042. void ggml_compute_forward_pool_2d_back(
  5043. const ggml_compute_params * params,
  5044. ggml_tensor * dst) {
  5045. const ggml_tensor * src = dst->src[0];
  5046. const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
  5047. assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  5048. if (params->ith != 0) {
  5049. return;
  5050. }
  5051. const int32_t * opts = (const int32_t *)dst->op_params;
  5052. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  5053. const int k0 = opts[1];
  5054. const int k1 = opts[2];
  5055. const int s0 = opts[3];
  5056. const int s1 = opts[4];
  5057. const int p0 = opts[5];
  5058. const int p1 = opts[6];
  5059. char * cdata = (char *) dst->data;
  5060. const char * cdataf = (const char *) dstf->data;
  5061. const char * const data_end = cdata + ggml_nbytes(dst);
  5062. GGML_ASSERT(params->ith == 0);
  5063. memset(cdata, 0, ggml_nbytes(dst));
  5064. const int64_t px = src->ne[0];
  5065. const int64_t py = src->ne[1];
  5066. const int64_t pa = px * py;
  5067. const float * splane = (const float *) src->data;
  5068. const int ka = k0 * k1;
  5069. const int offset0 = -p0;
  5070. const int offset1 = -p1;
  5071. while (cdata < data_end) {
  5072. for (int oy = 0; oy < py; ++oy) {
  5073. const float * const srow = splane + oy * px;
  5074. for (int ox = 0; ox < px; ++ox) {
  5075. const float grad0 = srow[ox];
  5076. const int ix = offset0 + ox * s0;
  5077. const int iy = offset1 + oy * s1;
  5078. if (op == GGML_OP_POOL_MAX) {
  5079. float maxval = -FLT_MAX;
  5080. int kxmax = -1;
  5081. int kymax = -1;
  5082. for (int ky = 0; ky < k1; ++ky) {
  5083. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  5084. continue;
  5085. }
  5086. const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
  5087. for (int kx = 0; kx < k0; ++kx) {
  5088. int j = ix + kx;
  5089. if (j < 0 || j >= dst->ne[0]) {
  5090. continue;
  5091. }
  5092. const float val = dst->type == GGML_TYPE_F32 ?
  5093. ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
  5094. if (val <= maxval) {
  5095. continue;
  5096. }
  5097. maxval = val;
  5098. kxmax = kx;
  5099. kymax = ky;
  5100. }
  5101. }
  5102. if (kxmax == -1 || kymax == -1) {
  5103. continue;
  5104. }
  5105. void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
  5106. const int j = ix + kxmax;
  5107. if (dst->type == GGML_TYPE_F32) {
  5108. ((float *) drow)[j] += grad0;
  5109. } else {
  5110. ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
  5111. }
  5112. } else if (op == GGML_OP_POOL_AVG) {
  5113. const float grad = grad0 / ka;
  5114. for (int ky = 0; ky < k1; ++ky) {
  5115. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  5116. continue;
  5117. }
  5118. void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
  5119. for (int kx = 0; kx < k0; ++kx) {
  5120. int j = ix + kx;
  5121. if (j < 0 || j >= dst->ne[0]) {
  5122. continue;
  5123. }
  5124. if (dst->type == GGML_TYPE_F32) {
  5125. ((float *) drow)[j] += grad;
  5126. } else {
  5127. ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
  5128. }
  5129. }
  5130. }
  5131. } else {
  5132. GGML_ASSERT(false);
  5133. }
  5134. }
  5135. }
  5136. cdata += dst->nb[2];
  5137. cdataf += dst->nb[2];
  5138. splane += pa;
  5139. }
  5140. }
  5141. // ggml_compute_forward_upscale
  5142. static void ggml_compute_forward_upscale_f32(
  5143. const ggml_compute_params * params,
  5144. ggml_tensor * dst) {
  5145. const ggml_tensor * src0 = dst->src[0];
  5146. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  5147. const int ith = params->ith;
  5148. const int nth = params->nth;
  5149. GGML_TENSOR_UNARY_OP_LOCALS
  5150. const float sf0 = (float)ne0/src0->ne[0];
  5151. const float sf1 = (float)ne1/src0->ne[1];
  5152. const float sf2 = (float)ne2/src0->ne[2];
  5153. const float sf3 = (float)ne3/src0->ne[3];
  5154. const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
  5155. if (mode == GGML_SCALE_MODE_NEAREST) {
  5156. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5157. const int64_t i03 = i3 / sf3;
  5158. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  5159. const int64_t i02 = i2 / sf2;
  5160. for (int64_t i1 = 0; i1 < ne1; i1++) {
  5161. const int64_t i01 = i1 / sf1;
  5162. for (int64_t i0 = 0; i0 < ne0; i0++) {
  5163. const int64_t i00 = i0 / sf0;
  5164. const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  5165. float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  5166. *y = *x;
  5167. }
  5168. }
  5169. }
  5170. }
  5171. } else if (mode == GGML_SCALE_MODE_BILINEAR) {
  5172. // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
  5173. const float pixel_offset = 0.5f;
  5174. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5175. const int64_t i03 = i3 / sf3;
  5176. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  5177. const int64_t i02 = i2 / sf2;
  5178. for (int64_t i1 = 0; i1 < ne1; i1++) {
  5179. const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
  5180. int64_t y0 = (int64_t)floorf(y);
  5181. int64_t y1 = y0 + 1;
  5182. y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
  5183. y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
  5184. float dy = y - (float)y0;
  5185. dy = std::max(0.0f, std::min(dy, 1.0f));
  5186. for (int64_t i0 = 0; i0 < ne0; i0++) {
  5187. const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
  5188. int64_t x0 = (int64_t)floorf(x);
  5189. int64_t x1 = x0 + 1;
  5190. x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
  5191. x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
  5192. float dx = x - (float)x0;
  5193. dx = std::max(0.0f, std::min(dx, 1.0f));
  5194. // fetch the four surrounding pixel values and interpolate
  5195. const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
  5196. const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
  5197. const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
  5198. const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
  5199. const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
  5200. float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  5201. *y_dst = val;
  5202. }
  5203. }
  5204. }
  5205. }
  5206. } else {
  5207. GGML_ABORT("unsupported upscale mode");
  5208. }
  5209. }
  5210. void ggml_compute_forward_upscale(
  5211. const ggml_compute_params * params,
  5212. ggml_tensor * dst) {
  5213. const ggml_tensor * src0 = dst->src[0];
  5214. switch (src0->type) {
  5215. case GGML_TYPE_F32:
  5216. {
  5217. ggml_compute_forward_upscale_f32(params, dst);
  5218. } break;
  5219. default:
  5220. {
  5221. GGML_ABORT("fatal error");
  5222. }
  5223. }
  5224. }
  5225. // ggml_compute_forward_pad
  5226. static void ggml_compute_forward_pad_f32(
  5227. const ggml_compute_params * params,
  5228. ggml_tensor * dst) {
  5229. const ggml_tensor * src0 = dst->src[0];
  5230. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5231. GGML_ASSERT( dst->nb[0] == sizeof(float));
  5232. const int ith = params->ith;
  5233. const int nth = params->nth;
  5234. GGML_TENSOR_UNARY_OP_LOCALS
  5235. float * dst_ptr = (float *) dst->data;
  5236. // TODO: optimize
  5237. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  5238. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  5239. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  5240. for (int64_t i3 = 0; i3 < ne3; ++i3) {
  5241. const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
  5242. const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  5243. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  5244. dst_ptr[dst_idx] = *src_ptr;
  5245. } else {
  5246. dst_ptr[dst_idx] = 0;
  5247. }
  5248. }
  5249. }
  5250. }
  5251. }
  5252. }
  5253. void ggml_compute_forward_pad(
  5254. const ggml_compute_params * params,
  5255. ggml_tensor * dst) {
  5256. const ggml_tensor * src0 = dst->src[0];
  5257. switch (src0->type) {
  5258. case GGML_TYPE_F32:
  5259. {
  5260. ggml_compute_forward_pad_f32(params, dst);
  5261. } break;
  5262. default:
  5263. {
  5264. GGML_ABORT("fatal error");
  5265. }
  5266. }
  5267. }
  5268. // ggml_compute_forward_pad_reflect_1d
  5269. void ggml_compute_forward_pad_reflect_1d(
  5270. const ggml_compute_params * params,
  5271. ggml_tensor * dst) {
  5272. const ggml_tensor * src0 = dst->src[0];
  5273. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  5274. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5275. const int ith = params->ith;
  5276. const int nth = params->nth;
  5277. const int32_t * opts = (const int32_t *) dst->op_params;
  5278. const int p0 = opts[0];
  5279. const int p1 = opts[1];
  5280. GGML_TENSOR_UNARY_OP_LOCALS
  5281. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5282. for (int64_t i2 = 0; i2 < ne2; i2++) {
  5283. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  5284. float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
  5285. float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
  5286. ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
  5287. for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
  5288. for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
  5289. }
  5290. }
  5291. }
  5292. }
  5293. // ggml_compute_forward_arange
  5294. static void ggml_compute_forward_arange_f32(
  5295. const ggml_compute_params * params,
  5296. ggml_tensor * dst) {
  5297. GGML_ASSERT(dst->nb[0] == sizeof(float));
  5298. const int ith = params->ith;
  5299. const int nth = params->nth;
  5300. const float start = ggml_get_op_params_f32(dst, 0);
  5301. const float stop = ggml_get_op_params_f32(dst, 1);
  5302. const float step = ggml_get_op_params_f32(dst, 2);
  5303. const int64_t steps = (int64_t) ceilf((stop - start) / step);
  5304. GGML_ASSERT(ggml_nelements(dst) == steps);
  5305. for (int64_t i = ith; i < steps; i+= nth) {
  5306. float value = start + step * i;
  5307. ((float *)dst->data)[i] = value;
  5308. }
  5309. }
  5310. void ggml_compute_forward_arange(
  5311. const ggml_compute_params * params,
  5312. ggml_tensor * dst) {
  5313. switch (dst->type) {
  5314. case GGML_TYPE_F32:
  5315. {
  5316. ggml_compute_forward_arange_f32(params, dst);
  5317. } break;
  5318. default:
  5319. {
  5320. GGML_ABORT("fatal error");
  5321. }
  5322. }
  5323. }
  5324. static void ggml_compute_forward_timestep_embedding_f32(
  5325. const ggml_compute_params * params,
  5326. ggml_tensor * dst) {
  5327. const ggml_tensor * src0 = dst->src[0];
  5328. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5329. const int ith = params->ith;
  5330. const int nth = params->nth;
  5331. GGML_TENSOR_UNARY_OP_LOCALS
  5332. const int dim = ggml_get_op_params_i32(dst, 0);
  5333. const int max_period = ggml_get_op_params_i32(dst, 1);
  5334. int half = dim / 2;
  5335. for (int64_t i = 0; i < ne00; i++) {
  5336. float * embed_data = (float *)((char *) dst->data + i*nb1);
  5337. for (int64_t j = ith; j < half; j += nth) {
  5338. float timestep = ((float *)src0->data)[i];
  5339. float freq = (float)expf(-logf(max_period) * j / half);
  5340. float arg = timestep * freq;
  5341. embed_data[j] = cosf(arg);
  5342. embed_data[j + half] = sinf(arg);
  5343. }
  5344. if (dim % 2 != 0 && ith == 0) {
  5345. embed_data[dim] = 0.f;
  5346. }
  5347. }
  5348. }
  5349. void ggml_compute_forward_timestep_embedding(
  5350. const ggml_compute_params * params,
  5351. ggml_tensor * dst) {
  5352. const ggml_tensor * src0 = dst->src[0];
  5353. switch (src0->type) {
  5354. case GGML_TYPE_F32:
  5355. {
  5356. ggml_compute_forward_timestep_embedding_f32(params, dst);
  5357. } break;
  5358. default:
  5359. {
  5360. GGML_ABORT("fatal error");
  5361. }
  5362. }
  5363. }
  5364. // ggml_compute_forward_argsort
  5365. static void ggml_compute_forward_argsort_f32(
  5366. const ggml_compute_params * params,
  5367. ggml_tensor * dst) {
  5368. const ggml_tensor * src0 = dst->src[0];
  5369. GGML_TENSOR_UNARY_OP_LOCALS
  5370. GGML_ASSERT(nb0 == sizeof(float));
  5371. const int ith = params->ith;
  5372. const int nth = params->nth;
  5373. const int64_t nr = ggml_nrows(src0);
  5374. ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
  5375. for (int64_t i = ith; i < nr; i += nth) {
  5376. int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
  5377. const float * src_data = (float *)((char *) src0->data + i*nb01);
  5378. for (int64_t j = 0; j < ne0; j++) {
  5379. dst_data[j] = j;
  5380. }
  5381. // C doesn't have a functional sort, so we do a bubble sort instead
  5382. for (int64_t j = 0; j < ne0; j++) {
  5383. for (int64_t k = j + 1; k < ne0; k++) {
  5384. if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
  5385. (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
  5386. int32_t tmp = dst_data[j];
  5387. dst_data[j] = dst_data[k];
  5388. dst_data[k] = tmp;
  5389. }
  5390. }
  5391. }
  5392. }
  5393. }
  5394. void ggml_compute_forward_argsort(
  5395. const ggml_compute_params * params,
  5396. ggml_tensor * dst) {
  5397. const ggml_tensor * src0 = dst->src[0];
  5398. switch (src0->type) {
  5399. case GGML_TYPE_F32:
  5400. {
  5401. ggml_compute_forward_argsort_f32(params, dst);
  5402. } break;
  5403. default:
  5404. {
  5405. GGML_ABORT("fatal error");
  5406. }
  5407. }
  5408. }
  5409. // ggml_compute_forward_flash_attn_ext
  5410. static void ggml_compute_forward_flash_attn_ext_f16(
  5411. const ggml_compute_params * params,
  5412. const ggml_tensor * q,
  5413. const ggml_tensor * k,
  5414. const ggml_tensor * v,
  5415. const ggml_tensor * mask,
  5416. ggml_tensor * dst) {
  5417. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  5418. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  5419. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  5420. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  5421. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  5422. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  5423. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  5424. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  5425. const int ith = params->ith;
  5426. const int nth = params->nth;
  5427. const int64_t DK = nek0;
  5428. const int64_t DV = nev0;
  5429. const int64_t N = neq1;
  5430. GGML_ASSERT(ne0 == DV);
  5431. GGML_ASSERT(ne2 == N);
  5432. // input tensor rows must be contiguous
  5433. GGML_ASSERT(nbq0 == ggml_type_size(q->type));
  5434. GGML_ASSERT(nbk0 == ggml_type_size(k->type));
  5435. GGML_ASSERT(nbv0 == ggml_type_size(v->type));
  5436. GGML_ASSERT(neq0 == DK);
  5437. GGML_ASSERT(nek0 == DK);
  5438. GGML_ASSERT(nev0 == DV);
  5439. GGML_ASSERT(neq1 == N);
  5440. // dst cannot be transposed or permuted
  5441. GGML_ASSERT(nb0 == sizeof(float));
  5442. GGML_ASSERT(nb0 <= nb1);
  5443. GGML_ASSERT(nb1 <= nb2);
  5444. GGML_ASSERT(nb2 <= nb3);
  5445. // broadcast factors
  5446. const int64_t rk2 = neq2/nek2;
  5447. const int64_t rk3 = neq3/nek3;
  5448. const int64_t rv2 = neq2/nev2;
  5449. const int64_t rv3 = neq3/nev3;
  5450. // parallelize by q rows using ggml_vec_dot_f32
  5451. // total rows in q
  5452. const int nr = neq1*neq2*neq3;
  5453. // rows per thread
  5454. const int dr = (nr + nth - 1)/nth;
  5455. // row range for this thread
  5456. const int ir0 = dr*ith;
  5457. const int ir1 = MIN(ir0 + dr, nr);
  5458. float scale = 1.0f;
  5459. float max_bias = 0.0f;
  5460. float logit_softcap = 0.0f;
  5461. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  5462. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  5463. memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
  5464. if (logit_softcap != 0) {
  5465. scale /= logit_softcap;
  5466. }
  5467. const uint32_t n_head = neq2;
  5468. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  5469. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  5470. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  5471. ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
  5472. ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
  5473. ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
  5474. ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
  5475. GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
  5476. GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
  5477. // loop over n_batch and n_head
  5478. for (int ir = ir0; ir < ir1; ++ir) {
  5479. // q indices
  5480. const int iq3 = ir/(neq2*neq1);
  5481. const int iq2 = (ir - iq3*neq2*neq1)/neq1;
  5482. const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
  5483. const uint32_t h = iq2; // head index
  5484. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  5485. float S = 0.0f; // sum
  5486. float M = -INFINITY; // maximum KQ value
  5487. float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
  5488. float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
  5489. ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
  5490. ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
  5491. if (v->type == GGML_TYPE_F16) {
  5492. memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
  5493. } else {
  5494. memset(VKQ32, 0, DV*sizeof(float));
  5495. }
  5496. const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
  5497. // k indices
  5498. const int ik3 = iq3 / rk3;
  5499. const int ik2 = iq2 / rk2;
  5500. // v indices
  5501. const int iv3 = iq3 / rv3;
  5502. const int iv2 = iq2 / rv2;
  5503. const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
  5504. q_to_vec_dot(pq, Q_q, DK);
  5505. // online softmax / attention
  5506. // loop over n_kv and n_head_kv
  5507. // ref: https://arxiv.org/pdf/2112.05682.pdf
  5508. for (int64_t ic = 0; ic < nek1; ++ic) {
  5509. const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
  5510. if (mv == -INFINITY) {
  5511. continue;
  5512. }
  5513. float s; // KQ value
  5514. const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
  5515. kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
  5516. s = s*scale; // scale KQ value
  5517. if (logit_softcap != 0.0f) {
  5518. s = logit_softcap*tanhf(s);
  5519. }
  5520. s += mv; // apply mask
  5521. const float Mold = M;
  5522. float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
  5523. float vs = 1.0f; // post-softmax KQ value, expf(s - M)
  5524. const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
  5525. if (v->type == GGML_TYPE_F16) {
  5526. if (s > M) {
  5527. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  5528. M = s;
  5529. ms = expf(Mold - M);
  5530. // V = V*expf(Mold - M)
  5531. ggml_vec_scale_f16(DV, VKQ16, ms);
  5532. } else {
  5533. // no new maximum, ms == 1.0f, vs != 1.0f
  5534. vs = expf(s - M);
  5535. }
  5536. // V += v*expf(s - M)
  5537. ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
  5538. } else {
  5539. if (s > M) {
  5540. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  5541. M = s;
  5542. ms = expf(Mold - M);
  5543. // V = V*expf(Mold - M)
  5544. ggml_vec_scale_f32(DV, VKQ32, ms);
  5545. } else {
  5546. // no new maximum, ms == 1.0f, vs != 1.0f
  5547. vs = expf(s - M);
  5548. }
  5549. // V += v*expf(s - M)
  5550. if (v_to_float) {
  5551. v_to_float(v_data, V32, DV);
  5552. ggml_vec_mad_f32(DV, VKQ32, V32, vs);
  5553. } else {
  5554. // V is F32
  5555. ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
  5556. }
  5557. }
  5558. S = S*ms + vs; // scale and increment sum with partial sum
  5559. }
  5560. if (v->type == GGML_TYPE_F16) {
  5561. for (int64_t d = 0; d < DV; ++d) {
  5562. VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
  5563. }
  5564. }
  5565. // V /= S
  5566. const float S_inv = 1.0f/S;
  5567. ggml_vec_scale_f32(DV, VKQ32, S_inv);
  5568. // dst indices
  5569. const int i1 = iq1;
  5570. const int i2 = iq2;
  5571. const int i3 = iq3;
  5572. // original
  5573. //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
  5574. // permute(0, 2, 1, 3)
  5575. memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
  5576. }
  5577. }
  5578. void ggml_compute_forward_flash_attn_ext(
  5579. const ggml_compute_params * params,
  5580. const ggml_tensor * q,
  5581. const ggml_tensor * k,
  5582. const ggml_tensor * v,
  5583. const ggml_tensor * mask,
  5584. ggml_tensor * dst) {
  5585. switch (dst->op_params[3]) {
  5586. case GGML_PREC_DEFAULT:
  5587. case GGML_PREC_F32:
  5588. {
  5589. // uses F32 accumulators
  5590. ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
  5591. } break;
  5592. default:
  5593. {
  5594. GGML_ABORT("fatal error");
  5595. }
  5596. }
  5597. }
  5598. // ggml_compute_forward_flash_attn_back
  5599. static void ggml_compute_forward_flash_attn_back_f32(
  5600. const ggml_compute_params * params,
  5601. const bool masked,
  5602. ggml_tensor * dst) {
  5603. const ggml_tensor * q = dst->src[0];
  5604. const ggml_tensor * k = dst->src[1];
  5605. const ggml_tensor * v = dst->src[2];
  5606. const ggml_tensor * d = dst->src[3];
  5607. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  5608. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  5609. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  5610. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  5611. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  5612. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  5613. GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
  5614. GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
  5615. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  5616. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  5617. const int ith = params->ith;
  5618. const int nth = params->nth;
  5619. const int64_t D = neq0;
  5620. const int64_t N = neq1;
  5621. const int64_t P = nek1 - N;
  5622. const int64_t M = P + N;
  5623. const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
  5624. const int mxDM = MAX(D, Mup);
  5625. // GGML_ASSERT(ne0 == D);
  5626. // GGML_ASSERT(ne1 == N);
  5627. GGML_ASSERT(P >= 0);
  5628. GGML_ASSERT(nbq0 == sizeof(float));
  5629. GGML_ASSERT(nbk0 == sizeof(float));
  5630. GGML_ASSERT(nbv0 == sizeof(float));
  5631. GGML_ASSERT(neq0 == D);
  5632. GGML_ASSERT(nek0 == D);
  5633. GGML_ASSERT(nev1 == D);
  5634. GGML_ASSERT(ned0 == D);
  5635. GGML_ASSERT(neq1 == N);
  5636. GGML_ASSERT(nek1 == N + P);
  5637. GGML_ASSERT(nev1 == D);
  5638. GGML_ASSERT(ned1 == N);
  5639. // dst cannot be transposed or permuted
  5640. GGML_ASSERT(nb0 == sizeof(float));
  5641. GGML_ASSERT(nb0 <= nb1);
  5642. GGML_ASSERT(nb1 <= nb2);
  5643. GGML_ASSERT(nb2 <= nb3);
  5644. if (ith == 0) {
  5645. memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
  5646. }
  5647. ggml_barrier(params->threadpool);
  5648. const int64_t elem_q = ggml_nelements(q);
  5649. const int64_t elem_k = ggml_nelements(k);
  5650. ggml_type result_type = dst->type;
  5651. GGML_ASSERT(ggml_blck_size(result_type) == 1);
  5652. const size_t tsize = ggml_type_size(result_type);
  5653. const size_t offs_q = 0;
  5654. const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
  5655. const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
  5656. void * grad_q = (char *) dst->data;
  5657. void * grad_k = (char *) dst->data + offs_k;
  5658. void * grad_v = (char *) dst->data + offs_v;
  5659. const size_t nbgq1 = nb0*neq0;
  5660. const size_t nbgq2 = nb0*neq0*neq1;
  5661. const size_t nbgq3 = nb0*neq0*neq1*neq2;
  5662. const size_t nbgk1 = nb0*nek0;
  5663. const size_t nbgk2 = nb0*nek0*nek1;
  5664. const size_t nbgk3 = nb0*nek0*nek1*neq2;
  5665. const size_t nbgv1 = nb0*nev0;
  5666. const size_t nbgv2 = nb0*nev0*nev1;
  5667. const size_t nbgv3 = nb0*nev0*nev1*neq2;
  5668. // parallelize by k rows using ggml_vec_dot_f32
  5669. // total rows in k
  5670. const int nr = nek2*nek3;
  5671. // rows per thread
  5672. const int dr = (nr + nth - 1)/nth;
  5673. // row range for this thread
  5674. const int ir0 = dr*ith;
  5675. const int ir1 = MIN(ir0 + dr, nr);
  5676. const float scale = 1.0f/sqrtf(D);
  5677. //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
  5678. // how often k2 (and v2) is repeated in q2
  5679. int nrep = neq2/nek2;
  5680. for (int ir = ir0; ir < ir1; ++ir) {
  5681. // q indices
  5682. const int ik3 = ir/(nek2);
  5683. const int ik2 = ir - ik3*nek2;
  5684. const int iq3 = ik3;
  5685. const int id3 = ik3;
  5686. const int iv3 = ik3;
  5687. const int iv2 = ik2;
  5688. for (int irep = 0; irep < nrep; ++irep) {
  5689. const int iq2 = ik2 + irep*nek2;
  5690. const int id2 = iq2;
  5691. // (ik2 + irep*nek2) % nek2 == ik2
  5692. for (int iq1 = 0; iq1 < neq1; ++iq1) {
  5693. const int id1 = iq1;
  5694. // not sure about CACHE_LINE_SIZE_F32..
  5695. // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
  5696. float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
  5697. float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
  5698. for (int i = M; i < Mup; ++i) {
  5699. S[i] = -INFINITY;
  5700. }
  5701. const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
  5702. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  5703. // k indices
  5704. const int ik1 = ic;
  5705. // S indices
  5706. const int i1 = ik1;
  5707. ggml_vec_dot_f32(neq0,
  5708. S + i1, 0,
  5709. (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
  5710. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
  5711. }
  5712. // scale
  5713. ggml_vec_scale_f32(masked_begin, S, scale);
  5714. for (int64_t i = masked_begin; i < M; i++) {
  5715. S[i] = -INFINITY;
  5716. }
  5717. // softmax
  5718. // exclude known -INF S[..] values from max and loop
  5719. // dont forget to set their SM values to zero
  5720. {
  5721. float max = -INFINITY;
  5722. ggml_vec_max_f32(masked_begin, &max, S);
  5723. ggml_float sum = 0.0;
  5724. {
  5725. #ifdef GGML_SOFT_MAX_ACCELERATE
  5726. max = -max;
  5727. vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
  5728. vvexpf(SM, SM, &Mup);
  5729. ggml_vec_sum_f32(Mup, &sum, SM);
  5730. #else
  5731. sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
  5732. #endif
  5733. }
  5734. assert(sum > 0.0);
  5735. sum = 1.0/sum;
  5736. ggml_vec_scale_f32(masked_begin, SM, sum);
  5737. }
  5738. // step-by-step explanation
  5739. {
  5740. // forward-process shape grads from backward process
  5741. // parallel_for ik2,ik3:
  5742. // for irep:
  5743. // iq2 = ik2 + irep*nek2
  5744. // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
  5745. // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
  5746. // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
  5747. // for iq1:
  5748. // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
  5749. // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
  5750. // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
  5751. // S0 = -Inf [D,1,1,1]
  5752. // ~S1[i] = dot(kcur[:D,i], qcur)
  5753. // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
  5754. // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
  5755. // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  5756. // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
  5757. // ~S5[i] = dot(vcur[:,i], S4)
  5758. // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
  5759. // ~dst[i,iq1,iq2,iq3] = S5[i] ^
  5760. // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
  5761. // dst backward-/ grad[dst] = d
  5762. //
  5763. // output gradients with their dependencies:
  5764. //
  5765. // grad[kcur] = grad[S1].T @ qcur
  5766. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  5767. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  5768. // grad[S4] = grad[S5] @ vcur
  5769. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  5770. // grad[qcur] = grad[S1] @ kcur
  5771. // grad[vcur] = grad[S5].T @ S4
  5772. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  5773. //
  5774. // in post-order:
  5775. //
  5776. // S1 = qcur @ kcur.T
  5777. // S2 = S1 * scale
  5778. // S3 = diag_mask_inf(S2, P)
  5779. // S4 = softmax(S3)
  5780. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  5781. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  5782. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  5783. // grad[qcur] = grad[S1] @ kcur
  5784. // grad[kcur] = grad[S1].T @ qcur
  5785. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  5786. //
  5787. // using less variables (SM=S4):
  5788. //
  5789. // S = diag_mask_inf(qcur @ kcur.T * scale, P)
  5790. // SM = softmax(S)
  5791. // S = d[:D,iq1,iq2,iq3] @ vcur
  5792. // dot_SM_gradSM = dot(SM, S)
  5793. // S = SM * (S - dot(SM, S))
  5794. // S = diag_mask_zero(S, P) * scale
  5795. //
  5796. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  5797. // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
  5798. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  5799. }
  5800. // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  5801. // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  5802. // for ic:
  5803. // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
  5804. // exclude known future zero S[..] values from operation
  5805. ggml_vec_set_f32(masked_begin, S, 0);
  5806. for (int64_t ic = 0; ic < D; ++ic) {
  5807. ggml_vec_mad_f32(masked_begin,
  5808. S,
  5809. (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
  5810. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  5811. }
  5812. // S = SM * (S - dot(SM, S))
  5813. float dot_SM_gradSM = 0;
  5814. ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
  5815. ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
  5816. ggml_vec_mul_f32 (masked_begin, S, S, SM);
  5817. // S = diag_mask_zero(S, P) * scale
  5818. // already done by above ggml_vec_set_f32
  5819. // exclude known zero S[..] values from operation
  5820. ggml_vec_scale_f32(masked_begin, S, scale);
  5821. // S shape [M,1]
  5822. // SM shape [M,1]
  5823. // kcur shape [D,M]
  5824. // qcur shape [D,1]
  5825. // vcur shape [M,D]
  5826. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  5827. // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
  5828. // for ic:
  5829. // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
  5830. // exclude known zero S[..] values from loop
  5831. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  5832. ggml_vec_mad_f32(D,
  5833. (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
  5834. (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
  5835. S[ic]);
  5836. }
  5837. // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
  5838. // for ic:
  5839. // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
  5840. // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
  5841. // exclude known zero S[..] values from loop
  5842. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  5843. ggml_vec_mad_f32(D,
  5844. (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
  5845. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
  5846. S[ic]);
  5847. }
  5848. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  5849. // for ic:
  5850. // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
  5851. // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
  5852. // exclude known zero SM[..] values from mad
  5853. for (int64_t ic = 0; ic < D; ++ic) {
  5854. ggml_vec_mad_f32(masked_begin,
  5855. (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
  5856. SM,
  5857. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  5858. }
  5859. }
  5860. }
  5861. }
  5862. }
  5863. void ggml_compute_forward_flash_attn_back(
  5864. const ggml_compute_params * params,
  5865. const bool masked,
  5866. ggml_tensor * dst) {
  5867. const ggml_tensor * q = dst->src[0];
  5868. switch (q->type) {
  5869. case GGML_TYPE_F32:
  5870. {
  5871. ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
  5872. } break;
  5873. default:
  5874. {
  5875. GGML_ABORT("fatal error");
  5876. }
  5877. }
  5878. }
  5879. // ggml_compute_forward_ssm_conv
  5880. static void ggml_compute_forward_ssm_conv_f32(
  5881. const ggml_compute_params * params,
  5882. ggml_tensor * dst) {
  5883. const ggml_tensor * src0 = dst->src[0]; // conv_x
  5884. const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
  5885. const int ith = params->ith;
  5886. const int nth = params->nth;
  5887. const int nc = src1->ne[0]; // d_conv
  5888. const int ncs = src0->ne[0]; // d_conv - 1 + n_t
  5889. const int nr = src0->ne[1]; // d_inner
  5890. const int n_t = dst->ne[1]; // tokens per sequence
  5891. const int n_s = dst->ne[2]; // number of sequences in the batch
  5892. GGML_ASSERT( dst->ne[0] == nr);
  5893. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5894. GGML_ASSERT(src1->nb[0] == sizeof(float));
  5895. GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
  5896. // rows per thread
  5897. const int dr = (nr + nth - 1)/nth;
  5898. // row range for this thread
  5899. const int ir0 = dr*ith;
  5900. const int ir1 = MIN(ir0 + dr, nr);
  5901. const int ir = ir1 - ir0;
  5902. for (int i3 = 0; i3 < n_s; ++i3) {
  5903. for (int i2 = 0; i2 < n_t; ++i2) {
  5904. // {d_conv - 1 + n_t, d_inner, n_seqs}
  5905. // sliding window
  5906. const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
  5907. const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
  5908. float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
  5909. // TODO: transpose the output for smaller strides for big batches?
  5910. // d_inner
  5911. for (int i1 = 0; i1 < ir; ++i1) {
  5912. // rowwise dot product
  5913. // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
  5914. float sumf = 0.0f;
  5915. // d_conv
  5916. for (int i0 = 0; i0 < nc; ++i0) {
  5917. sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
  5918. }
  5919. x[i1] = sumf;
  5920. }
  5921. }
  5922. }
  5923. }
  5924. void ggml_compute_forward_ssm_conv(
  5925. const ggml_compute_params * params,
  5926. ggml_tensor * dst) {
  5927. switch (dst->src[0]->type) {
  5928. case GGML_TYPE_F32:
  5929. {
  5930. ggml_compute_forward_ssm_conv_f32(params, dst);
  5931. } break;
  5932. default:
  5933. {
  5934. GGML_ABORT("fatal error");
  5935. }
  5936. }
  5937. }
  5938. // ggml_compute_forward_ssm_scan
  5939. static void ggml_compute_forward_ssm_scan_f32(
  5940. const ggml_compute_params * params,
  5941. ggml_tensor * dst) {
  5942. const ggml_tensor * src0 = dst->src[0]; // s
  5943. const ggml_tensor * src1 = dst->src[1]; // x
  5944. const ggml_tensor * src2 = dst->src[2]; // dt
  5945. const ggml_tensor * src3 = dst->src[3]; // A
  5946. const ggml_tensor * src4 = dst->src[4]; // B
  5947. const ggml_tensor * src5 = dst->src[5]; // C
  5948. const int ith = params->ith;
  5949. const int nth = params->nth;
  5950. const int64_t nc = src0->ne[0]; // d_state
  5951. const int64_t nr = src0->ne[1]; // d_inner
  5952. const int64_t n_t = src1->ne[1]; // number of tokens per sequence
  5953. const int64_t n_s = src0->ne[2]; // number of sequences in the batch
  5954. GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
  5955. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5956. GGML_ASSERT(src1->nb[0] == sizeof(float));
  5957. GGML_ASSERT(src2->nb[0] == sizeof(float));
  5958. GGML_ASSERT(src3->nb[0] == sizeof(float));
  5959. GGML_ASSERT(src4->nb[0] == sizeof(float));
  5960. GGML_ASSERT(src5->nb[0] == sizeof(float));
  5961. // required for the dot product between s and C
  5962. GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
  5963. // required for per-sequence offsets for states
  5964. GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
  5965. // required to get correct offset for state destination (i.e. src1->nb[3])
  5966. GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
  5967. // rows per thread
  5968. const int dr = (nr + nth - 1)/nth;
  5969. // row range for this thread
  5970. const int ir0 = dr*ith;
  5971. const int ir1 = MIN(ir0 + dr, nr);
  5972. const int ir = ir1 - ir0;
  5973. for (int i3 = 0; i3 < n_s; ++i3) {
  5974. for (int i2 = 0; i2 < n_t; ++i2) {
  5975. const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
  5976. const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
  5977. const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
  5978. const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
  5979. const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
  5980. const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
  5981. float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
  5982. float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
  5983. // use the output as the source for the next token-wise iterations
  5984. if (i2 > 0) { s0 = s; }
  5985. // d_inner
  5986. for (int i1 = 0; i1 < ir; ++i1) {
  5987. // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
  5988. float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
  5989. float x_dt = x[i1] * dt_soft_plus;
  5990. float sumf = 0.0f;
  5991. // d_state
  5992. for (int i0 = 0; i0 < nc; ++i0) {
  5993. int i = i0 + i1*nc;
  5994. // state = prev_state * dA + dB * x
  5995. float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
  5996. // y = rowwise_dotprod(state, C)
  5997. sumf += state * C[i0];
  5998. s[i] = state;
  5999. }
  6000. y[i1] = sumf;
  6001. }
  6002. }
  6003. }
  6004. }
  6005. void ggml_compute_forward_ssm_scan(
  6006. const ggml_compute_params * params,
  6007. ggml_tensor * dst) {
  6008. switch (dst->src[0]->type) {
  6009. case GGML_TYPE_F32:
  6010. {
  6011. ggml_compute_forward_ssm_scan_f32(params, dst);
  6012. } break;
  6013. default:
  6014. {
  6015. GGML_ABORT("fatal error");
  6016. }
  6017. }
  6018. }
  6019. // ggml_compute_forward_win_part
  6020. static void ggml_compute_forward_win_part_f32(
  6021. const ggml_compute_params * params,
  6022. ggml_tensor * dst) {
  6023. GGML_UNUSED(params);
  6024. const ggml_tensor * src0 = dst->src[0];
  6025. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  6026. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6027. const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
  6028. const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
  6029. const int32_t w = ((const int32_t *)(dst->op_params))[2];
  6030. assert(ne00 == ne0);
  6031. assert(ne3 == nep0*nep1);
  6032. // TODO: optimize / multi-thread
  6033. for (int py = 0; py < nep1; ++py) {
  6034. for (int px = 0; px < nep0; ++px) {
  6035. const int64_t i3 = py*nep0 + px;
  6036. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  6037. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  6038. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  6039. const int64_t i02 = py*w + i2;
  6040. const int64_t i01 = px*w + i1;
  6041. const int64_t i00 = i0;
  6042. const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
  6043. const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
  6044. if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
  6045. ((float *) dst->data)[i] = 0.0f;
  6046. } else {
  6047. ((float *) dst->data)[i] = ((float *) src0->data)[j];
  6048. }
  6049. }
  6050. }
  6051. }
  6052. }
  6053. }
  6054. }
  6055. void ggml_compute_forward_win_part(
  6056. const ggml_compute_params * params,
  6057. ggml_tensor * dst) {
  6058. const ggml_tensor * src0 = dst->src[0];
  6059. switch (src0->type) {
  6060. case GGML_TYPE_F32:
  6061. {
  6062. ggml_compute_forward_win_part_f32(params, dst);
  6063. } break;
  6064. default:
  6065. {
  6066. GGML_ABORT("fatal error");
  6067. }
  6068. }
  6069. }
  6070. // ggml_compute_forward_win_unpart
  6071. static void ggml_compute_forward_win_unpart_f32(
  6072. const ggml_compute_params * params,
  6073. ggml_tensor * dst) {
  6074. GGML_UNUSED(params);
  6075. const ggml_tensor * src0 = dst->src[0];
  6076. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  6077. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6078. const int32_t w = ((const int32_t *)(dst->op_params))[0];
  6079. // padding
  6080. const int px = (w - ne1%w)%w;
  6081. //const int py = (w - ne2%w)%w;
  6082. const int npx = (px + ne1)/w;
  6083. //const int npy = (py + ne2)/w;
  6084. assert(ne0 == ne00);
  6085. // TODO: optimize / multi-thread
  6086. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  6087. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  6088. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  6089. const int ip2 = i2/w;
  6090. const int ip1 = i1/w;
  6091. const int64_t i02 = i2%w;
  6092. const int64_t i01 = i1%w;
  6093. const int64_t i00 = i0;
  6094. const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
  6095. const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
  6096. ((float *) dst->data)[j] = ((float *) src0->data)[i];
  6097. }
  6098. }
  6099. }
  6100. }
  6101. void ggml_compute_forward_win_unpart(
  6102. const ggml_compute_params * params,
  6103. ggml_tensor * dst) {
  6104. const ggml_tensor * src0 = dst->src[0];
  6105. switch (src0->type) {
  6106. case GGML_TYPE_F32:
  6107. {
  6108. ggml_compute_forward_win_unpart_f32(params, dst);
  6109. } break;
  6110. default:
  6111. {
  6112. GGML_ABORT("fatal error");
  6113. }
  6114. }
  6115. }
  6116. //gmml_compute_forward_unary
  6117. void ggml_compute_forward_unary(
  6118. const ggml_compute_params * params,
  6119. ggml_tensor * dst) {
  6120. const ggml_unary_op op = ggml_get_unary_op(dst);
  6121. switch (op) {
  6122. case GGML_UNARY_OP_ABS:
  6123. {
  6124. ggml_compute_forward_abs(params, dst);
  6125. } break;
  6126. case GGML_UNARY_OP_SGN:
  6127. {
  6128. ggml_compute_forward_sgn(params, dst);
  6129. } break;
  6130. case GGML_UNARY_OP_NEG:
  6131. {
  6132. ggml_compute_forward_neg(params, dst);
  6133. } break;
  6134. case GGML_UNARY_OP_STEP:
  6135. {
  6136. ggml_compute_forward_step(params, dst);
  6137. } break;
  6138. case GGML_UNARY_OP_TANH:
  6139. {
  6140. ggml_compute_forward_tanh(params, dst);
  6141. } break;
  6142. case GGML_UNARY_OP_ELU:
  6143. {
  6144. ggml_compute_forward_elu(params, dst);
  6145. } break;
  6146. case GGML_UNARY_OP_RELU:
  6147. {
  6148. ggml_compute_forward_relu(params, dst);
  6149. } break;
  6150. case GGML_UNARY_OP_SIGMOID:
  6151. {
  6152. ggml_compute_forward_sigmoid(params, dst);
  6153. } break;
  6154. case GGML_UNARY_OP_GELU:
  6155. {
  6156. ggml_compute_forward_gelu(params, dst);
  6157. } break;
  6158. case GGML_UNARY_OP_GELU_QUICK:
  6159. {
  6160. ggml_compute_forward_gelu_quick(params, dst);
  6161. } break;
  6162. case GGML_UNARY_OP_SILU:
  6163. {
  6164. ggml_compute_forward_silu(params, dst);
  6165. } break;
  6166. case GGML_UNARY_OP_HARDSWISH:
  6167. {
  6168. ggml_compute_forward_hardswish(params, dst);
  6169. } break;
  6170. case GGML_UNARY_OP_HARDSIGMOID:
  6171. {
  6172. ggml_compute_forward_hardsigmoid(params, dst);
  6173. } break;
  6174. case GGML_UNARY_OP_EXP:
  6175. {
  6176. ggml_compute_forward_exp(params, dst);
  6177. } break;
  6178. default:
  6179. {
  6180. GGML_ABORT("fatal error");
  6181. }
  6182. }
  6183. }
  6184. // ggml_compute_forward_get_rel_pos
  6185. static void ggml_compute_forward_get_rel_pos_f16(
  6186. const ggml_compute_params * params,
  6187. ggml_tensor * dst) {
  6188. GGML_UNUSED(params);
  6189. const ggml_tensor * src0 = dst->src[0];
  6190. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
  6191. GGML_TENSOR_UNARY_OP_LOCALS
  6192. const int64_t w = ne1;
  6193. ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
  6194. ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
  6195. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  6196. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  6197. const int64_t pos = (w - i1 - 1) + i2;
  6198. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  6199. dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
  6200. }
  6201. }
  6202. }
  6203. }
  6204. void ggml_compute_forward_get_rel_pos(
  6205. const ggml_compute_params * params,
  6206. ggml_tensor * dst) {
  6207. const ggml_tensor * src0 = dst->src[0];
  6208. switch (src0->type) {
  6209. case GGML_TYPE_F16:
  6210. case GGML_TYPE_BF16:
  6211. {
  6212. ggml_compute_forward_get_rel_pos_f16(params, dst);
  6213. } break;
  6214. default:
  6215. {
  6216. GGML_ABORT("fatal error");
  6217. }
  6218. }
  6219. }
  6220. // ggml_compute_forward_add_rel_pos
  6221. static void ggml_compute_forward_add_rel_pos_f32(
  6222. const ggml_compute_params * params,
  6223. ggml_tensor * dst) {
  6224. const ggml_tensor * src0 = dst->src[0];
  6225. const ggml_tensor * src1 = dst->src[1];
  6226. const ggml_tensor * src2 = dst->src[2];
  6227. const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
  6228. if (!inplace) {
  6229. if (params->ith == 0) {
  6230. memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
  6231. }
  6232. ggml_barrier(params->threadpool);
  6233. }
  6234. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
  6235. float * src1_data = (float *) src1->data;
  6236. float * src2_data = (float *) src2->data;
  6237. float * dst_data = (float *) dst->data;
  6238. const int64_t ne10 = src1->ne[0];
  6239. const int64_t ne11 = src1->ne[1];
  6240. const int64_t ne12 = src1->ne[2];
  6241. const int64_t ne13 = src1->ne[3];
  6242. const int ith = params->ith;
  6243. const int nth = params->nth;
  6244. // total patches in dst
  6245. const int np = ne13;
  6246. // patches per thread
  6247. const int dp = (np + nth - 1)/nth;
  6248. // patch range for this thread
  6249. const int ip0 = dp*ith;
  6250. const int ip1 = MIN(ip0 + dp, np);
  6251. for (int64_t i13 = ip0; i13 < ip1; ++i13) {
  6252. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  6253. for (int64_t i11 = 0; i11 < ne11; ++i11) {
  6254. const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
  6255. for (int64_t i10 = 0; i10 < ne10; ++i10) {
  6256. const int64_t jp0 = jp1 + i10;
  6257. const float src1_e = src1_data[jp0];
  6258. const float src2_e = src2_data[jp0];
  6259. const int64_t jdh = jp0 * ne10;
  6260. const int64_t jdw = jdh - (ne10 - 1) * i10;
  6261. for (int64_t j = 0; j < ne10; ++j) {
  6262. dst_data[jdh + j ] += src2_e;
  6263. dst_data[jdw + j*ne10] += src1_e;
  6264. }
  6265. }
  6266. }
  6267. }
  6268. }
  6269. }
  6270. void ggml_compute_forward_add_rel_pos(
  6271. const ggml_compute_params * params,
  6272. ggml_tensor * dst) {
  6273. const ggml_tensor * src0 = dst->src[0];
  6274. switch (src0->type) {
  6275. case GGML_TYPE_F32:
  6276. {
  6277. ggml_compute_forward_add_rel_pos_f32(params, dst);
  6278. } break;
  6279. default:
  6280. {
  6281. GGML_ABORT("fatal error");
  6282. }
  6283. }
  6284. }
  6285. // ggml_compute_forward_rwkv_wkv6
  6286. static void ggml_compute_forward_rwkv_wkv6_f32(
  6287. const ggml_compute_params * params,
  6288. ggml_tensor * dst) {
  6289. const int64_t T = dst->src[1]->ne[2];
  6290. const int64_t C = dst->ne[0];
  6291. const int64_t HEADS = dst->src[1]->ne[1];
  6292. const int64_t n_seqs = dst->src[5]->ne[1];
  6293. const int64_t head_size = C / HEADS;
  6294. float * dst_data = (float *) dst->data;
  6295. float * state = ((float *) dst->data) + C * T;
  6296. const int ith = params->ith;
  6297. const int nth = params->nth;
  6298. if (ith >= HEADS) {
  6299. return;
  6300. }
  6301. const int h_start = (HEADS * ith) / nth;
  6302. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  6303. (HEADS * (ith + 1)) / nth : HEADS;
  6304. float * k = (float *) dst->src[0]->data;
  6305. float * v = (float *) dst->src[1]->data;
  6306. float * r = (float *) dst->src[2]->data;
  6307. float * time_faaaa = (float *) dst->src[3]->data;
  6308. float * time_decay = (float *) dst->src[4]->data;
  6309. size_t t_stride = HEADS * head_size; // Same to C
  6310. size_t h_stride = C / HEADS;
  6311. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  6312. size_t h_stride_2d = head_size * head_size;
  6313. if (ith == 0) {
  6314. memset(dst_data, 0, T * C * sizeof(float));
  6315. }
  6316. ggml_barrier(params->threadpool);
  6317. #if defined(__AVX__) && !defined(__AVX512F__)
  6318. #define GGML_F32X GGML_F32x8
  6319. #define GGML_F32X_SET1 GGML_F32x8_SET1
  6320. #define GGML_F32X_LOAD GGML_F32x8_LOAD
  6321. #define GGML_F32X_STORE GGML_F32x8_STORE
  6322. #define GGML_F32X_MUL GGML_F32x8_MUL
  6323. #define GGML_F32X_FMA GGML_F32x8_FMA
  6324. #define WKV_VECTOR_SIZE 8
  6325. #elif defined(__AVX512F__)
  6326. #define GGML_F32X GGML_F32x16
  6327. #define GGML_F32X_SET1 GGML_F32x16_SET1
  6328. #define GGML_F32X_LOAD GGML_F32x16_LOAD
  6329. #define GGML_F32X_STORE GGML_F32x16_STORE
  6330. #define GGML_F32X_MUL GGML_F32x16_MUL
  6331. #define GGML_F32X_FMA GGML_F32x16_FMA
  6332. #define WKV_VECTOR_SIZE 16
  6333. #elif defined(__ARM_NEON) && defined(__aarch64__)
  6334. #define GGML_F32X GGML_F32x4
  6335. #define GGML_F32X_SET1 GGML_F32x4_SET1
  6336. #define GGML_F32X_LOAD GGML_F32x4_LOAD
  6337. #define GGML_F32X_STORE GGML_F32x4_STORE
  6338. #define GGML_F32X_MUL GGML_F32x4_MUL
  6339. #define GGML_F32X_FMA GGML_F32x4_FMA
  6340. #define WKV_VECTOR_SIZE 4
  6341. #endif
  6342. #ifdef WKV_VECTOR_SIZE
  6343. const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
  6344. for (int64_t t = 0; t < T; t++) {
  6345. size_t t_offset = t * t_stride;
  6346. size_t state_offset = head_size * C * (t / (T / n_seqs));
  6347. float * state_cur = state + state_offset;
  6348. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  6349. for (int64_t h = h_start; h < h_end; h++) {
  6350. size_t h_offset = h * h_stride;
  6351. size_t t_h_offset = t_offset + h_offset;
  6352. size_t h_2d_offset = h * h_stride_2d;
  6353. for (int64_t i = 0; i < head_size; i++) {
  6354. size_t t_h_i_offset = t_h_offset + i;
  6355. size_t h_i_offset = h_offset + i;
  6356. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  6357. float k_val = k[t_h_i_offset];
  6358. float r_val = r[t_h_i_offset];
  6359. float time_faaaa_val = time_faaaa[h_i_offset];
  6360. float time_decay_val = time_decay[t_h_i_offset];
  6361. // Broadcast scalar values to vectors
  6362. GGML_F32X k_vec = GGML_F32X_SET1(k_val);
  6363. GGML_F32X r_vec = GGML_F32X_SET1(r_val);
  6364. GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
  6365. GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
  6366. for (int64_t j = 0; j < vec_count; j++) {
  6367. size_t base_j = j * WKV_VECTOR_SIZE;
  6368. size_t t_h_j_offset = t_h_offset + base_j;
  6369. size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
  6370. // Load x elements at once
  6371. GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
  6372. GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
  6373. GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
  6374. // Compute kv = v * k
  6375. GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
  6376. // Compute temp = kv * time_faaaa + prev_state
  6377. GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
  6378. // Update dst: dst += temp * r
  6379. dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
  6380. GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
  6381. // Update state: state = prev_state * time_decay + kv
  6382. GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
  6383. GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
  6384. }
  6385. // Handle remaining elements, this will not be used.
  6386. for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
  6387. size_t t_h_j_offset = t_h_offset + j;
  6388. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  6389. float v_val = v[t_h_j_offset];
  6390. float kv_val = v_val * k_val;
  6391. float prev_state_val = state_prev[h_2d_i_j_offset];
  6392. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  6393. dst_data[t_h_j_offset] += temp_val * r_val;
  6394. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  6395. }
  6396. }
  6397. }
  6398. }
  6399. #else
  6400. // basically fused operations:
  6401. // dst = r @ (time_faaaa * (k @ v) + state),
  6402. // state = time_decay * state + (k @ v),
  6403. // recursive through each token
  6404. for (int64_t t = 0; t < T; t++) {
  6405. size_t t_offset = t * t_stride;
  6406. size_t state_offset = head_size * C * (t / (T / n_seqs));
  6407. float * state_cur = state + state_offset;
  6408. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  6409. for (int64_t h = h_start; h < h_end; h++) {
  6410. size_t h_offset = h * h_stride;
  6411. size_t t_h_offset = t_offset + h_offset;
  6412. size_t h_2d_offset = h * h_stride_2d;
  6413. for (int64_t i = 0; i < head_size; i++) {
  6414. size_t t_h_i_offset = t_h_offset + i;
  6415. size_t h_i_offset = h_offset + i;
  6416. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  6417. float k_val = k[t_h_i_offset];
  6418. float r_val = r[t_h_i_offset];
  6419. float time_faaaa_val = time_faaaa[h_i_offset];
  6420. // RWKV v6: different time_decay for each token.
  6421. float time_decay_val = time_decay[t_h_i_offset];
  6422. for (int64_t j = 0; j < head_size; j++) {
  6423. size_t t_h_j_offset = t_h_offset + j;
  6424. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  6425. float v_val = v[t_h_j_offset];
  6426. float kv_val = v_val * k_val;
  6427. float prev_state_val = state_prev[h_2d_i_j_offset];
  6428. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  6429. dst_data[t_h_j_offset] += temp_val * r_val;
  6430. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  6431. }
  6432. }
  6433. }
  6434. }
  6435. #endif
  6436. }
  6437. void ggml_compute_forward_rwkv_wkv6(
  6438. const ggml_compute_params * params,
  6439. ggml_tensor * dst) {
  6440. const ggml_tensor * src0 = dst->src[0];
  6441. switch (src0->type) {
  6442. case GGML_TYPE_F32:
  6443. {
  6444. ggml_compute_forward_rwkv_wkv6_f32(params, dst);
  6445. } break;
  6446. default:
  6447. {
  6448. GGML_ABORT("fatal error");
  6449. }
  6450. }
  6451. }
  6452. // ggml_compute_forward_gla
  6453. static void ggml_compute_forward_gla_f32(
  6454. const ggml_compute_params * params,
  6455. ggml_tensor * dst) {
  6456. const int64_t T = dst->src[1]->ne[2];
  6457. const int64_t C = dst->ne[0];
  6458. const int64_t HEADS = dst->src[1]->ne[1];
  6459. const int64_t n_seqs = dst->src[4]->ne[1];
  6460. const int64_t head_size = C / HEADS;
  6461. const float scale = ggml_get_op_params_f32(dst, 0);
  6462. float * dst_data = (float *) dst->data;
  6463. float * state = ((float *) dst->data) + C * T;
  6464. const int ith = params->ith;
  6465. const int nth = params->nth;
  6466. if (ith >= HEADS) {
  6467. return;
  6468. }
  6469. const int h_start = (HEADS * ith) / nth;
  6470. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  6471. (HEADS * (ith + 1)) / nth : HEADS;
  6472. float * k = (float *) dst->src[0]->data;
  6473. float * v = (float *) dst->src[1]->data;
  6474. float * q = (float *) dst->src[2]->data;
  6475. float * g = (float *) dst->src[3]->data;
  6476. size_t t_stride = HEADS * head_size; // Same to C
  6477. size_t h_stride = C / HEADS;
  6478. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  6479. size_t h_stride_2d = head_size * head_size;
  6480. if (ith == 0) {
  6481. memset(dst_data, 0, T * C * sizeof(float));
  6482. }
  6483. ggml_barrier(params->threadpool);
  6484. #if defined(__AVX__) && !defined(__AVX512F__)
  6485. #define GGML_F32X GGML_F32x8
  6486. #define GGML_F32X_SET1 GGML_F32x8_SET1
  6487. #define GGML_F32X_LOAD GGML_F32x8_LOAD
  6488. #define GGML_F32X_STORE GGML_F32x8_STORE
  6489. #define GGML_F32X_MUL GGML_F32x8_MUL
  6490. #define GGML_F32X_FMA GGML_F32x8_FMA
  6491. #define GLA_VECTOR_SIZE 8
  6492. #elif defined(__AVX512F__)
  6493. #define GGML_F32X GGML_F32x16
  6494. #define GGML_F32X_SET1 GGML_F32x16_SET1
  6495. #define GGML_F32X_LOAD GGML_F32x16_LOAD
  6496. #define GGML_F32X_STORE GGML_F32x16_STORE
  6497. #define GGML_F32X_MUL GGML_F32x16_MUL
  6498. #define GGML_F32X_FMA GGML_F32x16_FMA
  6499. #define GLA_VECTOR_SIZE 16
  6500. #elif defined(__ARM_NEON) && defined(__aarch64__)
  6501. #define GGML_F32X GGML_F32x4
  6502. #define GGML_F32X_SET1 GGML_F32x4_SET1
  6503. #define GGML_F32X_LOAD GGML_F32x4_LOAD
  6504. #define GGML_F32X_STORE GGML_F32x4_STORE
  6505. #define GGML_F32X_MUL GGML_F32x4_MUL
  6506. #define GGML_F32X_FMA GGML_F32x4_FMA
  6507. #define GLA_VECTOR_SIZE 4
  6508. #endif
  6509. #ifdef GLA_VECTOR_SIZE
  6510. const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
  6511. for (int64_t t = 0; t < T; t++) {
  6512. size_t t_offset = t * t_stride;
  6513. size_t state_offset = head_size * C * (t / (T / n_seqs));
  6514. float * state_cur = state + state_offset;
  6515. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
  6516. for (int64_t h = h_start; h < h_end; h++) {
  6517. size_t h_offset = h * h_stride;
  6518. size_t t_h_offset = t_offset + h_offset;
  6519. size_t h_2d_offset = h * h_stride_2d;
  6520. for (int64_t i = 0; i < head_size; i++) {
  6521. size_t t_h_i_offset = t_h_offset + i;
  6522. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  6523. float k_val = k[t_h_i_offset];
  6524. float q_val = q[t_h_i_offset] * scale;
  6525. float g_val = g[t_h_i_offset];
  6526. // Broadcast scalar values to vectors
  6527. GGML_F32X k_vec = GGML_F32X_SET1(k_val);
  6528. GGML_F32X q_vec = GGML_F32X_SET1(q_val);
  6529. GGML_F32X g_vec = GGML_F32X_SET1(g_val);
  6530. for (int64_t j = 0; j < vec_count; j++) {
  6531. size_t base_j = j * GLA_VECTOR_SIZE;
  6532. size_t t_h_j_offset = t_h_offset + base_j;
  6533. size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
  6534. // Load x elements at once
  6535. GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
  6536. GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
  6537. GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
  6538. // Compute kv = v * k
  6539. GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
  6540. // Compute temp = prev_state * g + kv
  6541. GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
  6542. // Update dst: dst += temp * q
  6543. dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
  6544. GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
  6545. // Update state
  6546. GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
  6547. }
  6548. // Handle remaining elements, this will not be used.
  6549. for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
  6550. size_t t_h_j_offset = t_h_offset + j;
  6551. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  6552. float v_val = v[t_h_j_offset];
  6553. float kv_val = v_val * k_val;
  6554. float prev_state_val = state_prev[h_2d_i_j_offset];
  6555. float temp_val = kv_val + prev_state_val * g_val;
  6556. dst_data[t_h_j_offset] += temp_val * q_val;
  6557. state_cur[h_2d_i_j_offset] = temp_val;
  6558. }
  6559. }
  6560. }
  6561. }
  6562. #else
  6563. for (int64_t t = 0; t < T; t++) {
  6564. size_t t_offset = t * t_stride;
  6565. size_t state_offset = head_size * C * (t / (T / n_seqs));
  6566. float * state_cur = state + state_offset;
  6567. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
  6568. for (int64_t h = h_start; h < h_end; h++) {
  6569. size_t h_offset = h * h_stride;
  6570. size_t t_h_offset = t_offset + h_offset;
  6571. size_t h_2d_offset = h * h_stride_2d;
  6572. for (int64_t i = 0; i < head_size; i++) {
  6573. size_t t_h_i_offset = t_h_offset + i;
  6574. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  6575. float k_val = k[t_h_i_offset];
  6576. float q_val = q[t_h_i_offset] * scale;
  6577. float g_val = g[t_h_i_offset];
  6578. for (int64_t j = 0; j < head_size; j++) {
  6579. size_t t_h_j_offset = t_h_offset + j;
  6580. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  6581. float v_val = v[t_h_j_offset];
  6582. float kv_val = v_val * k_val;
  6583. float prev_state_val = state_prev[h_2d_i_j_offset];
  6584. float temp_val = prev_state_val * g_val + kv_val;
  6585. dst_data[t_h_j_offset] += temp_val * q_val;
  6586. state_cur[h_2d_i_j_offset] = temp_val;
  6587. }
  6588. }
  6589. }
  6590. }
  6591. #endif
  6592. }
  6593. void ggml_compute_forward_gla(
  6594. const ggml_compute_params * params,
  6595. ggml_tensor * dst) {
  6596. const ggml_tensor * src0 = dst->src[0];
  6597. switch (src0->type) {
  6598. case GGML_TYPE_F32:
  6599. {
  6600. ggml_compute_forward_gla_f32(params, dst);
  6601. } break;
  6602. default:
  6603. {
  6604. GGML_ABORT("fatal error");
  6605. }
  6606. }
  6607. }
  6608. // ggml_compute_forward_rwkv_wkv7
  6609. static void ggml_compute_forward_rwkv_wkv7_f32(
  6610. const ggml_compute_params * params,
  6611. ggml_tensor * dst) {
  6612. const int64_t T = dst->src[1]->ne[2];
  6613. const int64_t C = dst->ne[0];
  6614. const int64_t HEADS = dst->src[1]->ne[1];
  6615. const int64_t n_seqs = dst->src[6]->ne[1];
  6616. const int64_t head_size = C / HEADS;
  6617. float * dst_data = (float *) dst->data;
  6618. float * state = ((float *) dst->data) + C * T;
  6619. const int ith = params->ith;
  6620. const int nth = params->nth;
  6621. if (ith >= HEADS) {
  6622. return;
  6623. }
  6624. const int h_start = (HEADS * ith) / nth;
  6625. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  6626. (HEADS * (ith + 1)) / nth : HEADS;
  6627. float * r = (float *) dst->src[0]->data;
  6628. float * w = (float *) dst->src[1]->data;
  6629. float * k = (float *) dst->src[2]->data;
  6630. float * v = (float *) dst->src[3]->data;
  6631. float * a = (float *) dst->src[4]->data;
  6632. float * b = (float *) dst->src[5]->data;
  6633. int64_t t_stride = HEADS * head_size; // Same to C
  6634. int64_t h_stride = C / HEADS;
  6635. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  6636. int64_t h_stride_2d = head_size * head_size;
  6637. #if defined(GGML_SIMD)
  6638. for (int64_t t = 0; t < T; t++) {
  6639. int64_t t_offset = t * t_stride;
  6640. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  6641. float * state_cur = state + state_offset;
  6642. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  6643. for (int64_t h = h_start; h < h_end; h++) {
  6644. int64_t h_offset = h * h_stride;
  6645. int64_t t_h_offset = t_offset + h_offset;
  6646. int64_t h_2d_offset = h * h_stride_2d;
  6647. for (int64_t ii = 0; ii < head_size; ii++) {
  6648. int64_t t_h_i_offset = t_h_offset + ii;
  6649. int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
  6650. GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
  6651. float sa = 0;
  6652. {
  6653. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  6654. GGML_F32_VEC ax[GGML_F32_ARR];
  6655. GGML_F32_VEC ay[GGML_F32_ARR];
  6656. for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
  6657. for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
  6658. ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
  6659. ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
  6660. sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
  6661. }
  6662. }
  6663. GGML_F32_VEC_REDUCE(sa, sum);
  6664. }
  6665. GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
  6666. int64_t j = 0;
  6667. GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  6668. for (; j < head_size; j += GGML_F32_STEP) {
  6669. for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
  6670. int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
  6671. int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
  6672. GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
  6673. GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
  6674. GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
  6675. GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
  6676. k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
  6677. GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
  6678. // kv + s * decay + sa * b
  6679. state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
  6680. state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
  6681. GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
  6682. result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
  6683. }
  6684. }
  6685. GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
  6686. // There shouldn't be left-overs though.
  6687. for (; j < head_size; j++) {
  6688. int64_t t_h_j_offset = t_h_offset + j;
  6689. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  6690. float r_val = r[t_h_j_offset];
  6691. float w_val = w[t_h_j_offset];
  6692. float k_val = k[t_h_j_offset];
  6693. float b_val = b[t_h_j_offset];
  6694. float kv_val = v[t_h_i_offset] * k_val;
  6695. float prev_state_val = state_prev[h_2d_i_j_offset];
  6696. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  6697. dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
  6698. }
  6699. }
  6700. }
  6701. }
  6702. #else
  6703. for (int64_t t = 0; t < T; t++) {
  6704. int64_t t_offset = t * t_stride;
  6705. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  6706. float * state_cur = state + state_offset;
  6707. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  6708. for (int64_t h = h_start; h < h_end; h++) {
  6709. int64_t h_offset = h * h_stride;
  6710. int64_t t_h_offset = t_offset + h_offset;
  6711. int64_t h_2d_offset = h * h_stride_2d;
  6712. for (int64_t i = 0; i < head_size; i++) {
  6713. int64_t t_h_i_offset = t_h_offset + i;
  6714. int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
  6715. float v_val = v[t_h_i_offset];
  6716. float sa = 0, result = 0;
  6717. for (int64_t j = 0; j < head_size; j++) {
  6718. sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
  6719. }
  6720. for (int64_t j = 0; j < head_size; j++) {
  6721. int64_t t_h_j_offset = t_h_offset + j;
  6722. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  6723. float r_val = r[t_h_j_offset];
  6724. float w_val = w[t_h_j_offset];
  6725. float k_val = k[t_h_j_offset];
  6726. float b_val = b[t_h_j_offset];
  6727. float kv_val = v_val * k_val;
  6728. float prev_state_val = state_prev[h_2d_i_j_offset];
  6729. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  6730. result += state_cur[h_2d_i_j_offset] * r_val;
  6731. }
  6732. dst_data[t_h_i_offset] = result;
  6733. }
  6734. }
  6735. }
  6736. #endif
  6737. }
  6738. void ggml_compute_forward_rwkv_wkv7(
  6739. const ggml_compute_params * params,
  6740. ggml_tensor * dst) {
  6741. const ggml_tensor * src0 = dst->src[0];
  6742. switch (src0->type) {
  6743. case GGML_TYPE_F32:
  6744. {
  6745. ggml_compute_forward_rwkv_wkv7_f32(params, dst);
  6746. } break;
  6747. default:
  6748. {
  6749. GGML_ABORT("fatal error");
  6750. }
  6751. }
  6752. }
  6753. // ggml_compute_forward_map_custom1
  6754. void ggml_compute_forward_map_custom1(
  6755. const ggml_compute_params * params,
  6756. ggml_tensor * dst) {
  6757. const ggml_tensor * a = dst->src[0];
  6758. struct ggml_map_custom1_op_params p;
  6759. memcpy(&p, dst->op_params, sizeof(p));
  6760. p.fun(dst, a, params->ith, params->nth, p.userdata);
  6761. }
  6762. // ggml_compute_forward_map_custom2
  6763. void ggml_compute_forward_map_custom2(
  6764. const ggml_compute_params * params,
  6765. ggml_tensor * dst) {
  6766. const ggml_tensor * a = dst->src[0];
  6767. const ggml_tensor * b = dst->src[1];
  6768. struct ggml_map_custom2_op_params p;
  6769. memcpy(&p, dst->op_params, sizeof(p));
  6770. p.fun(dst, a, b, params->ith, params->nth, p.userdata);
  6771. }
  6772. // ggml_compute_forward_map_custom3
  6773. void ggml_compute_forward_map_custom3(
  6774. const ggml_compute_params * params,
  6775. ggml_tensor * dst) {
  6776. const ggml_tensor * a = dst->src[0];
  6777. const ggml_tensor * b = dst->src[1];
  6778. const ggml_tensor * c = dst->src[2];
  6779. struct ggml_map_custom3_op_params p;
  6780. memcpy(&p, dst->op_params, sizeof(p));
  6781. p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
  6782. }
  6783. // ggml_compute_forward_custom
  6784. void ggml_compute_forward_custom(
  6785. const struct ggml_compute_params * params,
  6786. struct ggml_tensor * dst) {
  6787. struct ggml_custom_op_params p;
  6788. memcpy(&p, dst->op_params, sizeof(p));
  6789. p.fun(dst, params->ith, params->nth, p.userdata);
  6790. }
  6791. // ggml_compute_forward_cross_entropy_loss
  6792. static void ggml_compute_forward_cross_entropy_loss_f32(
  6793. const ggml_compute_params * params,
  6794. ggml_tensor * dst) {
  6795. const ggml_tensor * src0 = dst->src[0];
  6796. const ggml_tensor * src1 = dst->src[1];
  6797. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  6798. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  6799. GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
  6800. GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
  6801. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  6802. GGML_ASSERT(ggml_is_scalar(dst));
  6803. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  6804. // TODO: handle transposed/permuted matrices
  6805. const int64_t nc = src0->ne[0];
  6806. const int64_t nr = ggml_nrows(src0);
  6807. const int ith = params->ith;
  6808. const int nth = params->nth;
  6809. float * sums = (float *) params->wdata;
  6810. float * st = ((float *) params->wdata) + nth + ith*nc;
  6811. float sum_thread = 0.0f;
  6812. GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
  6813. // rows per thread
  6814. const int64_t dr = (nr + nth - 1)/nth;
  6815. // row range for this thread
  6816. const int64_t ir0 = dr*ith;
  6817. const int64_t ir1 = MIN(ir0 + dr, nr);
  6818. for (int64_t i1 = ir0; i1 < ir1; ++i1) {
  6819. const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
  6820. const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
  6821. #ifndef NDEBUG
  6822. for (int64_t i = 0; i < nc; ++i) {
  6823. //printf("p[%d] = %f\n", i, p[i]);
  6824. assert(!isnan(s0[i]));
  6825. assert(!isnan(s1[i]));
  6826. }
  6827. #endif
  6828. float max = -INFINITY;
  6829. ggml_vec_max_f32(nc, &max, s0);
  6830. const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
  6831. assert(sum_softmax >= 0.0);
  6832. ggml_vec_add1_f32(nc, st, st, -sum_softmax);
  6833. ggml_vec_mul_f32(nc, st, st, s1);
  6834. float sum_st = 0.0f;
  6835. ggml_vec_sum_f32(nc, &sum_st, st);
  6836. sum_thread += sum_st;
  6837. #ifndef NDEBUG
  6838. for (int64_t i = 0; i < nc; ++i) {
  6839. assert(!isnan(st[i]));
  6840. assert(!isinf(st[i]));
  6841. }
  6842. #endif
  6843. }
  6844. sums[ith] = sum_thread;
  6845. ggml_barrier(params->threadpool);
  6846. if (ith == 0) {
  6847. float * dp = (float *) dst->data;
  6848. ggml_vec_sum_f32(nth, dp, sums);
  6849. dp[0] *= -1.0f / (float) nr;
  6850. }
  6851. }
  6852. void ggml_compute_forward_cross_entropy_loss(
  6853. const ggml_compute_params * params,
  6854. ggml_tensor * dst) {
  6855. const ggml_tensor * src0 = dst->src[0];
  6856. switch (src0->type) {
  6857. case GGML_TYPE_F32:
  6858. {
  6859. ggml_compute_forward_cross_entropy_loss_f32(params, dst);
  6860. } break;
  6861. default:
  6862. {
  6863. GGML_ABORT("fatal error");
  6864. }
  6865. }
  6866. }
  6867. // ggml_compute_forward_cross_entropy_loss_back
  6868. static void ggml_compute_forward_cross_entropy_loss_back_f32(
  6869. const ggml_compute_params * params,
  6870. ggml_tensor * dst) {
  6871. const ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
  6872. const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
  6873. const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
  6874. GGML_ASSERT(ggml_is_contiguous(dst));
  6875. GGML_ASSERT(ggml_is_contiguous(src0f));
  6876. GGML_ASSERT(ggml_is_contiguous(src1f));
  6877. GGML_ASSERT(ggml_is_contiguous(grad));
  6878. GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
  6879. const int64_t ith = params->ith;
  6880. const int64_t nth = params->nth;
  6881. // TODO: handle transposed/permuted matrices
  6882. const int64_t nc = src0f->ne[0];
  6883. const int64_t nr = ggml_nrows(src0f);
  6884. // rows per thread
  6885. const int64_t dr = (nr + nth - 1)/nth;
  6886. // row range for this thread
  6887. const int64_t ir0 = dr*ith;
  6888. const int64_t ir1 = MIN(ir0 + dr, nr);
  6889. const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
  6890. for (int64_t i1 = ir0; i1 < ir1; i1++) {
  6891. float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
  6892. const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
  6893. const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
  6894. #ifndef NDEBUG
  6895. for (int64_t i = 0; i < nc; ++i) {
  6896. //printf("p[%d] = %f\n", i, p[i]);
  6897. assert(!isnan(s0[i]));
  6898. assert(!isnan(s1[i]));
  6899. }
  6900. #endif
  6901. // soft_max
  6902. float max = -INFINITY;
  6903. ggml_vec_max_f32(nc, &max, s0);
  6904. const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
  6905. assert(sum > 0.0);
  6906. ggml_vec_scale_f32(nc, ds0, 1.0/sum);
  6907. // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
  6908. ggml_vec_sub_f32(nc, ds0, ds0, s1);
  6909. ggml_vec_scale_f32(nc, ds0, d_by_nr);
  6910. #ifndef NDEBUG
  6911. for (int64_t i = 0; i < nc; ++i) {
  6912. assert(!isnan(ds0[i]));
  6913. assert(!isinf(ds0[i]));
  6914. }
  6915. #endif
  6916. }
  6917. }
  6918. void ggml_compute_forward_cross_entropy_loss_back(
  6919. const ggml_compute_params * params,
  6920. ggml_tensor * dst) {
  6921. const ggml_tensor * src0 = dst->src[0];
  6922. switch (src0->type) {
  6923. case GGML_TYPE_F32:
  6924. {
  6925. ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
  6926. } break;
  6927. default:
  6928. {
  6929. GGML_ABORT("fatal error");
  6930. }
  6931. }
  6932. }
  6933. static void ggml_compute_forward_opt_step_adamw_f32(
  6934. const ggml_compute_params * params,
  6935. ggml_tensor * dst) {
  6936. const ggml_tensor * src0 = dst->src[0];
  6937. const ggml_tensor * src0_grad = dst->src[1];
  6938. const ggml_tensor * src0_grad_m = dst->src[2];
  6939. const ggml_tensor * src0_grad_v = dst->src[3];
  6940. const ggml_tensor * adamw_params = dst->src[4];
  6941. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
  6942. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
  6943. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
  6944. GGML_ASSERT(ggml_nelements(adamw_params) == 7);
  6945. const int ith = params->ith;
  6946. const int nth = params->nth;
  6947. const int nr = ggml_nrows(src0);
  6948. GGML_TENSOR_UNARY_OP_LOCALS
  6949. GGML_ASSERT(nb00 == sizeof(float));
  6950. // rows per thread
  6951. const int dr = (nr + nth - 1)/nth;
  6952. // row range for this thread
  6953. const int ir0 = dr*ith;
  6954. const int ir1 = MIN(ir0 + dr, nr);
  6955. const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
  6956. const float alpha = adamw_params_ptr[0];
  6957. const float beta1 = adamw_params_ptr[1];
  6958. const float beta2 = adamw_params_ptr[2];
  6959. const float eps = adamw_params_ptr[3];
  6960. const float wd = adamw_params_ptr[4];
  6961. const float beta1h = adamw_params_ptr[5];
  6962. const float beta2h = adamw_params_ptr[6];
  6963. for (int ir = ir0; ir < ir1; ++ir) {
  6964. const int64_t i03 = ir/(ne02*ne01);
  6965. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  6966. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  6967. const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
  6968. float * w = (float *) ((char *) src0->data + offset); // weight
  6969. const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
  6970. float * m = (float *) ((char *) src0_grad_m->data + offset);
  6971. float * v = (float *) ((char *) src0_grad_v->data + offset);
  6972. for (int i00 = 0; i00 < ne00; ++i00) {
  6973. m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
  6974. v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
  6975. const float mh = m[i00]*beta1h;
  6976. const float vh = sqrtf(v[i00]*beta2h) + eps;
  6977. // The weight decay is applied independently of the Adam momenta m and v.
  6978. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
  6979. // See: https://arxiv.org/pdf/1711.05101v3.pdf
  6980. w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
  6981. }
  6982. }
  6983. }
  6984. void ggml_compute_forward_opt_step_adamw(
  6985. const ggml_compute_params * params,
  6986. ggml_tensor * dst) {
  6987. const ggml_tensor * src0 = dst->src[0];
  6988. switch (src0->type) {
  6989. case GGML_TYPE_F32:
  6990. {
  6991. ggml_compute_forward_opt_step_adamw_f32(params, dst);
  6992. } break;
  6993. default:
  6994. {
  6995. GGML_ABORT("fatal error");
  6996. }
  6997. }
  6998. }