ops.cpp 336 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691769276937694769576967697769876997700770177027703770477057706770777087709771077117712771377147715771677177718771977207721772277237724772577267727772877297730773177327733773477357736773777387739774077417742774377447745774677477748774977507751775277537754775577567757775877597760776177627763776477657766776777687769777077717772777377747775777677777778777977807781778277837784778577867787778877897790779177927793779477957796779777987799780078017802780378047805780678077808780978107811781278137814781578167817781878197820782178227823782478257826782778287829783078317832783378347835783678377838783978407841784278437844784578467847784878497850785178527853785478557856785778587859786078617862786378647865786678677868786978707871787278737874787578767877787878797880788178827883788478857886788778887889789078917892789378947895789678977898789979007901790279037904790579067907790879097910791179127913791479157916791779187919792079217922792379247925792679277928792979307931793279337934793579367937793879397940794179427943794479457946794779487949795079517952795379547955795679577958795979607961796279637964796579667967796879697970797179727973797479757976797779787979798079817982798379847985798679877988798979907991799279937994799579967997799879998000800180028003800480058006800780088009801080118012801380148015801680178018801980208021802280238024802580268027802880298030803180328033803480358036803780388039804080418042804380448045804680478048804980508051805280538054805580568057805880598060806180628063806480658066806780688069807080718072807380748075807680778078807980808081808280838084808580868087808880898090809180928093809480958096809780988099810081018102810381048105810681078108810981108111811281138114811581168117811881198120812181228123812481258126812781288129813081318132813381348135813681378138813981408141814281438144814581468147814881498150815181528153815481558156815781588159816081618162816381648165816681678168816981708171817281738174817581768177817881798180818181828183818481858186818781888189819081918192819381948195819681978198819982008201820282038204820582068207820882098210821182128213821482158216821782188219822082218222822382248225822682278228822982308231823282338234823582368237823882398240824182428243824482458246824782488249825082518252825382548255825682578258825982608261826282638264826582668267826882698270827182728273827482758276827782788279828082818282828382848285828682878288828982908291829282938294829582968297829882998300830183028303830483058306830783088309831083118312831383148315831683178318831983208321832283238324832583268327832883298330833183328333833483358336833783388339834083418342834383448345834683478348834983508351835283538354835583568357835883598360836183628363836483658366836783688369837083718372837383748375837683778378837983808381838283838384838583868387838883898390839183928393839483958396839783988399840084018402840384048405840684078408840984108411841284138414841584168417841884198420842184228423842484258426842784288429843084318432843384348435843684378438843984408441844284438444844584468447844884498450845184528453845484558456845784588459846084618462846384648465846684678468846984708471847284738474847584768477847884798480848184828483848484858486848784888489849084918492849384948495849684978498849985008501850285038504850585068507850885098510851185128513851485158516851785188519852085218522852385248525852685278528852985308531853285338534853585368537853885398540854185428543854485458546854785488549855085518552855385548555855685578558855985608561856285638564856585668567856885698570857185728573857485758576857785788579858085818582858385848585858685878588858985908591859285938594859585968597859885998600860186028603860486058606860786088609861086118612861386148615861686178618861986208621862286238624862586268627862886298630863186328633863486358636863786388639864086418642864386448645864686478648864986508651865286538654865586568657865886598660866186628663866486658666866786688669867086718672867386748675867686778678867986808681868286838684868586868687868886898690869186928693869486958696869786988699870087018702870387048705870687078708870987108711871287138714871587168717871887198720872187228723872487258726872787288729873087318732873387348735873687378738873987408741874287438744874587468747874887498750875187528753875487558756875787588759876087618762876387648765876687678768876987708771877287738774877587768777877887798780878187828783878487858786878787888789879087918792879387948795879687978798879988008801880288038804880588068807880888098810881188128813881488158816881788188819882088218822882388248825882688278828882988308831883288338834883588368837883888398840884188428843884488458846884788488849885088518852885388548855885688578858885988608861886288638864886588668867886888698870887188728873887488758876887788788879888088818882888388848885888688878888888988908891889288938894889588968897889888998900890189028903890489058906890789088909891089118912891389148915891689178918891989208921892289238924892589268927892889298930893189328933893489358936893789388939894089418942894389448945894689478948894989508951895289538954895589568957895889598960896189628963896489658966896789688969897089718972897389748975897689778978897989808981898289838984898589868987898889898990899189928993899489958996899789988999900090019002900390049005900690079008900990109011901290139014901590169017901890199020902190229023902490259026902790289029903090319032903390349035903690379038903990409041904290439044904590469047904890499050905190529053905490559056905790589059906090619062906390649065906690679068906990709071907290739074907590769077907890799080908190829083908490859086908790889089909090919092909390949095909690979098909991009101910291039104910591069107910891099110911191129113911491159116911791189119912091219122912391249125912691279128912991309131913291339134913591369137913891399140914191429143914491459146914791489149915091519152915391549155915691579158915991609161916291639164916591669167916891699170917191729173917491759176917791789179918091819182918391849185918691879188918991909191919291939194919591969197919891999200920192029203920492059206920792089209921092119212921392149215921692179218921992209221922292239224922592269227922892299230923192329233923492359236923792389239924092419242924392449245924692479248924992509251925292539254925592569257925892599260926192629263926492659266926792689269927092719272927392749275927692779278927992809281928292839284928592869287928892899290929192929293929492959296929792989299930093019302930393049305930693079308930993109311931293139314931593169317931893199320932193229323932493259326932793289329933093319332933393349335933693379338933993409341934293439344934593469347934893499350935193529353935493559356935793589359936093619362936393649365936693679368936993709371937293739374937593769377937893799380938193829383938493859386938793889389939093919392939393949395939693979398939994009401940294039404940594069407940894099410941194129413941494159416941794189419942094219422942394249425942694279428942994309431943294339434943594369437943894399440944194429443944494459446944794489449945094519452945394549455945694579458945994609461946294639464946594669467946894699470947194729473947494759476947794789479948094819482948394849485948694879488948994909491949294939494949594969497949894999500950195029503950495059506950795089509951095119512951395149515951695179518951995209521952295239524952595269527952895299530953195329533953495359536953795389539954095419542954395449545954695479548954995509551955295539554955595569557955895599560956195629563956495659566956795689569957095719572957395749575957695779578957995809581958295839584958595869587958895899590959195929593959495959596959795989599960096019602960396049605960696079608960996109611961296139614961596169617961896199620962196229623962496259626962796289629963096319632963396349635963696379638963996409641964296439644964596469647964896499650965196529653965496559656965796589659966096619662966396649665966696679668966996709671967296739674967596769677967896799680968196829683968496859686968796889689969096919692969396949695969696979698969997009701970297039704970597069707970897099710971197129713971497159716971797189719972097219722972397249725972697279728972997309731973297339734973597369737973897399740974197429743974497459746974797489749975097519752975397549755975697579758975997609761976297639764976597669767976897699770977197729773977497759776977797789779978097819782978397849785978697879788978997909791979297939794979597969797979897999800980198029803980498059806980798089809981098119812981398149815981698179818981998209821982298239824982598269827982898299830983198329833983498359836983798389839984098419842984398449845984698479848984998509851985298539854985598569857985898599860986198629863986498659866986798689869987098719872987398749875987698779878987998809881988298839884988598869887988898899890989198929893989498959896989798989899990099019902990399049905990699079908990999109911991299139914991599169917991899199920992199229923992499259926992799289929993099319932993399349935993699379938993999409941994299439944994599469947994899499950995199529953995499559956995799589959996099619962996399649965996699679968996999709971997299739974997599769977997899799980998199829983998499859986998799889989999099919992999399949995999699979998999910000100011000210003100041000510006100071000810009100101001110012100131001410015100161001710018100191002010021100221002310024100251002610027100281002910030100311003210033100341003510036100371003810039100401004110042100431004410045100461004710048100491005010051100521005310054100551005610057100581005910060100611006210063100641006510066100671006810069100701007110072100731007410075100761007710078100791008010081100821008310084100851008610087100881008910090100911009210093100941009510096100971009810099101001010110102101031010410105101061010710108101091011010111101121011310114101151011610117101181011910120101211012210123101241012510126101271012810129101301013110132101331013410135101361013710138101391014010141101421014310144101451014610147101481014910150101511015210153101541015510156101571015810159101601016110162101631016410165101661016710168101691017010171
  1. #include "ops.h"
  2. #include "ggml-cpu.h"
  3. #include "ggml-impl.h"
  4. #include "binary-ops.h"
  5. #include "ggml.h"
  6. #include "unary-ops.h"
  7. #include "vec.h"
  8. #include <float.h>
  9. // ggml_compute_forward_dup
  10. static void ggml_compute_forward_dup_same_cont(
  11. const ggml_compute_params * params,
  12. ggml_tensor * dst) {
  13. const ggml_tensor * src0 = dst->src[0];
  14. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  15. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  16. GGML_ASSERT(src0->type == dst->type);
  17. const size_t nb0 = ggml_type_size(src0->type);
  18. const int ith = params->ith; // thread index
  19. const int nth = params->nth; // number of threads
  20. // parallelize by blocks
  21. const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
  22. const int dr = (nk + nth - 1) / nth;
  23. const int k0 = dr * ith;
  24. const int k1 = MIN(k0 + dr, nk);
  25. if (k0 < k1) {
  26. memcpy(
  27. ((char *) dst->data + k0*nb0),
  28. ((char *) src0->data + k0*nb0),
  29. (k1 - k0) * nb0);
  30. }
  31. }
  32. static void ggml_compute_forward_dup_f16(
  33. const ggml_compute_params * params,
  34. ggml_tensor * dst) {
  35. const ggml_tensor * src0 = dst->src[0];
  36. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  37. GGML_TENSOR_UNARY_OP_LOCALS
  38. const int ith = params->ith; // thread index
  39. const int nth = params->nth; // number of threads
  40. // parallelize by rows
  41. const int nr = ne01;
  42. // number of rows per thread
  43. const int dr = (nr + nth - 1) / nth;
  44. // row range for this thread
  45. const int ir0 = dr * ith;
  46. const int ir1 = MIN(ir0 + dr, nr);
  47. if (src0->type == dst->type &&
  48. ne00 == ne0 &&
  49. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  50. // copy by rows
  51. const size_t rs = ne00*nb00;
  52. for (int64_t i03 = 0; i03 < ne03; i03++) {
  53. for (int64_t i02 = 0; i02 < ne02; i02++) {
  54. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  55. memcpy(
  56. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  57. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  58. rs);
  59. }
  60. }
  61. }
  62. return;
  63. }
  64. // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
  65. if (ggml_is_contiguous(dst)) {
  66. if (nb00 == sizeof(ggml_fp16_t)) {
  67. if (dst->type == GGML_TYPE_F16) {
  68. size_t id = 0;
  69. const size_t rs = ne00 * nb00;
  70. char * dst_ptr = (char *) dst->data;
  71. for (int i03 = 0; i03 < ne03; i03++) {
  72. for (int i02 = 0; i02 < ne02; i02++) {
  73. id += rs * ir0;
  74. for (int i01 = ir0; i01 < ir1; i01++) {
  75. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  76. memcpy(dst_ptr + id, src0_ptr, rs);
  77. id += rs;
  78. }
  79. id += rs * (ne01 - ir1);
  80. }
  81. }
  82. } else if (dst->type == GGML_TYPE_F32) {
  83. size_t id = 0;
  84. float * dst_ptr = (float *) dst->data;
  85. for (int i03 = 0; i03 < ne03; i03++) {
  86. for (int i02 = 0; i02 < ne02; i02++) {
  87. id += ne00 * ir0;
  88. for (int i01 = ir0; i01 < ir1; i01++) {
  89. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  90. for (int i00 = 0; i00 < ne00; i00++) {
  91. dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
  92. id++;
  93. }
  94. }
  95. id += ne00 * (ne01 - ir1);
  96. }
  97. }
  98. } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
  99. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  100. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  101. size_t id = 0;
  102. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  103. char * dst_ptr = (char *) dst->data;
  104. for (int i03 = 0; i03 < ne03; i03++) {
  105. for (int i02 = 0; i02 < ne02; i02++) {
  106. id += rs * ir0;
  107. for (int i01 = ir0; i01 < ir1; i01++) {
  108. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  109. for (int i00 = 0; i00 < ne00; i00++) {
  110. src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
  111. }
  112. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  113. id += rs;
  114. }
  115. id += rs * (ne01 - ir1);
  116. }
  117. }
  118. } else {
  119. GGML_ABORT("fatal error"); // TODO: implement
  120. }
  121. } else {
  122. //printf("%s: this is not optimal - fix me\n", __func__);
  123. if (dst->type == GGML_TYPE_F32) {
  124. size_t id = 0;
  125. float * dst_ptr = (float *) dst->data;
  126. for (int i03 = 0; i03 < ne03; i03++) {
  127. for (int i02 = 0; i02 < ne02; i02++) {
  128. id += ne00 * ir0;
  129. for (int i01 = ir0; i01 < ir1; i01++) {
  130. for (int i00 = 0; i00 < ne00; i00++) {
  131. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  132. dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
  133. id++;
  134. }
  135. }
  136. id += ne00 * (ne01 - ir1);
  137. }
  138. }
  139. } else if (dst->type == GGML_TYPE_F16) {
  140. size_t id = 0;
  141. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  142. for (int i03 = 0; i03 < ne03; i03++) {
  143. for (int i02 = 0; i02 < ne02; i02++) {
  144. id += ne00 * ir0;
  145. for (int i01 = ir0; i01 < ir1; i01++) {
  146. for (int i00 = 0; i00 < ne00; i00++) {
  147. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  148. dst_ptr[id] = *src0_ptr;
  149. id++;
  150. }
  151. }
  152. id += ne00 * (ne01 - ir1);
  153. }
  154. }
  155. } else {
  156. GGML_ABORT("fatal error"); // TODO: implement
  157. }
  158. }
  159. return;
  160. }
  161. // dst counters
  162. int64_t i10 = 0;
  163. int64_t i11 = 0;
  164. int64_t i12 = 0;
  165. int64_t i13 = 0;
  166. if (dst->type == GGML_TYPE_F16) {
  167. for (int64_t i03 = 0; i03 < ne03; i03++) {
  168. for (int64_t i02 = 0; i02 < ne02; i02++) {
  169. i10 += ne00 * ir0;
  170. while (i10 >= ne0) {
  171. i10 -= ne0;
  172. if (++i11 == ne1) {
  173. i11 = 0;
  174. if (++i12 == ne2) {
  175. i12 = 0;
  176. if (++i13 == ne3) {
  177. i13 = 0;
  178. }
  179. }
  180. }
  181. }
  182. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  183. for (int64_t i00 = 0; i00 < ne00; i00++) {
  184. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  185. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  186. memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
  187. if (++i10 == ne00) {
  188. i10 = 0;
  189. if (++i11 == ne01) {
  190. i11 = 0;
  191. if (++i12 == ne02) {
  192. i12 = 0;
  193. if (++i13 == ne03) {
  194. i13 = 0;
  195. }
  196. }
  197. }
  198. }
  199. }
  200. }
  201. i10 += ne00 * (ne01 - ir1);
  202. while (i10 >= ne0) {
  203. i10 -= ne0;
  204. if (++i11 == ne1) {
  205. i11 = 0;
  206. if (++i12 == ne2) {
  207. i12 = 0;
  208. if (++i13 == ne3) {
  209. i13 = 0;
  210. }
  211. }
  212. }
  213. }
  214. }
  215. }
  216. } else if (dst->type == GGML_TYPE_F32) {
  217. for (int64_t i03 = 0; i03 < ne03; i03++) {
  218. for (int64_t i02 = 0; i02 < ne02; i02++) {
  219. i10 += ne00 * ir0;
  220. while (i10 >= ne0) {
  221. i10 -= ne0;
  222. if (++i11 == ne1) {
  223. i11 = 0;
  224. if (++i12 == ne2) {
  225. i12 = 0;
  226. if (++i13 == ne3) {
  227. i13 = 0;
  228. }
  229. }
  230. }
  231. }
  232. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  233. for (int64_t i00 = 0; i00 < ne00; i00++) {
  234. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  235. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  236. *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
  237. if (++i10 == ne0) {
  238. i10 = 0;
  239. if (++i11 == ne1) {
  240. i11 = 0;
  241. if (++i12 == ne2) {
  242. i12 = 0;
  243. if (++i13 == ne3) {
  244. i13 = 0;
  245. }
  246. }
  247. }
  248. }
  249. }
  250. }
  251. i10 += ne00 * (ne01 - ir1);
  252. while (i10 >= ne0) {
  253. i10 -= ne0;
  254. if (++i11 == ne1) {
  255. i11 = 0;
  256. if (++i12 == ne2) {
  257. i12 = 0;
  258. if (++i13 == ne3) {
  259. i13 = 0;
  260. }
  261. }
  262. }
  263. }
  264. }
  265. }
  266. } else {
  267. GGML_ABORT("fatal error"); // TODO: implement
  268. }
  269. }
  270. static void ggml_compute_forward_dup_bf16(
  271. const ggml_compute_params * params,
  272. ggml_tensor * dst) {
  273. const ggml_tensor * src0 = dst->src[0];
  274. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  275. GGML_TENSOR_UNARY_OP_LOCALS
  276. const int ith = params->ith; // thread index
  277. const int nth = params->nth; // number of threads
  278. // parallelize by rows
  279. const int nr = ne01;
  280. // number of rows per thread
  281. const int dr = (nr + nth - 1) / nth;
  282. // row range for this thread
  283. const int ir0 = dr * ith;
  284. const int ir1 = MIN(ir0 + dr, nr);
  285. if (src0->type == dst->type &&
  286. ne00 == ne0 &&
  287. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  288. // copy by rows
  289. const size_t rs = ne00*nb00;
  290. for (int64_t i03 = 0; i03 < ne03; i03++) {
  291. for (int64_t i02 = 0; i02 < ne02; i02++) {
  292. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  293. memcpy(
  294. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  295. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  296. rs);
  297. }
  298. }
  299. }
  300. return;
  301. }
  302. // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
  303. if (ggml_is_contiguous(dst)) {
  304. if (nb00 == sizeof(ggml_bf16_t)) {
  305. if (dst->type == GGML_TYPE_BF16) {
  306. size_t id = 0;
  307. const size_t rs = ne00 * nb00;
  308. char * dst_ptr = (char *) dst->data;
  309. for (int i03 = 0; i03 < ne03; i03++) {
  310. for (int i02 = 0; i02 < ne02; i02++) {
  311. id += rs * ir0;
  312. for (int i01 = ir0; i01 < ir1; i01++) {
  313. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  314. memcpy(dst_ptr + id, src0_ptr, rs);
  315. id += rs;
  316. }
  317. id += rs * (ne01 - ir1);
  318. }
  319. }
  320. } else if (dst->type == GGML_TYPE_F16) {
  321. size_t id = 0;
  322. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  323. for (int i03 = 0; i03 < ne03; i03++) {
  324. for (int i02 = 0; i02 < ne02; i02++) {
  325. id += ne00 * ir0;
  326. for (int i01 = ir0; i01 < ir1; i01++) {
  327. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  328. for (int i00 = 0; i00 < ne00; i00++) {
  329. dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
  330. id++;
  331. }
  332. }
  333. id += ne00 * (ne01 - ir1);
  334. }
  335. }
  336. } else if (dst->type == GGML_TYPE_F32) {
  337. size_t id = 0;
  338. float * dst_ptr = (float *) dst->data;
  339. for (int i03 = 0; i03 < ne03; i03++) {
  340. for (int i02 = 0; i02 < ne02; i02++) {
  341. id += ne00 * ir0;
  342. for (int i01 = ir0; i01 < ir1; i01++) {
  343. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  344. for (int i00 = 0; i00 < ne00; i00++) {
  345. dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
  346. id++;
  347. }
  348. }
  349. id += ne00 * (ne01 - ir1);
  350. }
  351. }
  352. } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
  353. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  354. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  355. size_t id = 0;
  356. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  357. char * dst_ptr = (char *) dst->data;
  358. for (int i03 = 0; i03 < ne03; i03++) {
  359. for (int i02 = 0; i02 < ne02; i02++) {
  360. id += rs * ir0;
  361. for (int i01 = ir0; i01 < ir1; i01++) {
  362. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  363. for (int i00 = 0; i00 < ne00; i00++) {
  364. src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
  365. }
  366. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  367. id += rs;
  368. }
  369. id += rs * (ne01 - ir1);
  370. }
  371. }
  372. } else {
  373. GGML_ABORT("fatal error"); // TODO: implement
  374. }
  375. } else {
  376. //printf("%s: this is not optimal - fix me\n", __func__);
  377. if (dst->type == GGML_TYPE_F32) {
  378. size_t id = 0;
  379. float * dst_ptr = (float *) dst->data;
  380. for (int i03 = 0; i03 < ne03; i03++) {
  381. for (int i02 = 0; i02 < ne02; i02++) {
  382. id += ne00 * ir0;
  383. for (int i01 = ir0; i01 < ir1; i01++) {
  384. for (int i00 = 0; i00 < ne00; i00++) {
  385. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  386. dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
  387. id++;
  388. }
  389. }
  390. id += ne00 * (ne01 - ir1);
  391. }
  392. }
  393. } else if (dst->type == GGML_TYPE_BF16) {
  394. size_t id = 0;
  395. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
  396. for (int i03 = 0; i03 < ne03; i03++) {
  397. for (int i02 = 0; i02 < ne02; i02++) {
  398. id += ne00 * ir0;
  399. for (int i01 = ir0; i01 < ir1; i01++) {
  400. for (int i00 = 0; i00 < ne00; i00++) {
  401. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  402. dst_ptr[id] = *src0_ptr;
  403. id++;
  404. }
  405. }
  406. id += ne00 * (ne01 - ir1);
  407. }
  408. }
  409. } else if (dst->type == GGML_TYPE_F16) {
  410. size_t id = 0;
  411. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  412. for (int i03 = 0; i03 < ne03; i03++) {
  413. for (int i02 = 0; i02 < ne02; i02++) {
  414. id += ne00 * ir0;
  415. for (int i01 = ir0; i01 < ir1; i01++) {
  416. for (int i00 = 0; i00 < ne00; i00++) {
  417. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  418. dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
  419. id++;
  420. }
  421. }
  422. id += ne00 * (ne01 - ir1);
  423. }
  424. }
  425. } else {
  426. GGML_ABORT("fatal error"); // TODO: implement
  427. }
  428. }
  429. return;
  430. }
  431. // dst counters
  432. int64_t i10 = 0;
  433. int64_t i11 = 0;
  434. int64_t i12 = 0;
  435. int64_t i13 = 0;
  436. if (dst->type == GGML_TYPE_BF16) {
  437. for (int64_t i03 = 0; i03 < ne03; i03++) {
  438. for (int64_t i02 = 0; i02 < ne02; i02++) {
  439. i10 += ne00 * ir0;
  440. while (i10 >= ne0) {
  441. i10 -= ne0;
  442. if (++i11 == ne1) {
  443. i11 = 0;
  444. if (++i12 == ne2) {
  445. i12 = 0;
  446. if (++i13 == ne3) {
  447. i13 = 0;
  448. }
  449. }
  450. }
  451. }
  452. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  453. for (int64_t i00 = 0; i00 < ne00; i00++) {
  454. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  455. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  456. memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
  457. if (++i10 == ne00) {
  458. i10 = 0;
  459. if (++i11 == ne01) {
  460. i11 = 0;
  461. if (++i12 == ne02) {
  462. i12 = 0;
  463. if (++i13 == ne03) {
  464. i13 = 0;
  465. }
  466. }
  467. }
  468. }
  469. }
  470. }
  471. i10 += ne00 * (ne01 - ir1);
  472. while (i10 >= ne0) {
  473. i10 -= ne0;
  474. if (++i11 == ne1) {
  475. i11 = 0;
  476. if (++i12 == ne2) {
  477. i12 = 0;
  478. if (++i13 == ne3) {
  479. i13 = 0;
  480. }
  481. }
  482. }
  483. }
  484. }
  485. }
  486. } else if (dst->type == GGML_TYPE_F16) {
  487. for (int64_t i03 = 0; i03 < ne03; i03++) {
  488. for (int64_t i02 = 0; i02 < ne02; i02++) {
  489. i10 += ne00 * ir0;
  490. while (i10 >= ne0) {
  491. i10 -= ne0;
  492. if (++i11 == ne1) {
  493. i11 = 0;
  494. if (++i12 == ne2) {
  495. i12 = 0;
  496. if (++i13 == ne3) {
  497. i13 = 0;
  498. }
  499. }
  500. }
  501. }
  502. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  503. for (int64_t i00 = 0; i00 < ne00; i00++) {
  504. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  505. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  506. *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
  507. if (++i10 == ne0) {
  508. i10 = 0;
  509. if (++i11 == ne1) {
  510. i11 = 0;
  511. if (++i12 == ne2) {
  512. i12 = 0;
  513. if (++i13 == ne3) {
  514. i13 = 0;
  515. }
  516. }
  517. }
  518. }
  519. }
  520. }
  521. i10 += ne00 * (ne01 - ir1);
  522. while (i10 >= ne0) {
  523. i10 -= ne0;
  524. if (++i11 == ne1) {
  525. i11 = 0;
  526. if (++i12 == ne2) {
  527. i12 = 0;
  528. if (++i13 == ne3) {
  529. i13 = 0;
  530. }
  531. }
  532. }
  533. }
  534. }
  535. }
  536. } else if (dst->type == GGML_TYPE_F32) {
  537. for (int64_t i03 = 0; i03 < ne03; i03++) {
  538. for (int64_t i02 = 0; i02 < ne02; i02++) {
  539. i10 += ne00 * ir0;
  540. while (i10 >= ne0) {
  541. i10 -= ne0;
  542. if (++i11 == ne1) {
  543. i11 = 0;
  544. if (++i12 == ne2) {
  545. i12 = 0;
  546. if (++i13 == ne3) {
  547. i13 = 0;
  548. }
  549. }
  550. }
  551. }
  552. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  553. for (int64_t i00 = 0; i00 < ne00; i00++) {
  554. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  555. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  556. *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
  557. if (++i10 == ne0) {
  558. i10 = 0;
  559. if (++i11 == ne1) {
  560. i11 = 0;
  561. if (++i12 == ne2) {
  562. i12 = 0;
  563. if (++i13 == ne3) {
  564. i13 = 0;
  565. }
  566. }
  567. }
  568. }
  569. }
  570. }
  571. i10 += ne00 * (ne01 - ir1);
  572. while (i10 >= ne0) {
  573. i10 -= ne0;
  574. if (++i11 == ne1) {
  575. i11 = 0;
  576. if (++i12 == ne2) {
  577. i12 = 0;
  578. if (++i13 == ne3) {
  579. i13 = 0;
  580. }
  581. }
  582. }
  583. }
  584. }
  585. }
  586. } else {
  587. GGML_ABORT("fatal error"); // TODO: implement
  588. }
  589. }
  590. static void ggml_compute_forward_dup_f32(
  591. const ggml_compute_params * params,
  592. ggml_tensor * dst) {
  593. const ggml_tensor * src0 = dst->src[0];
  594. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  595. GGML_TENSOR_UNARY_OP_LOCALS
  596. const int ith = params->ith; // thread index
  597. const int nth = params->nth; // number of threads
  598. // parallelize by rows
  599. const int nr = ne01;
  600. // number of rows per thread
  601. const int dr = (nr + nth - 1) / nth;
  602. // row range for this thread
  603. const int ir0 = dr * ith;
  604. const int ir1 = MIN(ir0 + dr, nr);
  605. if (src0->type == dst->type &&
  606. ne00 == ne0 &&
  607. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  608. // copy by rows
  609. const size_t rs = ne00*nb00;
  610. for (int64_t i03 = 0; i03 < ne03; i03++) {
  611. for (int64_t i02 = 0; i02 < ne02; i02++) {
  612. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  613. memcpy(
  614. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  615. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  616. rs);
  617. }
  618. }
  619. }
  620. return;
  621. }
  622. if (ggml_is_contiguous(dst)) {
  623. // TODO: simplify
  624. if (nb00 == sizeof(float)) {
  625. if (ggml_get_type_traits_cpu(dst->type)->from_float) {
  626. ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
  627. size_t id = 0;
  628. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  629. char * dst_ptr = (char *) dst->data;
  630. for (int i03 = 0; i03 < ne03; i03++) {
  631. for (int i02 = 0; i02 < ne02; i02++) {
  632. id += rs * ir0;
  633. for (int i01 = ir0; i01 < ir1; i01++) {
  634. const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  635. from_float(src0_ptr, dst_ptr + id, ne00);
  636. id += rs;
  637. }
  638. id += rs * (ne01 - ir1);
  639. }
  640. }
  641. } else {
  642. GGML_ABORT("fatal error"); // TODO: implement
  643. }
  644. } else {
  645. //printf("%s: this is not optimal - fix me\n", __func__);
  646. if (dst->type == GGML_TYPE_F32) {
  647. size_t id = 0;
  648. float * dst_ptr = (float *) dst->data;
  649. for (int i03 = 0; i03 < ne03; i03++) {
  650. for (int i02 = 0; i02 < ne02; i02++) {
  651. id += ne00 * ir0;
  652. for (int i01 = ir0; i01 < ir1; i01++) {
  653. for (int i00 = 0; i00 < ne00; i00++) {
  654. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  655. dst_ptr[id] = *src0_ptr;
  656. id++;
  657. }
  658. }
  659. id += ne00 * (ne01 - ir1);
  660. }
  661. }
  662. } else if (dst->type == GGML_TYPE_F16) {
  663. size_t id = 0;
  664. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  665. for (int i03 = 0; i03 < ne03; i03++) {
  666. for (int i02 = 0; i02 < ne02; i02++) {
  667. id += ne00 * ir0;
  668. for (int i01 = ir0; i01 < ir1; i01++) {
  669. for (int i00 = 0; i00 < ne00; i00++) {
  670. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  671. dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
  672. id++;
  673. }
  674. }
  675. id += ne00 * (ne01 - ir1);
  676. }
  677. }
  678. } else if (dst->type == GGML_TYPE_BF16) {
  679. size_t id = 0;
  680. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
  681. for (int i03 = 0; i03 < ne03; i03++) {
  682. for (int i02 = 0; i02 < ne02; i02++) {
  683. id += ne00 * ir0;
  684. for (int i01 = ir0; i01 < ir1; i01++) {
  685. for (int i00 = 0; i00 < ne00; i00++) {
  686. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  687. dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
  688. id++;
  689. }
  690. }
  691. id += ne00 * (ne01 - ir1);
  692. }
  693. }
  694. } else {
  695. GGML_ABORT("fatal error"); // TODO: implement
  696. }
  697. }
  698. return;
  699. }
  700. // dst counters
  701. int64_t i10 = 0;
  702. int64_t i11 = 0;
  703. int64_t i12 = 0;
  704. int64_t i13 = 0;
  705. if (dst->type == GGML_TYPE_F32) {
  706. for (int64_t i03 = 0; i03 < ne03; i03++) {
  707. for (int64_t i02 = 0; i02 < ne02; i02++) {
  708. i10 += ne00 * ir0;
  709. while (i10 >= ne0) {
  710. i10 -= ne0;
  711. if (++i11 == ne1) {
  712. i11 = 0;
  713. if (++i12 == ne2) {
  714. i12 = 0;
  715. if (++i13 == ne3) {
  716. i13 = 0;
  717. }
  718. }
  719. }
  720. }
  721. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  722. for (int64_t i00 = 0; i00 < ne00; i00++) {
  723. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  724. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  725. memcpy(dst_ptr, src0_ptr, sizeof(float));
  726. if (++i10 == ne0) {
  727. i10 = 0;
  728. if (++i11 == ne1) {
  729. i11 = 0;
  730. if (++i12 == ne2) {
  731. i12 = 0;
  732. if (++i13 == ne3) {
  733. i13 = 0;
  734. }
  735. }
  736. }
  737. }
  738. }
  739. }
  740. i10 += ne00 * (ne01 - ir1);
  741. while (i10 >= ne0) {
  742. i10 -= ne0;
  743. if (++i11 == ne1) {
  744. i11 = 0;
  745. if (++i12 == ne2) {
  746. i12 = 0;
  747. if (++i13 == ne3) {
  748. i13 = 0;
  749. }
  750. }
  751. }
  752. }
  753. }
  754. }
  755. } else if (dst->type == GGML_TYPE_F16) {
  756. for (int64_t i03 = 0; i03 < ne03; i03++) {
  757. for (int64_t i02 = 0; i02 < ne02; i02++) {
  758. i10 += ne00 * ir0;
  759. while (i10 >= ne0) {
  760. i10 -= ne0;
  761. if (++i11 == ne1) {
  762. i11 = 0;
  763. if (++i12 == ne2) {
  764. i12 = 0;
  765. if (++i13 == ne3) {
  766. i13 = 0;
  767. }
  768. }
  769. }
  770. }
  771. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  772. for (int64_t i00 = 0; i00 < ne00; i00++) {
  773. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  774. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  775. *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
  776. if (++i10 == ne0) {
  777. i10 = 0;
  778. if (++i11 == ne1) {
  779. i11 = 0;
  780. if (++i12 == ne2) {
  781. i12 = 0;
  782. if (++i13 == ne3) {
  783. i13 = 0;
  784. }
  785. }
  786. }
  787. }
  788. }
  789. }
  790. i10 += ne00 * (ne01 - ir1);
  791. while (i10 >= ne0) {
  792. i10 -= ne0;
  793. if (++i11 == ne1) {
  794. i11 = 0;
  795. if (++i12 == ne2) {
  796. i12 = 0;
  797. if (++i13 == ne3) {
  798. i13 = 0;
  799. }
  800. }
  801. }
  802. }
  803. }
  804. }
  805. } else if (dst->type == GGML_TYPE_BF16) {
  806. for (int64_t i03 = 0; i03 < ne03; i03++) {
  807. for (int64_t i02 = 0; i02 < ne02; i02++) {
  808. i10 += ne00 * ir0;
  809. while (i10 >= ne0) {
  810. i10 -= ne0;
  811. if (++i11 == ne1) {
  812. i11 = 0;
  813. if (++i12 == ne2) {
  814. i12 = 0;
  815. if (++i13 == ne3) {
  816. i13 = 0;
  817. }
  818. }
  819. }
  820. }
  821. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  822. for (int64_t i00 = 0; i00 < ne00; i00++) {
  823. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  824. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  825. *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
  826. if (++i10 == ne0) {
  827. i10 = 0;
  828. if (++i11 == ne1) {
  829. i11 = 0;
  830. if (++i12 == ne2) {
  831. i12 = 0;
  832. if (++i13 == ne3) {
  833. i13 = 0;
  834. }
  835. }
  836. }
  837. }
  838. }
  839. }
  840. i10 += ne00 * (ne01 - ir1);
  841. while (i10 >= ne0) {
  842. i10 -= ne0;
  843. if (++i11 == ne1) {
  844. i11 = 0;
  845. if (++i12 == ne2) {
  846. i12 = 0;
  847. if (++i13 == ne3) {
  848. i13 = 0;
  849. }
  850. }
  851. }
  852. }
  853. }
  854. }
  855. } else {
  856. GGML_ABORT("fatal error"); // TODO: implement
  857. }
  858. }
  859. // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
  860. static void ggml_compute_forward_dup_bytes(
  861. const ggml_compute_params * params,
  862. ggml_tensor * dst) {
  863. const ggml_tensor * src0 = dst->src[0];
  864. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  865. GGML_ASSERT(src0->type == dst->type);
  866. GGML_TENSOR_UNARY_OP_LOCALS;
  867. if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
  868. ggml_compute_forward_dup_same_cont(params, dst);
  869. return;
  870. }
  871. const size_t type_size = ggml_type_size(src0->type);
  872. const int ith = params->ith; // thread index
  873. const int nth = params->nth; // number of threads
  874. // parallelize by rows
  875. const int nr = ne01;
  876. // number of rows per thread
  877. const int dr = (nr + nth - 1) / nth;
  878. // row range for this thread
  879. const int ir0 = dr * ith;
  880. const int ir1 = MIN(ir0 + dr, nr);
  881. if (src0->type == dst->type &&
  882. ggml_are_same_shape(src0, dst) &&
  883. nb00 == type_size && nb0 == type_size) {
  884. // copy by rows
  885. const size_t rs = ggml_row_size(src0->type, ne00);
  886. for (int64_t i03 = 0; i03 < ne03; i03++) {
  887. for (int64_t i02 = 0; i02 < ne02; i02++) {
  888. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  889. memcpy(
  890. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  891. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  892. rs);
  893. }
  894. }
  895. }
  896. return;
  897. }
  898. if (ggml_is_contiguous(dst)) {
  899. size_t id = 0;
  900. char * dst_ptr = (char *) dst->data;
  901. const size_t rs = ne00 * type_size;
  902. if (nb00 == type_size) {
  903. // src0 is contigous on first dimension, copy by rows
  904. for (int64_t i03 = 0; i03 < ne03; i03++) {
  905. for (int64_t i02 = 0; i02 < ne02; i02++) {
  906. id += rs * ir0;
  907. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  908. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  909. memcpy(dst_ptr + id, src0_ptr, rs);
  910. id += rs;
  911. }
  912. id += rs * (ne01 - ir1);
  913. }
  914. }
  915. } else {
  916. //printf("%s: this is not optimal - fix me\n", __func__);
  917. for (int64_t i03 = 0; i03 < ne03; i03++) {
  918. for (int64_t i02 = 0; i02 < ne02; i02++) {
  919. id += rs * ir0;
  920. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  921. for (int64_t i00 = 0; i00 < ne00; i00++) {
  922. const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
  923. memcpy(dst_ptr + id, src0_ptr, type_size);
  924. id += type_size;
  925. }
  926. }
  927. id += rs * (ne01 - ir1);
  928. }
  929. }
  930. }
  931. return;
  932. }
  933. // dst counters
  934. int64_t k10 = 0;
  935. int64_t i11 = 0;
  936. int64_t i12 = 0;
  937. int64_t i13 = 0;
  938. // number of blocks in a row
  939. const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
  940. const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
  941. for (int64_t i03 = 0; i03 < ne03; i03++) {
  942. for (int64_t i02 = 0; i02 < ne02; i02++) {
  943. k10 += nk00 * ir0;
  944. while (k10 >= nk0) {
  945. k10 -= nk0;
  946. if (++i11 == ne1) {
  947. i11 = 0;
  948. if (++i12 == ne2) {
  949. i12 = 0;
  950. if (++i13 == ne3) {
  951. i13 = 0;
  952. }
  953. }
  954. }
  955. }
  956. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  957. for (int64_t k00 = 0; k00 < nk00; k00++) {
  958. const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  959. char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  960. memcpy(dst_ptr, src0_ptr, type_size);
  961. if (++k10 == nk0) {
  962. k10 = 0;
  963. if (++i11 == ne1) {
  964. i11 = 0;
  965. if (++i12 == ne2) {
  966. i12 = 0;
  967. if (++i13 == ne3) {
  968. i13 = 0;
  969. }
  970. }
  971. }
  972. }
  973. }
  974. }
  975. k10 += nk00 * (ne01 - ir1);
  976. while (k10 >= nk0) {
  977. k10 -= nk0;
  978. if (++i11 == ne1) {
  979. i11 = 0;
  980. if (++i12 == ne2) {
  981. i12 = 0;
  982. if (++i13 == ne3) {
  983. i13 = 0;
  984. }
  985. }
  986. }
  987. }
  988. }
  989. }
  990. }
  991. static void ggml_compute_forward_dup_q(
  992. const ggml_compute_params * params,
  993. ggml_tensor * dst) {
  994. const ggml_tensor * src0 = dst->src[0];
  995. const ggml_tensor * src1 = dst->src[1];
  996. GGML_TENSOR_BINARY_OP_LOCALS
  997. const ggml_type type = src0->type;
  998. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  999. size_t qk = ggml_blck_size(type);
  1000. const int64_t nr = ggml_nelements(src1) / qk;
  1001. // destination must be contiguous in the first dimension
  1002. GGML_ASSERT(nb10 == ggml_type_size(dst->type));
  1003. // must either have first dimension large enough to hold a row, or fully contiguous
  1004. GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
  1005. const int ith = params->ith;
  1006. const int nth = params->nth;
  1007. const int dr = (nr + nth - 1)/nth;
  1008. // row range for this thread
  1009. const int ir0 = dr*ith;
  1010. const int ir1 = MIN(ir0 + dr, nr);
  1011. for (int64_t ir = ir0; ir < ir1; ++ir) {
  1012. uint32_t i = ir * qk;
  1013. const int64_t i03 = i/(ne00 * ne01 * ne02);
  1014. const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
  1015. const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
  1016. const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
  1017. const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
  1018. const int64_t i13 = i/(ne10 * ne11 * ne12);
  1019. const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
  1020. const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
  1021. const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
  1022. const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
  1023. dequantize_row_q(
  1024. (const void *) ((char *) src0->data + x_offset),
  1025. (float *) ((char *) dst->data + dst_offset), qk);
  1026. }
  1027. }
  1028. void ggml_compute_forward_dup(
  1029. const ggml_compute_params * params,
  1030. ggml_tensor * dst) {
  1031. const ggml_tensor * src0 = dst->src[0];
  1032. if (src0->type == dst->type) {
  1033. ggml_compute_forward_dup_bytes(params, dst);
  1034. return;
  1035. }
  1036. switch (src0->type) {
  1037. case GGML_TYPE_F16:
  1038. {
  1039. ggml_compute_forward_dup_f16(params, dst);
  1040. } break;
  1041. case GGML_TYPE_BF16:
  1042. {
  1043. ggml_compute_forward_dup_bf16(params, dst);
  1044. } break;
  1045. case GGML_TYPE_F32:
  1046. {
  1047. ggml_compute_forward_dup_f32(params, dst);
  1048. } break;
  1049. default:
  1050. {
  1051. if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
  1052. ggml_compute_forward_dup_q(params, dst);
  1053. break;
  1054. }
  1055. GGML_ABORT("fatal error");
  1056. }
  1057. }
  1058. }
  1059. // ggml_compute_forward_add
  1060. static void ggml_compute_forward_add_q_f32(
  1061. const ggml_compute_params * params,
  1062. ggml_tensor * dst) {
  1063. const ggml_tensor * src0 = dst->src[0];
  1064. const ggml_tensor * src1 = dst->src[1];
  1065. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  1066. const int nr = ggml_nrows(src0);
  1067. GGML_TENSOR_BINARY_OP_LOCALS
  1068. const int ith = params->ith;
  1069. const int nth = params->nth;
  1070. const ggml_type type = src0->type;
  1071. const ggml_type dtype = dst->type;
  1072. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  1073. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
  1074. // we don't support permuted src0 or src1
  1075. GGML_ASSERT(nb00 == ggml_type_size(type));
  1076. GGML_ASSERT(nb10 == sizeof(float));
  1077. // dst cannot be transposed or permuted
  1078. GGML_ASSERT(nb0 <= nb1);
  1079. GGML_ASSERT(nb1 <= nb2);
  1080. GGML_ASSERT(nb2 <= nb3);
  1081. GGML_ASSERT(ggml_is_quantized(src0->type));
  1082. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1083. // rows per thread
  1084. const int dr = (nr + nth - 1)/nth;
  1085. // row range for this thread
  1086. const int ir0 = dr*ith;
  1087. const int ir1 = MIN(ir0 + dr, nr);
  1088. float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  1089. for (int ir = ir0; ir < ir1; ++ir) {
  1090. // src0 indices
  1091. const int i03 = ir/(ne02*ne01);
  1092. const int i02 = (ir - i03*ne02*ne01)/ne01;
  1093. const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
  1094. // src1 and dst are same shape as src0 => same indices
  1095. const int i13 = i03;
  1096. const int i12 = i02;
  1097. const int i11 = i01;
  1098. const int i3 = i03;
  1099. const int i2 = i02;
  1100. const int i1 = i01;
  1101. void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
  1102. float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
  1103. void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  1104. assert(ne00 % 32 == 0);
  1105. // unquantize row from src0 to temp buffer
  1106. dequantize_row_q(src0_row, wdata, ne00);
  1107. // add src1
  1108. ggml_vec_acc_f32(ne00, wdata, src1_row);
  1109. // quantize row to dst
  1110. if (quantize_row_q != NULL) {
  1111. quantize_row_q(wdata, dst_row, ne00);
  1112. } else {
  1113. memcpy(dst_row, wdata, ne0*nb0);
  1114. }
  1115. }
  1116. }
  1117. void ggml_compute_forward_add(
  1118. const ggml_compute_params * params,
  1119. ggml_tensor * dst) {
  1120. const ggml_tensor * src0 = dst->src[0];
  1121. switch (src0->type) {
  1122. case GGML_TYPE_F32:
  1123. case GGML_TYPE_F16:
  1124. case GGML_TYPE_BF16:
  1125. {
  1126. ggml_compute_forward_add_non_quantized(params, dst);
  1127. } break;
  1128. case GGML_TYPE_Q4_0:
  1129. case GGML_TYPE_Q4_1:
  1130. case GGML_TYPE_Q5_0:
  1131. case GGML_TYPE_Q5_1:
  1132. case GGML_TYPE_Q8_0:
  1133. case GGML_TYPE_Q2_K:
  1134. case GGML_TYPE_Q3_K:
  1135. case GGML_TYPE_Q4_K:
  1136. case GGML_TYPE_Q5_K:
  1137. case GGML_TYPE_Q6_K:
  1138. case GGML_TYPE_TQ1_0:
  1139. case GGML_TYPE_TQ2_0:
  1140. case GGML_TYPE_IQ2_XXS:
  1141. case GGML_TYPE_IQ2_XS:
  1142. case GGML_TYPE_IQ3_XXS:
  1143. case GGML_TYPE_IQ1_S:
  1144. case GGML_TYPE_IQ1_M:
  1145. case GGML_TYPE_IQ4_NL:
  1146. case GGML_TYPE_IQ4_XS:
  1147. case GGML_TYPE_IQ3_S:
  1148. case GGML_TYPE_IQ2_S:
  1149. {
  1150. ggml_compute_forward_add_q_f32(params, dst);
  1151. } break;
  1152. default:
  1153. {
  1154. GGML_ABORT("fatal error");
  1155. }
  1156. }
  1157. }
  1158. // ggml_compute_forward_add1
  1159. static void ggml_compute_forward_add1_f32(
  1160. const ggml_compute_params * params,
  1161. ggml_tensor * dst) {
  1162. const ggml_tensor * src0 = dst->src[0];
  1163. const ggml_tensor * src1 = dst->src[1];
  1164. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1165. GGML_ASSERT(ggml_is_scalar(src1));
  1166. const int ith = params->ith;
  1167. const int nth = params->nth;
  1168. const int nr = ggml_nrows(src0);
  1169. GGML_TENSOR_UNARY_OP_LOCALS
  1170. GGML_ASSERT( nb0 == sizeof(float));
  1171. GGML_ASSERT(nb00 == sizeof(float));
  1172. // rows per thread
  1173. const int dr = (nr + nth - 1)/nth;
  1174. // row range for this thread
  1175. const int ir0 = dr*ith;
  1176. const int ir1 = MIN(ir0 + dr, nr);
  1177. for (int ir = ir0; ir < ir1; ++ir) {
  1178. // src0 and dst are same shape => same indices
  1179. const int i3 = ir/(ne2*ne1);
  1180. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1181. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1182. #ifdef GGML_USE_ACCELERATE
  1183. GGML_UNUSED(ggml_vec_add1_f32);
  1184. vDSP_vadd(
  1185. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
  1186. (float *) ((char *) src1->data), 0,
  1187. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
  1188. ne0);
  1189. #else
  1190. ggml_vec_add1_f32(ne0,
  1191. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
  1192. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  1193. *(float *) src1->data);
  1194. #endif
  1195. }
  1196. }
  1197. static void ggml_compute_forward_add1_f16_f32(
  1198. const ggml_compute_params * params,
  1199. ggml_tensor * dst) {
  1200. const ggml_tensor * src0 = dst->src[0];
  1201. const ggml_tensor * src1 = dst->src[1];
  1202. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1203. GGML_ASSERT(ggml_is_scalar(src1));
  1204. // scalar to add
  1205. const float v = *(float *) src1->data;
  1206. const int ith = params->ith;
  1207. const int nth = params->nth;
  1208. const int nr = ggml_nrows(src0);
  1209. GGML_TENSOR_UNARY_OP_LOCALS
  1210. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  1211. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1212. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  1213. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  1214. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1215. // rows per thread
  1216. const int dr = (nr + nth - 1)/nth;
  1217. // row range for this thread
  1218. const int ir0 = dr*ith;
  1219. const int ir1 = MIN(ir0 + dr, nr);
  1220. for (int ir = ir0; ir < ir1; ++ir) {
  1221. // src0 and dst are same shape => same indices
  1222. const int i3 = ir/(ne2*ne1);
  1223. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1224. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1225. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1226. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1227. for (int i = 0; i < ne0; i++) {
  1228. dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
  1229. }
  1230. }
  1231. }
  1232. static void ggml_compute_forward_add1_f16_f16(
  1233. const ggml_compute_params * params,
  1234. ggml_tensor * dst) {
  1235. const ggml_tensor * src0 = dst->src[0];
  1236. const ggml_tensor * src1 = dst->src[1];
  1237. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1238. GGML_ASSERT(ggml_is_scalar(src1));
  1239. // scalar to add
  1240. const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
  1241. const int ith = params->ith;
  1242. const int nth = params->nth;
  1243. const int nr = ggml_nrows(src0);
  1244. GGML_TENSOR_UNARY_OP_LOCALS
  1245. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  1246. GGML_ASSERT(src1->type == GGML_TYPE_F16);
  1247. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  1248. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  1249. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1250. // rows per thread
  1251. const int dr = (nr + nth - 1)/nth;
  1252. // row range for this thread
  1253. const int ir0 = dr*ith;
  1254. const int ir1 = MIN(ir0 + dr, nr);
  1255. for (int ir = ir0; ir < ir1; ++ir) {
  1256. // src0 and dst are same shape => same indices
  1257. const int i3 = ir/(ne2*ne1);
  1258. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1259. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1260. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1261. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1262. for (int i = 0; i < ne0; i++) {
  1263. dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
  1264. }
  1265. }
  1266. }
  1267. static void ggml_compute_forward_add1_q_f32(
  1268. const ggml_compute_params * params,
  1269. ggml_tensor * dst) {
  1270. const ggml_tensor * src0 = dst->src[0];
  1271. const ggml_tensor * src1 = dst->src[1];
  1272. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1273. GGML_ASSERT(ggml_is_scalar(src1));
  1274. // scalar to add
  1275. const float v = *(float *) src1->data;
  1276. const int ith = params->ith;
  1277. const int nth = params->nth;
  1278. const int nr = ggml_nrows(src0);
  1279. GGML_TENSOR_UNARY_OP_LOCALS
  1280. const ggml_type type = src0->type;
  1281. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  1282. ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
  1283. // we don't support permuted src0
  1284. GGML_ASSERT(nb00 == ggml_type_size(type));
  1285. // dst cannot be transposed or permuted
  1286. GGML_ASSERT(nb0 <= nb1);
  1287. GGML_ASSERT(nb1 <= nb2);
  1288. GGML_ASSERT(nb2 <= nb3);
  1289. GGML_ASSERT(ggml_is_quantized(src0->type));
  1290. GGML_ASSERT(dst->type == src0->type);
  1291. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1292. // rows per thread
  1293. const int dr = (nr + nth - 1)/nth;
  1294. // row range for this thread
  1295. const int ir0 = dr*ith;
  1296. const int ir1 = MIN(ir0 + dr, nr);
  1297. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  1298. for (int ir = ir0; ir < ir1; ++ir) {
  1299. // src0 and dst are same shape => same indices
  1300. const int i3 = ir/(ne2*ne1);
  1301. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1302. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1303. void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
  1304. void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
  1305. assert(ne0 % 32 == 0);
  1306. // unquantize row from src0 to temp buffer
  1307. dequantize_row_q(src0_row, wdata, ne0);
  1308. // add src1
  1309. ggml_vec_acc1_f32(ne0, wdata, v);
  1310. // quantize row to dst
  1311. quantize_row_q(wdata, dst_row, ne0);
  1312. }
  1313. }
  1314. static void ggml_compute_forward_add1_bf16_f32(
  1315. const ggml_compute_params * params,
  1316. ggml_tensor * dst) {
  1317. const ggml_tensor * src0 = dst->src[0];
  1318. const ggml_tensor * src1 = dst->src[1];
  1319. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1320. GGML_ASSERT(ggml_is_scalar(src1));
  1321. // scalar to add
  1322. const float v = *(float *) src1->data;
  1323. const int ith = params->ith;
  1324. const int nth = params->nth;
  1325. const int nr = ggml_nrows(src0);
  1326. GGML_TENSOR_UNARY_OP_LOCALS
  1327. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  1328. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1329. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  1330. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  1331. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  1332. // rows per thread
  1333. const int dr = (nr + nth - 1)/nth;
  1334. // row range for this thread
  1335. const int ir0 = dr*ith;
  1336. const int ir1 = MIN(ir0 + dr, nr);
  1337. for (int ir = ir0; ir < ir1; ++ir) {
  1338. // src0 and dst are same shape => same indices
  1339. const int i3 = ir/(ne2*ne1);
  1340. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1341. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1342. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1343. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1344. for (int i = 0; i < ne0; i++) {
  1345. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  1346. }
  1347. }
  1348. }
  1349. static void ggml_compute_forward_add1_bf16_bf16(
  1350. const ggml_compute_params * params,
  1351. ggml_tensor * dst) {
  1352. const ggml_tensor * src0 = dst->src[0];
  1353. const ggml_tensor * src1 = dst->src[1];
  1354. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1355. GGML_ASSERT(ggml_is_scalar(src1));
  1356. // scalar to add
  1357. const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
  1358. const int ith = params->ith;
  1359. const int nth = params->nth;
  1360. const int nr = ggml_nrows(src0);
  1361. GGML_TENSOR_UNARY_OP_LOCALS
  1362. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  1363. GGML_ASSERT(src1->type == GGML_TYPE_BF16);
  1364. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  1365. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  1366. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  1367. // rows per thread
  1368. const int dr = (nr + nth - 1)/nth;
  1369. // row range for this thread
  1370. const int ir0 = dr*ith;
  1371. const int ir1 = MIN(ir0 + dr, nr);
  1372. for (int ir = ir0; ir < ir1; ++ir) {
  1373. // src0 and dst are same shape => same indices
  1374. const int i3 = ir/(ne2*ne1);
  1375. const int i2 = (ir - i3*ne2*ne1)/ne1;
  1376. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  1377. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  1378. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  1379. for (int i = 0; i < ne0; i++) {
  1380. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  1381. }
  1382. }
  1383. }
  1384. void ggml_compute_forward_add1(
  1385. const ggml_compute_params * params,
  1386. ggml_tensor * dst) {
  1387. const ggml_tensor * src0 = dst->src[0];
  1388. const ggml_tensor * src1 = dst->src[1];
  1389. switch (src0->type) {
  1390. case GGML_TYPE_F32:
  1391. {
  1392. ggml_compute_forward_add1_f32(params, dst);
  1393. } break;
  1394. case GGML_TYPE_F16:
  1395. {
  1396. if (src1->type == GGML_TYPE_F16) {
  1397. ggml_compute_forward_add1_f16_f16(params, dst);
  1398. }
  1399. else if (src1->type == GGML_TYPE_F32) {
  1400. ggml_compute_forward_add1_f16_f32(params, dst);
  1401. }
  1402. else {
  1403. GGML_ABORT("fatal error");
  1404. }
  1405. } break;
  1406. case GGML_TYPE_BF16:
  1407. {
  1408. if (src1->type == GGML_TYPE_BF16) {
  1409. ggml_compute_forward_add1_bf16_bf16(params, dst);
  1410. }
  1411. else if (src1->type == GGML_TYPE_F32) {
  1412. ggml_compute_forward_add1_bf16_f32(params, dst);
  1413. }
  1414. else {
  1415. GGML_ABORT("fatal error");
  1416. }
  1417. } break;
  1418. case GGML_TYPE_Q4_0:
  1419. case GGML_TYPE_Q4_1:
  1420. case GGML_TYPE_Q5_0:
  1421. case GGML_TYPE_Q5_1:
  1422. case GGML_TYPE_Q8_0:
  1423. case GGML_TYPE_Q8_1:
  1424. case GGML_TYPE_Q2_K:
  1425. case GGML_TYPE_Q3_K:
  1426. case GGML_TYPE_Q4_K:
  1427. case GGML_TYPE_Q5_K:
  1428. case GGML_TYPE_Q6_K:
  1429. case GGML_TYPE_TQ1_0:
  1430. case GGML_TYPE_TQ2_0:
  1431. case GGML_TYPE_IQ2_XXS:
  1432. case GGML_TYPE_IQ2_XS:
  1433. case GGML_TYPE_IQ3_XXS:
  1434. case GGML_TYPE_IQ1_S:
  1435. case GGML_TYPE_IQ1_M:
  1436. case GGML_TYPE_IQ4_NL:
  1437. case GGML_TYPE_IQ4_XS:
  1438. case GGML_TYPE_IQ3_S:
  1439. case GGML_TYPE_IQ2_S:
  1440. {
  1441. ggml_compute_forward_add1_q_f32(params, dst);
  1442. } break;
  1443. default:
  1444. {
  1445. GGML_ABORT("fatal error");
  1446. }
  1447. }
  1448. }
  1449. // ggml_compute_forward_acc
  1450. static void ggml_compute_forward_acc_f32(
  1451. const ggml_compute_params * params,
  1452. ggml_tensor * dst) {
  1453. const ggml_tensor * src0 = dst->src[0];
  1454. const ggml_tensor * src1 = dst->src[1];
  1455. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  1456. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  1457. // view src0 and dst with these strides and data offset inbytes during acc
  1458. // nb0 is implicitly element_size because src0 and dst are contiguous
  1459. size_t nb1 = ((int32_t *) dst->op_params)[0];
  1460. size_t nb2 = ((int32_t *) dst->op_params)[1];
  1461. size_t nb3 = ((int32_t *) dst->op_params)[2];
  1462. size_t offset = ((int32_t *) dst->op_params)[3];
  1463. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  1464. if (!inplace) {
  1465. if (params->ith == 0) {
  1466. // memcpy needs to be synchronized across threads to avoid race conditions.
  1467. // => do it in INIT phase
  1468. memcpy(
  1469. ((char *) dst->data),
  1470. ((char *) src0->data),
  1471. ggml_nbytes(dst));
  1472. }
  1473. ggml_barrier(params->threadpool);
  1474. }
  1475. const int ith = params->ith;
  1476. const int nth = params->nth;
  1477. const int nr = ggml_nrows(src1);
  1478. const int nc = src1->ne[0];
  1479. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  1480. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  1481. // src0 and dst as viewed during acc
  1482. const size_t nb0 = ggml_element_size(src0);
  1483. const size_t nb00 = nb0;
  1484. const size_t nb01 = nb1;
  1485. const size_t nb02 = nb2;
  1486. const size_t nb03 = nb3;
  1487. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
  1488. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
  1489. GGML_ASSERT(nb10 == sizeof(float));
  1490. // rows per thread
  1491. const int dr = (nr + nth - 1)/nth;
  1492. // row range for this thread
  1493. const int ir0 = dr*ith;
  1494. const int ir1 = MIN(ir0 + dr, nr);
  1495. for (int ir = ir0; ir < ir1; ++ir) {
  1496. // src0 and dst are viewed with shape of src1 and offset
  1497. // => same indices
  1498. const int i3 = ir/(ne12*ne11);
  1499. const int i2 = (ir - i3*ne12*ne11)/ne11;
  1500. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  1501. #ifdef GGML_USE_ACCELERATE
  1502. vDSP_vadd(
  1503. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
  1504. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
  1505. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
  1506. #else
  1507. ggml_vec_add_f32(nc,
  1508. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  1509. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
  1510. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  1511. #endif
  1512. }
  1513. }
  1514. void ggml_compute_forward_acc(
  1515. const ggml_compute_params * params,
  1516. ggml_tensor * dst) {
  1517. const ggml_tensor * src0 = dst->src[0];
  1518. switch (src0->type) {
  1519. case GGML_TYPE_F32:
  1520. {
  1521. ggml_compute_forward_acc_f32(params, dst);
  1522. } break;
  1523. case GGML_TYPE_F16:
  1524. case GGML_TYPE_BF16:
  1525. case GGML_TYPE_Q4_0:
  1526. case GGML_TYPE_Q4_1:
  1527. case GGML_TYPE_Q5_0:
  1528. case GGML_TYPE_Q5_1:
  1529. case GGML_TYPE_Q8_0:
  1530. case GGML_TYPE_Q8_1:
  1531. case GGML_TYPE_Q2_K:
  1532. case GGML_TYPE_Q3_K:
  1533. case GGML_TYPE_Q4_K:
  1534. case GGML_TYPE_Q5_K:
  1535. case GGML_TYPE_Q6_K:
  1536. case GGML_TYPE_TQ1_0:
  1537. case GGML_TYPE_TQ2_0:
  1538. case GGML_TYPE_IQ2_XXS:
  1539. case GGML_TYPE_IQ2_XS:
  1540. case GGML_TYPE_IQ3_XXS:
  1541. case GGML_TYPE_IQ1_S:
  1542. case GGML_TYPE_IQ1_M:
  1543. case GGML_TYPE_IQ4_NL:
  1544. case GGML_TYPE_IQ4_XS:
  1545. case GGML_TYPE_IQ3_S:
  1546. case GGML_TYPE_IQ2_S:
  1547. default:
  1548. {
  1549. GGML_ABORT("fatal error");
  1550. }
  1551. }
  1552. }
  1553. // ggml_compute_forward_sum
  1554. static void ggml_compute_forward_sum_f32(
  1555. const ggml_compute_params * params,
  1556. ggml_tensor * dst) {
  1557. const ggml_tensor * src0 = dst->src[0];
  1558. if (params->ith != 0) {
  1559. return;
  1560. }
  1561. assert(ggml_is_scalar(dst));
  1562. assert(src0->nb[0] == sizeof(float));
  1563. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1564. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1565. ggml_float sum = 0;
  1566. ggml_float row_sum = 0;
  1567. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1568. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1569. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1570. ggml_vec_sum_f32_ggf(ne00,
  1571. &row_sum,
  1572. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  1573. sum += row_sum;
  1574. }
  1575. }
  1576. }
  1577. ((float *) dst->data)[0] = sum;
  1578. }
  1579. static void ggml_compute_forward_sum_f16(
  1580. const ggml_compute_params * params,
  1581. ggml_tensor * dst) {
  1582. const ggml_tensor * src0 = dst->src[0];
  1583. if (params->ith != 0) {
  1584. return;
  1585. }
  1586. assert(ggml_is_scalar(dst));
  1587. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  1588. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1589. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1590. float sum = 0;
  1591. float row_sum = 0;
  1592. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1593. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1594. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1595. ggml_vec_sum_f16_ggf(ne00,
  1596. &row_sum,
  1597. (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  1598. sum += row_sum;
  1599. }
  1600. }
  1601. }
  1602. ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
  1603. }
  1604. static void ggml_compute_forward_sum_bf16(
  1605. const ggml_compute_params * params,
  1606. ggml_tensor * dst) {
  1607. const ggml_tensor * src0 = dst->src[0];
  1608. if (params->ith != 0) {
  1609. return;
  1610. }
  1611. assert(ggml_is_scalar(dst));
  1612. assert(src0->nb[0] == sizeof(ggml_bf16_t));
  1613. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  1614. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  1615. float sum = 0;
  1616. float row_sum = 0;
  1617. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1618. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1619. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1620. ggml_vec_sum_bf16_ggf(ne00,
  1621. &row_sum,
  1622. (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  1623. sum += row_sum;
  1624. }
  1625. }
  1626. }
  1627. ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
  1628. }
  1629. void ggml_compute_forward_sum(
  1630. const ggml_compute_params * params,
  1631. ggml_tensor * dst) {
  1632. const ggml_tensor * src0 = dst->src[0];
  1633. switch (src0->type) {
  1634. case GGML_TYPE_F32:
  1635. {
  1636. ggml_compute_forward_sum_f32(params, dst);
  1637. } break;
  1638. case GGML_TYPE_F16:
  1639. {
  1640. ggml_compute_forward_sum_f16(params, dst);
  1641. } break;
  1642. case GGML_TYPE_BF16:
  1643. {
  1644. ggml_compute_forward_sum_bf16(params, dst);
  1645. } break;
  1646. default:
  1647. {
  1648. GGML_ABORT("fatal error");
  1649. }
  1650. }
  1651. }
  1652. // ggml_compute_forward_sum_rows
  1653. static void ggml_compute_forward_sum_rows_f32(
  1654. const ggml_compute_params * params,
  1655. ggml_tensor * dst) {
  1656. const ggml_tensor * src0 = dst->src[0];
  1657. if (params->ith != 0) {
  1658. return;
  1659. }
  1660. GGML_ASSERT(src0->nb[0] == sizeof(float));
  1661. GGML_ASSERT(dst->nb[0] == sizeof(float));
  1662. GGML_TENSOR_UNARY_OP_LOCALS
  1663. GGML_ASSERT(ne0 == 1);
  1664. GGML_ASSERT(ne1 == ne01);
  1665. GGML_ASSERT(ne2 == ne02);
  1666. GGML_ASSERT(ne3 == ne03);
  1667. for (int64_t i3 = 0; i3 < ne03; i3++) {
  1668. for (int64_t i2 = 0; i2 < ne02; i2++) {
  1669. for (int64_t i1 = 0; i1 < ne01; i1++) {
  1670. float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
  1671. float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
  1672. float row_sum = 0;
  1673. ggml_vec_sum_f32(ne00, &row_sum, src_row);
  1674. dst_row[0] = row_sum;
  1675. }
  1676. }
  1677. }
  1678. }
  1679. void ggml_compute_forward_sum_rows(
  1680. const ggml_compute_params * params,
  1681. ggml_tensor * dst) {
  1682. const ggml_tensor * src0 = dst->src[0];
  1683. switch (src0->type) {
  1684. case GGML_TYPE_F32:
  1685. {
  1686. ggml_compute_forward_sum_rows_f32(params, dst);
  1687. } break;
  1688. default:
  1689. {
  1690. GGML_ABORT("fatal error");
  1691. }
  1692. }
  1693. }
  1694. // ggml_compute_forward_mean
  1695. static void ggml_compute_forward_mean_f32(
  1696. const ggml_compute_params * params,
  1697. ggml_tensor * dst) {
  1698. const ggml_tensor * src0 = dst->src[0];
  1699. if (params->ith != 0) {
  1700. return;
  1701. }
  1702. assert(src0->nb[0] == sizeof(float));
  1703. GGML_TENSOR_UNARY_OP_LOCALS
  1704. assert(ne0 == 1);
  1705. assert(ne1 == ne01);
  1706. assert(ne2 == ne02);
  1707. assert(ne3 == ne03);
  1708. GGML_UNUSED(ne0);
  1709. GGML_UNUSED(ne1);
  1710. GGML_UNUSED(ne2);
  1711. GGML_UNUSED(ne3);
  1712. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1713. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1714. for (int64_t i01 = 0; i01 < ne01; i01++) {
  1715. ggml_vec_sum_f32(ne00,
  1716. (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  1717. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  1718. *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
  1719. }
  1720. }
  1721. }
  1722. }
  1723. void ggml_compute_forward_mean(
  1724. const ggml_compute_params * params,
  1725. ggml_tensor * dst) {
  1726. const ggml_tensor * src0 = dst->src[0];
  1727. switch (src0->type) {
  1728. case GGML_TYPE_F32:
  1729. {
  1730. ggml_compute_forward_mean_f32(params, dst);
  1731. } break;
  1732. default:
  1733. {
  1734. GGML_ABORT("fatal error");
  1735. }
  1736. }
  1737. }
  1738. // ggml_compute_forward_argmax
  1739. static void ggml_compute_forward_argmax_f32(
  1740. const ggml_compute_params * params,
  1741. ggml_tensor * dst) {
  1742. const ggml_tensor * src0 = dst->src[0];
  1743. if (params->ith != 0) {
  1744. return;
  1745. }
  1746. assert(src0->nb[0] == sizeof(float));
  1747. assert(dst->nb[0] == sizeof(float));
  1748. const int64_t ne00 = src0->ne[0];
  1749. const int64_t ne01 = src0->ne[1];
  1750. const size_t nb01 = src0->nb[1];
  1751. const size_t nb0 = dst->nb[0];
  1752. for (int64_t i1 = 0; i1 < ne01; i1++) {
  1753. float * src = (float *) ((char *) src0->data + i1*nb01);
  1754. int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
  1755. int v = 0;
  1756. ggml_vec_argmax_f32(ne00, &v, src);
  1757. dst_[0] = v;
  1758. }
  1759. }
  1760. void ggml_compute_forward_argmax(
  1761. const ggml_compute_params * params,
  1762. ggml_tensor * dst) {
  1763. const ggml_tensor * src0 = dst->src[0];
  1764. switch (src0->type) {
  1765. case GGML_TYPE_F32:
  1766. {
  1767. ggml_compute_forward_argmax_f32(params, dst);
  1768. } break;
  1769. default:
  1770. {
  1771. GGML_ABORT("fatal error");
  1772. }
  1773. }
  1774. }
  1775. // ggml_compute_forward_count_equal
  1776. static void ggml_compute_forward_count_equal_i32(
  1777. const ggml_compute_params * params,
  1778. ggml_tensor * dst) {
  1779. const ggml_tensor * src0 = dst->src[0];
  1780. const ggml_tensor * src1 = dst->src[1];
  1781. GGML_TENSOR_BINARY_OP_LOCALS;
  1782. GGML_ASSERT(src0->type == GGML_TYPE_I32);
  1783. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  1784. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  1785. GGML_ASSERT(ggml_is_scalar(dst));
  1786. GGML_ASSERT(dst->type == GGML_TYPE_I64);
  1787. const int64_t nr = ggml_nrows(src0);
  1788. const int ith = params->ith;
  1789. const int nth = params->nth;
  1790. int64_t * sums = (int64_t *) params->wdata;
  1791. int64_t sum_thread = 0;
  1792. // rows per thread
  1793. const int64_t dr = (nr + nth - 1)/nth;
  1794. // row range for this thread
  1795. const int64_t ir0 = dr*ith;
  1796. const int64_t ir1 = MIN(ir0 + dr, nr);
  1797. for (int64_t ir = ir0; ir < ir1; ++ir) {
  1798. const int64_t i03 = ir / (ne02*ne01);
  1799. const int64_t i02 = (ir - i03*ne03) / ne01;
  1800. const int64_t i01 = ir - i03*ne03 - i02*ne02;
  1801. const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
  1802. const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
  1803. for (int64_t i00 = 0; i00 < ne00; ++i00) {
  1804. const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
  1805. const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
  1806. sum_thread += val0 == val1;
  1807. }
  1808. }
  1809. if (ith != 0) {
  1810. sums[ith] = sum_thread;
  1811. }
  1812. ggml_barrier(params->threadpool);
  1813. if (ith != 0) {
  1814. return;
  1815. }
  1816. for (int ith_other = 1; ith_other < nth; ++ith_other) {
  1817. sum_thread += sums[ith_other];
  1818. }
  1819. *((int64_t *) dst->data) = sum_thread;
  1820. }
  1821. void ggml_compute_forward_count_equal(
  1822. const ggml_compute_params * params,
  1823. ggml_tensor * dst) {
  1824. const ggml_tensor * src0 = dst->src[0];
  1825. switch (src0->type) {
  1826. case GGML_TYPE_I32:
  1827. {
  1828. ggml_compute_forward_count_equal_i32(params, dst);
  1829. } break;
  1830. default:
  1831. {
  1832. GGML_ABORT("fatal error");
  1833. }
  1834. }
  1835. }
  1836. // ggml_compute_forward_repeat
  1837. static void ggml_compute_forward_repeat_f32(
  1838. const ggml_compute_params * params,
  1839. ggml_tensor * dst) {
  1840. const ggml_tensor * src0 = dst->src[0];
  1841. if (params->ith != 0) {
  1842. return;
  1843. }
  1844. GGML_ASSERT(ggml_can_repeat(src0, dst));
  1845. GGML_TENSOR_UNARY_OP_LOCALS
  1846. // guaranteed to be an integer due to the check in ggml_can_repeat
  1847. const int nr0 = (int)(ne0/ne00);
  1848. const int nr1 = (int)(ne1/ne01);
  1849. const int nr2 = (int)(ne2/ne02);
  1850. const int nr3 = (int)(ne3/ne03);
  1851. // TODO: support for transposed / permuted tensors
  1852. GGML_ASSERT(nb0 == sizeof(float));
  1853. GGML_ASSERT(nb00 == sizeof(float));
  1854. // TODO: maybe this is not optimal?
  1855. for (int i3 = 0; i3 < nr3; i3++) {
  1856. for (int k3 = 0; k3 < ne03; k3++) {
  1857. for (int i2 = 0; i2 < nr2; i2++) {
  1858. for (int k2 = 0; k2 < ne02; k2++) {
  1859. for (int i1 = 0; i1 < nr1; i1++) {
  1860. for (int k1 = 0; k1 < ne01; k1++) {
  1861. for (int i0 = 0; i0 < nr0; i0++) {
  1862. ggml_vec_cpy_f32(ne00,
  1863. (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
  1864. (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
  1865. }
  1866. }
  1867. }
  1868. }
  1869. }
  1870. }
  1871. }
  1872. }
  1873. static void ggml_compute_forward_repeat_f16(
  1874. const ggml_compute_params * params,
  1875. ggml_tensor * dst) {
  1876. const ggml_tensor * src0 = dst->src[0];
  1877. if (params->ith != 0) {
  1878. return;
  1879. }
  1880. GGML_ASSERT(ggml_can_repeat(src0, dst));
  1881. GGML_TENSOR_UNARY_OP_LOCALS
  1882. // guaranteed to be an integer due to the check in ggml_can_repeat
  1883. const int nr0 = (int)(ne0/ne00);
  1884. const int nr1 = (int)(ne1/ne01);
  1885. const int nr2 = (int)(ne2/ne02);
  1886. const int nr3 = (int)(ne3/ne03);
  1887. // TODO: support for transposed / permuted tensors
  1888. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  1889. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1890. // TODO: maybe this is not optimal?
  1891. for (int i3 = 0; i3 < nr3; i3++) {
  1892. for (int k3 = 0; k3 < ne03; k3++) {
  1893. for (int i2 = 0; i2 < nr2; i2++) {
  1894. for (int k2 = 0; k2 < ne02; k2++) {
  1895. for (int i1 = 0; i1 < nr1; i1++) {
  1896. for (int k1 = 0; k1 < ne01; k1++) {
  1897. for (int i0 = 0; i0 < nr0; i0++) {
  1898. ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
  1899. ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
  1900. // ggml_vec_cpy_f16(ne00, y, x)
  1901. for (int i = 0; i < ne00; ++i) {
  1902. y[i] = x[i];
  1903. }
  1904. }
  1905. }
  1906. }
  1907. }
  1908. }
  1909. }
  1910. }
  1911. }
  1912. void ggml_compute_forward_repeat(
  1913. const ggml_compute_params * params,
  1914. ggml_tensor * dst) {
  1915. const ggml_tensor * src0 = dst->src[0];
  1916. switch (src0->type) {
  1917. case GGML_TYPE_F16:
  1918. case GGML_TYPE_BF16:
  1919. case GGML_TYPE_I16:
  1920. {
  1921. ggml_compute_forward_repeat_f16(params, dst);
  1922. } break;
  1923. case GGML_TYPE_F32:
  1924. case GGML_TYPE_I32:
  1925. {
  1926. ggml_compute_forward_repeat_f32(params, dst);
  1927. } break;
  1928. // TODO: templateify the implemenation and support for I64
  1929. // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
  1930. //case GGML_TYPE_I64:
  1931. // {
  1932. // ggml_compute_forward_repeat_i64(params, dst);
  1933. // } break;
  1934. default:
  1935. {
  1936. GGML_ABORT("fatal error");
  1937. }
  1938. }
  1939. }
  1940. // ggml_compute_forward_repeat_back
  1941. static void ggml_compute_forward_repeat_back_f32(
  1942. const ggml_compute_params * params,
  1943. ggml_tensor * dst) {
  1944. const ggml_tensor * src0 = dst->src[0];
  1945. if (params->ith != 0) {
  1946. return;
  1947. }
  1948. GGML_ASSERT(ggml_can_repeat(dst, src0));
  1949. GGML_TENSOR_UNARY_OP_LOCALS
  1950. // guaranteed to be an integer due to the check in ggml_can_repeat
  1951. const int nr0 = (int)(ne00/ne0);
  1952. const int nr1 = (int)(ne01/ne1);
  1953. const int nr2 = (int)(ne02/ne2);
  1954. const int nr3 = (int)(ne03/ne3);
  1955. // TODO: support for transposed / permuted tensors
  1956. GGML_ASSERT(nb0 == sizeof(float));
  1957. GGML_ASSERT(nb00 == sizeof(float));
  1958. if (ggml_is_contiguous(dst)) {
  1959. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  1960. } else {
  1961. for (int k3 = 0; k3 < ne3; k3++) {
  1962. for (int k2 = 0; k2 < ne2; k2++) {
  1963. for (int k1 = 0; k1 < ne1; k1++) {
  1964. ggml_vec_set_f32(ne0,
  1965. (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
  1966. 0);
  1967. }
  1968. }
  1969. }
  1970. }
  1971. // TODO: maybe this is not optimal?
  1972. for (int i3 = 0; i3 < nr3; i3++) {
  1973. for (int k3 = 0; k3 < ne3; k3++) {
  1974. for (int i2 = 0; i2 < nr2; i2++) {
  1975. for (int k2 = 0; k2 < ne2; k2++) {
  1976. for (int i1 = 0; i1 < nr1; i1++) {
  1977. for (int k1 = 0; k1 < ne1; k1++) {
  1978. for (int i0 = 0; i0 < nr0; i0++) {
  1979. ggml_vec_acc_f32(ne0,
  1980. (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
  1981. (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
  1982. }
  1983. }
  1984. }
  1985. }
  1986. }
  1987. }
  1988. }
  1989. }
  1990. void ggml_compute_forward_repeat_back(
  1991. const ggml_compute_params * params,
  1992. ggml_tensor * dst) {
  1993. const ggml_tensor * src0 = dst->src[0];
  1994. switch (src0->type) {
  1995. case GGML_TYPE_F32:
  1996. {
  1997. ggml_compute_forward_repeat_back_f32(params, dst);
  1998. } break;
  1999. default:
  2000. {
  2001. GGML_ABORT("fatal error");
  2002. }
  2003. }
  2004. }
  2005. // ggml_compute_forward_concat
  2006. static void ggml_compute_forward_concat_any(
  2007. const ggml_compute_params * params,
  2008. ggml_tensor * dst) {
  2009. const ggml_tensor * src0 = dst->src[0];
  2010. const ggml_tensor * src1 = dst->src[1];
  2011. const size_t len = ggml_type_size(src0->type);
  2012. const int ith = params->ith;
  2013. const int nth = params->nth;
  2014. GGML_TENSOR_BINARY_OP_LOCALS
  2015. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2016. GGML_ASSERT(dim >= 0 && dim < 4);
  2017. int64_t o[4] = {0, 0, 0, 0};
  2018. o[dim] = src0->ne[dim];
  2019. const char * x;
  2020. // TODO: smarter multi-theading
  2021. for (int i3 = 0; i3 < ne3; i3++) {
  2022. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2023. for (int i1 = 0; i1 < ne1; i1++) {
  2024. for (int i0 = 0; i0 < ne0; i0++) {
  2025. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2026. x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
  2027. } else {
  2028. x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
  2029. }
  2030. char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
  2031. memcpy(y, x, len);
  2032. }
  2033. }
  2034. }
  2035. }
  2036. }
  2037. static void ggml_compute_forward_concat_i8(
  2038. const ggml_compute_params * params,
  2039. ggml_tensor * dst) {
  2040. const ggml_tensor * src0 = dst->src[0];
  2041. const ggml_tensor * src1 = dst->src[1];
  2042. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
  2043. const int ith = params->ith;
  2044. const int nth = params->nth;
  2045. GGML_TENSOR_BINARY_OP_LOCALS
  2046. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2047. GGML_ASSERT(dim >= 0 && dim < 4);
  2048. int64_t o[4] = {0, 0, 0, 0};
  2049. o[dim] = src0->ne[dim];
  2050. const int8_t * x;
  2051. // TODO: smarter multi-theading
  2052. for (int i3 = 0; i3 < ne3; i3++) {
  2053. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2054. for (int i1 = 0; i1 < ne1; i1++) {
  2055. for (int i0 = 0; i0 < ne0; i0++) {
  2056. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2057. x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  2058. } else {
  2059. x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  2060. }
  2061. int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  2062. *y = *x;
  2063. }
  2064. }
  2065. }
  2066. }
  2067. }
  2068. static void ggml_compute_forward_concat_f16(
  2069. const ggml_compute_params * params,
  2070. ggml_tensor * dst) {
  2071. const ggml_tensor * src0 = dst->src[0];
  2072. const ggml_tensor * src1 = dst->src[1];
  2073. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
  2074. const int ith = params->ith;
  2075. const int nth = params->nth;
  2076. GGML_TENSOR_BINARY_OP_LOCALS
  2077. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2078. GGML_ASSERT(dim >= 0 && dim < 4);
  2079. int64_t o[4] = {0, 0, 0, 0};
  2080. o[dim] = src0->ne[dim];
  2081. const ggml_fp16_t * x;
  2082. // TODO: smarter multi-theading
  2083. for (int i3 = 0; i3 < ne3; i3++) {
  2084. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2085. for (int i1 = 0; i1 < ne1; i1++) {
  2086. for (int i0 = 0; i0 < ne0; i0++) {
  2087. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2088. x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  2089. } else {
  2090. x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  2091. }
  2092. ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  2093. *y = *x;
  2094. }
  2095. }
  2096. }
  2097. }
  2098. }
  2099. static void ggml_compute_forward_concat_f32(
  2100. const ggml_compute_params * params,
  2101. ggml_tensor * dst) {
  2102. const ggml_tensor * src0 = dst->src[0];
  2103. const ggml_tensor * src1 = dst->src[1];
  2104. GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
  2105. const int ith = params->ith;
  2106. const int nth = params->nth;
  2107. GGML_TENSOR_BINARY_OP_LOCALS
  2108. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  2109. GGML_ASSERT(dim >= 0 && dim < 4);
  2110. int64_t o[4] = {0, 0, 0, 0};
  2111. o[dim] = src0->ne[dim];
  2112. const float * x;
  2113. // TODO: smarter multi-theading
  2114. for (int i3 = 0; i3 < ne3; i3++) {
  2115. for (int i2 = ith; i2 < ne2; i2 += nth) {
  2116. for (int i1 = 0; i1 < ne1; i1++) {
  2117. for (int i0 = 0; i0 < ne0; i0++) {
  2118. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  2119. x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  2120. } else {
  2121. x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  2122. }
  2123. float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  2124. *y = *x;
  2125. }
  2126. }
  2127. }
  2128. }
  2129. }
  2130. void ggml_compute_forward_concat(
  2131. const ggml_compute_params * params,
  2132. ggml_tensor * dst) {
  2133. const ggml_tensor * src0 = dst->src[0];
  2134. switch (src0->type) {
  2135. case GGML_TYPE_F16:
  2136. case GGML_TYPE_BF16:
  2137. case GGML_TYPE_I16:
  2138. {
  2139. ggml_compute_forward_concat_f16(params, dst);
  2140. } break;
  2141. case GGML_TYPE_I8:
  2142. {
  2143. ggml_compute_forward_concat_i8(params, dst);
  2144. } break;
  2145. case GGML_TYPE_F32:
  2146. case GGML_TYPE_I32:
  2147. {
  2148. ggml_compute_forward_concat_f32(params, dst);
  2149. } break;
  2150. default:
  2151. {
  2152. ggml_compute_forward_concat_any(params, dst);
  2153. }
  2154. }
  2155. }
  2156. // ggml_compute_forward_gelu
  2157. static void ggml_compute_forward_gelu_f32(
  2158. const ggml_compute_params * params,
  2159. ggml_tensor * dst) {
  2160. const ggml_tensor * src0 = dst->src[0];
  2161. assert(ggml_is_contiguous_1(src0));
  2162. assert(ggml_is_contiguous_1(dst));
  2163. assert(ggml_are_same_shape(src0, dst));
  2164. const int ith = params->ith;
  2165. const int nth = params->nth;
  2166. const int nc = src0->ne[0];
  2167. const int nr = ggml_nrows(src0);
  2168. // rows per thread
  2169. const int dr = (nr + nth - 1)/nth;
  2170. // row range for this thread
  2171. const int ir0 = dr*ith;
  2172. const int ir1 = MIN(ir0 + dr, nr);
  2173. for (int i1 = ir0; i1 < ir1; i1++) {
  2174. ggml_vec_gelu_f32(nc,
  2175. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2176. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2177. #ifndef NDEBUG
  2178. for (int k = 0; k < nc; k++) {
  2179. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2180. GGML_UNUSED(x);
  2181. assert(!isnan(x));
  2182. assert(!isinf(x));
  2183. }
  2184. #endif
  2185. }
  2186. }
  2187. static void ggml_compute_forward_gelu_f16(
  2188. const ggml_compute_params * params,
  2189. ggml_tensor * dst) {
  2190. const ggml_tensor * src0 = dst->src[0];
  2191. assert(ggml_is_contiguous_1(src0));
  2192. assert(ggml_is_contiguous_1(dst));
  2193. assert(ggml_are_same_shape(src0, dst));
  2194. const int ith = params->ith;
  2195. const int nth = params->nth;
  2196. const int nc = src0->ne[0];
  2197. const int nr = ggml_nrows(src0);
  2198. // rows per thread
  2199. const int dr = (nr + nth - 1)/nth;
  2200. // row range for this thread
  2201. const int ir0 = dr*ith;
  2202. const int ir1 = MIN(ir0 + dr, nr);
  2203. for (int i1 = ir0; i1 < ir1; i1++) {
  2204. ggml_vec_gelu_f16(nc,
  2205. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2206. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2207. #ifndef NDEBUG
  2208. for (int k = 0; k < nc; k++) {
  2209. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2210. const float v = GGML_CPU_FP16_TO_FP32(x);
  2211. GGML_UNUSED(v);
  2212. assert(!isnan(v));
  2213. assert(!isinf(v));
  2214. }
  2215. #endif
  2216. }
  2217. }
  2218. static void ggml_compute_forward_gelu(
  2219. const ggml_compute_params * params,
  2220. ggml_tensor * dst) {
  2221. const ggml_tensor * src0 = dst->src[0];
  2222. switch (src0->type) {
  2223. case GGML_TYPE_F32:
  2224. {
  2225. ggml_compute_forward_gelu_f32(params, dst);
  2226. } break;
  2227. case GGML_TYPE_F16:
  2228. {
  2229. ggml_compute_forward_gelu_f16(params, dst);
  2230. } break;
  2231. default:
  2232. {
  2233. GGML_ABORT("fatal error");
  2234. }
  2235. }
  2236. }
  2237. // ggml_compute_forward_gelu_erf
  2238. static void ggml_compute_forward_gelu_erf_f32(
  2239. const ggml_compute_params * params,
  2240. ggml_tensor * dst) {
  2241. const ggml_tensor * src0 = dst->src[0];
  2242. assert(ggml_is_contiguous_1(src0));
  2243. assert(ggml_is_contiguous_1(dst));
  2244. assert(ggml_are_same_shape(src0, dst));
  2245. const int ith = params->ith;
  2246. const int nth = params->nth;
  2247. const int nc = src0->ne[0];
  2248. const int nr = ggml_nrows(src0);
  2249. // rows per thread
  2250. const int dr = (nr + nth - 1)/nth;
  2251. // row range for this thread
  2252. const int ir0 = dr*ith;
  2253. const int ir1 = MIN(ir0 + dr, nr);
  2254. for (int i1 = ir0; i1 < ir1; i1++) {
  2255. ggml_vec_gelu_erf_f32(nc,
  2256. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2257. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2258. #ifndef NDEBUG
  2259. for (int k = 0; k < nc; k++) {
  2260. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2261. GGML_UNUSED(x);
  2262. assert(!isnan(x));
  2263. assert(!isinf(x));
  2264. }
  2265. #endif
  2266. }
  2267. }
  2268. static void ggml_compute_forward_gelu_erf_f16(
  2269. const ggml_compute_params * params,
  2270. ggml_tensor * dst) {
  2271. const ggml_tensor * src0 = dst->src[0];
  2272. assert(ggml_is_contiguous_1(src0));
  2273. assert(ggml_is_contiguous_1(dst));
  2274. assert(ggml_are_same_shape(src0, dst));
  2275. const int ith = params->ith;
  2276. const int nth = params->nth;
  2277. const int nc = src0->ne[0];
  2278. const int nr = ggml_nrows(src0);
  2279. // rows per thread
  2280. const int dr = (nr + nth - 1)/nth;
  2281. // row range for this thread
  2282. const int ir0 = dr*ith;
  2283. const int ir1 = MIN(ir0 + dr, nr);
  2284. for (int i1 = ir0; i1 < ir1; i1++) {
  2285. ggml_vec_gelu_erf_f16(nc,
  2286. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2287. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2288. #ifndef NDEBUG
  2289. for (int k = 0; k < nc; k++) {
  2290. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2291. const float v = GGML_CPU_FP16_TO_FP32(x);
  2292. GGML_UNUSED(v);
  2293. assert(!isnan(v));
  2294. assert(!isinf(v));
  2295. }
  2296. #endif
  2297. }
  2298. }
  2299. static void ggml_compute_forward_gelu_erf(
  2300. const ggml_compute_params * params,
  2301. ggml_tensor * dst) {
  2302. const ggml_tensor * src0 = dst->src[0];
  2303. switch (src0->type) {
  2304. case GGML_TYPE_F32:
  2305. {
  2306. ggml_compute_forward_gelu_erf_f32(params, dst);
  2307. } break;
  2308. case GGML_TYPE_F16:
  2309. {
  2310. ggml_compute_forward_gelu_erf_f16(params, dst);
  2311. } break;
  2312. default:
  2313. {
  2314. GGML_ABORT("fatal error");
  2315. }
  2316. }
  2317. }
  2318. // ggml_compute_forward_gelu_quick
  2319. static void ggml_compute_forward_gelu_quick_f32(
  2320. const ggml_compute_params * params,
  2321. ggml_tensor * dst) {
  2322. const ggml_tensor * src0 = dst->src[0];
  2323. assert(ggml_is_contiguous_1(src0));
  2324. assert(ggml_is_contiguous_1(dst));
  2325. assert(ggml_are_same_shape(src0, dst));
  2326. const int ith = params->ith;
  2327. const int nth = params->nth;
  2328. const int nc = src0->ne[0];
  2329. const int nr = ggml_nrows(src0);
  2330. // rows per thread
  2331. const int dr = (nr + nth - 1)/nth;
  2332. // row range for this thread
  2333. const int ir0 = dr*ith;
  2334. const int ir1 = MIN(ir0 + dr, nr);
  2335. for (int i1 = ir0; i1 < ir1; i1++) {
  2336. ggml_vec_gelu_quick_f32(nc,
  2337. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2338. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2339. #ifndef NDEBUG
  2340. for (int k = 0; k < nc; k++) {
  2341. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2342. GGML_UNUSED(x);
  2343. assert(!isnan(x));
  2344. assert(!isinf(x));
  2345. }
  2346. #endif
  2347. }
  2348. }
  2349. static void ggml_compute_forward_gelu_quick_f16(
  2350. const ggml_compute_params * params,
  2351. ggml_tensor * dst) {
  2352. const ggml_tensor * src0 = dst->src[0];
  2353. assert(ggml_is_contiguous_1(src0));
  2354. assert(ggml_is_contiguous_1(dst));
  2355. assert(ggml_are_same_shape(src0, dst));
  2356. const int ith = params->ith;
  2357. const int nth = params->nth;
  2358. const int nc = src0->ne[0];
  2359. const int nr = ggml_nrows(src0);
  2360. // rows per thread
  2361. const int dr = (nr + nth - 1)/nth;
  2362. // row range for this thread
  2363. const int ir0 = dr*ith;
  2364. const int ir1 = MIN(ir0 + dr, nr);
  2365. for (int i1 = ir0; i1 < ir1; i1++) {
  2366. ggml_vec_gelu_quick_f16(nc,
  2367. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2368. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2369. #ifndef NDEBUG
  2370. for (int k = 0; k < nc; k++) {
  2371. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2372. const float v = GGML_CPU_FP16_TO_FP32(x);
  2373. GGML_UNUSED(v);
  2374. assert(!isnan(v));
  2375. assert(!isinf(v));
  2376. }
  2377. #endif
  2378. }
  2379. }
  2380. static void ggml_compute_forward_gelu_quick(
  2381. const ggml_compute_params * params,
  2382. ggml_tensor * dst) {
  2383. const ggml_tensor * src0 = dst->src[0];
  2384. switch (src0->type) {
  2385. case GGML_TYPE_F32:
  2386. {
  2387. ggml_compute_forward_gelu_quick_f32(params, dst);
  2388. } break;
  2389. case GGML_TYPE_F16:
  2390. {
  2391. ggml_compute_forward_gelu_quick_f16(params, dst);
  2392. } break;
  2393. default:
  2394. {
  2395. GGML_ABORT("fatal error");
  2396. }
  2397. }
  2398. }
  2399. // ggml_compute_forward_silu
  2400. static void ggml_compute_forward_silu_f32(
  2401. const ggml_compute_params * params,
  2402. ggml_tensor * dst) {
  2403. const ggml_tensor * src0 = dst->src[0];
  2404. assert(ggml_is_contiguous_1(src0));
  2405. assert(ggml_is_contiguous_1(dst));
  2406. assert(ggml_are_same_shape(src0, dst));
  2407. const int ith = params->ith;
  2408. const int nth = params->nth;
  2409. const int nc = src0->ne[0];
  2410. const int nr = ggml_nrows(src0);
  2411. // rows per thread
  2412. const int dr = (nr + nth - 1)/nth;
  2413. // row range for this thread
  2414. const int ir0 = dr*ith;
  2415. const int ir1 = MIN(ir0 + dr, nr);
  2416. for (int i1 = ir0; i1 < ir1; i1++) {
  2417. ggml_vec_silu_f32(nc,
  2418. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2419. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  2420. #ifndef NDEBUG
  2421. for (int k = 0; k < nc; k++) {
  2422. const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  2423. GGML_UNUSED(x);
  2424. assert(!isnan(x));
  2425. assert(!isinf(x));
  2426. }
  2427. #endif
  2428. }
  2429. }
  2430. static void ggml_compute_forward_silu_f16(
  2431. const ggml_compute_params * params,
  2432. ggml_tensor * dst) {
  2433. const ggml_tensor * src0 = dst->src[0];
  2434. assert(ggml_is_contiguous_1(src0));
  2435. assert(ggml_is_contiguous_1(dst));
  2436. assert(ggml_are_same_shape(src0, dst));
  2437. const int ith = params->ith;
  2438. const int nth = params->nth;
  2439. const int nc = src0->ne[0];
  2440. const int nr = ggml_nrows(src0);
  2441. // rows per thread
  2442. const int dr = (nr + nth - 1)/nth;
  2443. // row range for this thread
  2444. const int ir0 = dr*ith;
  2445. const int ir1 = MIN(ir0 + dr, nr);
  2446. for (int i1 = ir0; i1 < ir1; i1++) {
  2447. ggml_vec_silu_f16(nc,
  2448. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2449. (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
  2450. #ifndef NDEBUG
  2451. for (int k = 0; k < nc; k++) {
  2452. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  2453. const float v = GGML_CPU_FP16_TO_FP32(x);
  2454. GGML_UNUSED(v);
  2455. assert(!isnan(v));
  2456. assert(!isinf(v));
  2457. }
  2458. #endif
  2459. }
  2460. }
  2461. static void ggml_compute_forward_silu(
  2462. const ggml_compute_params * params,
  2463. ggml_tensor * dst) {
  2464. const ggml_tensor * src0 = dst->src[0];
  2465. switch (src0->type) {
  2466. case GGML_TYPE_F32:
  2467. {
  2468. ggml_compute_forward_silu_f32(params, dst);
  2469. } break;
  2470. case GGML_TYPE_F16:
  2471. {
  2472. ggml_compute_forward_silu_f16(params, dst);
  2473. } break;
  2474. default:
  2475. {
  2476. GGML_ABORT("fatal error");
  2477. }
  2478. }
  2479. }
  2480. // ggml_compute_forward_leaky_relu
  2481. static void ggml_compute_forward_leaky_relu_f32(
  2482. const ggml_compute_params * params,
  2483. ggml_tensor * dst) {
  2484. const ggml_tensor * src0 = dst->src[0];
  2485. if (params->ith != 0) {
  2486. return;
  2487. }
  2488. assert(ggml_is_contiguous_1(src0));
  2489. assert(ggml_is_contiguous_1(dst));
  2490. assert(ggml_are_same_shape(src0, dst));
  2491. const int n = ggml_nrows(src0);
  2492. const int nc = src0->ne[0];
  2493. float negative_slope;
  2494. memcpy(&negative_slope, dst->op_params, sizeof(float));
  2495. assert(dst->nb[0] == sizeof(float));
  2496. assert(src0->nb[0] == sizeof(float));
  2497. for (int i = 0; i < n; i++) {
  2498. ggml_vec_leaky_relu_f32(nc,
  2499. (float *) ((char *) dst->data + i*( dst->nb[1])),
  2500. (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  2501. }
  2502. }
  2503. static void ggml_compute_forward_leaky_relu_f16(
  2504. const ggml_compute_params * params,
  2505. ggml_tensor * dst) {
  2506. const ggml_tensor * src0 = dst->src[0];
  2507. if (params->ith != 0) {
  2508. return;
  2509. }
  2510. assert(ggml_is_contiguous_1(src0));
  2511. assert(ggml_is_contiguous_1(dst));
  2512. assert(ggml_are_same_shape(src0, dst));
  2513. const int n = ggml_nrows(src0);
  2514. const int nc = src0->ne[0];
  2515. float negative_slope;
  2516. memcpy(&negative_slope, dst->op_params, sizeof(float));
  2517. assert(dst->nb[0] == sizeof(ggml_fp16_t));
  2518. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  2519. for (int i = 0; i < n; i++) {
  2520. ggml_vec_leaky_relu_f16(nc,
  2521. (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
  2522. (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  2523. }
  2524. }
  2525. void ggml_compute_forward_leaky_relu(
  2526. const ggml_compute_params * params,
  2527. ggml_tensor * dst) {
  2528. const ggml_tensor * src0 = dst->src[0];
  2529. switch (src0->type) {
  2530. case GGML_TYPE_F32:
  2531. {
  2532. ggml_compute_forward_leaky_relu_f32(params, dst);
  2533. } break;
  2534. case GGML_TYPE_F16:
  2535. {
  2536. ggml_compute_forward_leaky_relu_f16(params, dst);
  2537. } break;
  2538. default:
  2539. {
  2540. GGML_ABORT("fatal error");
  2541. }
  2542. }
  2543. }
  2544. // ggml_compute_forward_silu_back
  2545. static void ggml_compute_forward_silu_back_f32(
  2546. const ggml_compute_params * params,
  2547. ggml_tensor * dst) {
  2548. const ggml_tensor * grad = dst->src[0];
  2549. const ggml_tensor * src1 = dst->src[1];
  2550. assert(ggml_is_contiguous_1(grad));
  2551. assert(ggml_is_contiguous_1(src1));
  2552. assert(ggml_is_contiguous_1(dst));
  2553. assert(ggml_are_same_shape(src1, dst));
  2554. assert(ggml_are_same_shape(src1, grad));
  2555. const int ith = params->ith;
  2556. const int nth = params->nth;
  2557. const int nc = src1->ne[0];
  2558. const int nr = ggml_nrows(src1);
  2559. // rows per thread
  2560. const int dr = (nr + nth - 1)/nth;
  2561. // row range for this thread
  2562. const int ir0 = dr*ith;
  2563. const int ir1 = MIN(ir0 + dr, nr);
  2564. for (int i1 = ir0; i1 < ir1; i1++) {
  2565. ggml_vec_silu_backward_f32(nc,
  2566. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  2567. (float *) ((char *) src1->data + i1*(src1->nb[1])),
  2568. (float *) ((char *) grad->data + i1*(grad->nb[1])));
  2569. #ifndef NDEBUG
  2570. for (int k = 0; k < nc; k++) {
  2571. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2572. GGML_UNUSED(x);
  2573. assert(!isnan(x));
  2574. assert(!isinf(x));
  2575. }
  2576. #endif
  2577. }
  2578. }
  2579. static void ggml_compute_forward_silu_back_f16(
  2580. const ggml_compute_params * params,
  2581. ggml_tensor * dst) {
  2582. const ggml_tensor * grad = dst->src[0];
  2583. const ggml_tensor * src1 = dst->src[1];
  2584. assert(ggml_is_contiguous_1(grad));
  2585. assert(ggml_is_contiguous_1(src1));
  2586. assert(ggml_is_contiguous_1(dst));
  2587. assert(ggml_are_same_shape(src1, dst));
  2588. assert(ggml_are_same_shape(src1, grad));
  2589. const int ith = params->ith;
  2590. const int nth = params->nth;
  2591. const int nc = src1->ne[0];
  2592. const int nr = ggml_nrows(src1);
  2593. // rows per thread
  2594. const int dr = (nr + nth - 1)/nth;
  2595. // row range for this thread
  2596. const int ir0 = dr*ith;
  2597. const int ir1 = MIN(ir0 + dr, nr);
  2598. for (int i1 = ir0; i1 < ir1; i1++) {
  2599. ggml_vec_silu_backward_f16(nc,
  2600. (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
  2601. (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
  2602. (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
  2603. #ifndef NDEBUG
  2604. for (int k = 0; k < nc; k++) {
  2605. const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2606. const float v = GGML_CPU_FP16_TO_FP32(x);
  2607. GGML_UNUSED(v);
  2608. assert(!isnan(v));
  2609. assert(!isinf(v));
  2610. }
  2611. #endif
  2612. }
  2613. }
  2614. void ggml_compute_forward_silu_back(
  2615. const ggml_compute_params * params,
  2616. ggml_tensor * dst) {
  2617. const ggml_tensor * src0 = dst->src[0];
  2618. switch (src0->type) {
  2619. case GGML_TYPE_F32:
  2620. {
  2621. ggml_compute_forward_silu_back_f32(params, dst);
  2622. } break;
  2623. case GGML_TYPE_F16:
  2624. {
  2625. ggml_compute_forward_silu_back_f16(params, dst);
  2626. } break;
  2627. default:
  2628. {
  2629. GGML_ABORT("fatal error");
  2630. }
  2631. }
  2632. }
  2633. // ggml_compute_forward_reglu
  2634. static void ggml_compute_forward_reglu_f32(
  2635. const ggml_compute_params * params,
  2636. ggml_tensor * dst) {
  2637. const ggml_tensor * src0 = dst->src[0];
  2638. const ggml_tensor * src1 = dst->src[1];
  2639. char * src0_d = (char *) src0->data;
  2640. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2641. const size_t src0_o = src0->nb[1];
  2642. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2643. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2644. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2645. if (src1) {
  2646. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2647. GGML_ASSERT(src0->type == src1->type);
  2648. }
  2649. const int ith = params->ith;
  2650. const int nth = params->nth;
  2651. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2652. const int nr = ggml_nrows(src0);
  2653. GGML_ASSERT(dst->ne[0] == nc);
  2654. GGML_ASSERT(ggml_nrows(dst) == nr);
  2655. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2656. // rows per thread
  2657. const int dr = (nr + nth - 1)/nth;
  2658. // row range for this thread
  2659. const int ir0 = dr*ith;
  2660. const int ir1 = MIN(ir0 + dr, nr);
  2661. for (int i1 = ir0; i1 < ir1; i1++) {
  2662. float * src0_p = (float *) (src0_d + i1*src0_o);
  2663. float * src1_p = (float *) (src1_d + i1*src1_o);
  2664. if (!src1) {
  2665. src0_p += swapped ? nc : 0;
  2666. src1_p += swapped ? 0 : nc;
  2667. }
  2668. ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2669. #ifndef NDEBUG
  2670. for (int k = 0; k < nc; k++) {
  2671. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2672. GGML_UNUSED(x);
  2673. assert(!isnan(x));
  2674. assert(!isinf(x));
  2675. }
  2676. #endif
  2677. }
  2678. }
  2679. static void ggml_compute_forward_reglu_f16(
  2680. const ggml_compute_params * params,
  2681. ggml_tensor * dst) {
  2682. const ggml_tensor * src0 = dst->src[0];
  2683. const ggml_tensor * src1 = dst->src[1];
  2684. char * src0_d = (char *) src0->data;
  2685. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2686. const size_t src0_o = src0->nb[1];
  2687. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2688. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2689. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2690. if (src1) {
  2691. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2692. GGML_ASSERT(src0->type == src1->type);
  2693. }
  2694. const int ith = params->ith;
  2695. const int nth = params->nth;
  2696. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2697. const int nr = ggml_nrows(src0);
  2698. GGML_ASSERT(dst->ne[0] == nc);
  2699. GGML_ASSERT(ggml_nrows(dst) == nr);
  2700. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2701. // rows per thread
  2702. const int dr = (nr + nth - 1)/nth;
  2703. // row range for this thread
  2704. const int ir0 = dr*ith;
  2705. const int ir1 = MIN(ir0 + dr, nr);
  2706. for (int i1 = ir0; i1 < ir1; i1++) {
  2707. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2708. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2709. if (!src1) {
  2710. src0_p += swapped ? nc : 0;
  2711. src1_p += swapped ? 0 : nc;
  2712. }
  2713. ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2714. #ifndef NDEBUG
  2715. for (int k = 0; k < nc; k++) {
  2716. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2717. const float v = GGML_FP16_TO_FP32(x);
  2718. GGML_UNUSED(v);
  2719. assert(!isnan(v));
  2720. assert(!isinf(v));
  2721. }
  2722. #endif
  2723. }
  2724. }
  2725. static void ggml_compute_forward_reglu(
  2726. const ggml_compute_params * params,
  2727. ggml_tensor * dst) {
  2728. const ggml_tensor * src0 = dst->src[0];
  2729. switch (src0->type) {
  2730. case GGML_TYPE_F32:
  2731. {
  2732. ggml_compute_forward_reglu_f32(params, dst);
  2733. } break;
  2734. case GGML_TYPE_F16:
  2735. {
  2736. ggml_compute_forward_reglu_f16(params, dst);
  2737. } break;
  2738. default:
  2739. {
  2740. GGML_ABORT("fatal error");
  2741. }
  2742. }
  2743. }
  2744. // ggml_compute_forward_geglu
  2745. static void ggml_compute_forward_geglu_f32(
  2746. const ggml_compute_params * params,
  2747. ggml_tensor * dst) {
  2748. const ggml_tensor * src0 = dst->src[0];
  2749. const ggml_tensor * src1 = dst->src[1];
  2750. char * src0_d = (char *) src0->data;
  2751. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2752. const size_t src0_o = src0->nb[1];
  2753. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2754. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2755. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2756. if (src1) {
  2757. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2758. GGML_ASSERT(src0->type == src1->type);
  2759. }
  2760. const int ith = params->ith;
  2761. const int nth = params->nth;
  2762. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2763. const int nr = ggml_nrows(src0);
  2764. GGML_ASSERT(dst->ne[0] == nc);
  2765. GGML_ASSERT(ggml_nrows(dst) == nr);
  2766. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2767. // rows per thread
  2768. const int dr = (nr + nth - 1)/nth;
  2769. // row range for this thread
  2770. const int ir0 = dr*ith;
  2771. const int ir1 = MIN(ir0 + dr, nr);
  2772. for (int i1 = ir0; i1 < ir1; i1++) {
  2773. float * src0_p = (float *) (src0_d + i1*src0_o);
  2774. float * src1_p = (float *) (src1_d + i1*src1_o);
  2775. if (!src1) {
  2776. src0_p += swapped ? nc : 0;
  2777. src1_p += swapped ? 0 : nc;
  2778. }
  2779. ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2780. #ifndef NDEBUG
  2781. for (int k = 0; k < nc; k++) {
  2782. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2783. GGML_UNUSED(x);
  2784. assert(!isnan(x));
  2785. assert(!isinf(x));
  2786. }
  2787. #endif
  2788. }
  2789. }
  2790. static void ggml_compute_forward_geglu_f16(
  2791. const ggml_compute_params * params,
  2792. ggml_tensor * dst) {
  2793. const ggml_tensor * src0 = dst->src[0];
  2794. const ggml_tensor * src1 = dst->src[1];
  2795. char * src0_d = (char *) src0->data;
  2796. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2797. const size_t src0_o = src0->nb[1];
  2798. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2799. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2800. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2801. if (src1) {
  2802. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2803. GGML_ASSERT(src0->type == src1->type);
  2804. }
  2805. const int ith = params->ith;
  2806. const int nth = params->nth;
  2807. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2808. const int nr = ggml_nrows(src0);
  2809. GGML_ASSERT(dst->ne[0] == nc);
  2810. GGML_ASSERT(ggml_nrows(dst) == nr);
  2811. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2812. // rows per thread
  2813. const int dr = (nr + nth - 1)/nth;
  2814. // row range for this thread
  2815. const int ir0 = dr*ith;
  2816. const int ir1 = MIN(ir0 + dr, nr);
  2817. for (int i1 = ir0; i1 < ir1; i1++) {
  2818. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2819. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2820. if (!src1) {
  2821. src0_p += swapped ? nc : 0;
  2822. src1_p += swapped ? 0 : nc;
  2823. }
  2824. ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2825. #ifndef NDEBUG
  2826. for (int k = 0; k < nc; k++) {
  2827. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2828. const float v = GGML_FP16_TO_FP32(x);
  2829. GGML_UNUSED(v);
  2830. assert(!isnan(v));
  2831. assert(!isinf(v));
  2832. }
  2833. #endif
  2834. }
  2835. }
  2836. static void ggml_compute_forward_geglu(
  2837. const ggml_compute_params * params,
  2838. ggml_tensor * dst) {
  2839. const ggml_tensor * src0 = dst->src[0];
  2840. switch (src0->type) {
  2841. case GGML_TYPE_F32:
  2842. {
  2843. ggml_compute_forward_geglu_f32(params, dst);
  2844. } break;
  2845. case GGML_TYPE_F16:
  2846. {
  2847. ggml_compute_forward_geglu_f16(params, dst);
  2848. } break;
  2849. default:
  2850. {
  2851. GGML_ABORT("fatal error");
  2852. }
  2853. }
  2854. }
  2855. // ggml_compute_forward_swiglu
  2856. static void ggml_compute_forward_swiglu_f32(
  2857. const ggml_compute_params * params,
  2858. ggml_tensor * dst) {
  2859. const ggml_tensor * src0 = dst->src[0];
  2860. const ggml_tensor * src1 = dst->src[1];
  2861. char * src0_d = (char *) src0->data;
  2862. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2863. const size_t src0_o = src0->nb[1];
  2864. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2865. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2866. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2867. if (src1) {
  2868. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2869. GGML_ASSERT(src0->type == src1->type);
  2870. }
  2871. const int ith = params->ith;
  2872. const int nth = params->nth;
  2873. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2874. const int nr = ggml_nrows(src0);
  2875. GGML_ASSERT(dst->ne[0] == nc);
  2876. GGML_ASSERT(ggml_nrows(dst) == nr);
  2877. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2878. // rows per thread
  2879. const int dr = (nr + nth - 1)/nth;
  2880. // row range for this thread
  2881. const int ir0 = dr*ith;
  2882. const int ir1 = MIN(ir0 + dr, nr);
  2883. for (int i1 = ir0; i1 < ir1; i1++) {
  2884. float * src0_p = (float *) (src0_d + i1*src0_o);
  2885. float * src1_p = (float *) (src1_d + i1*src1_o);
  2886. if (!src1) {
  2887. src0_p += swapped ? nc : 0;
  2888. src1_p += swapped ? 0 : nc;
  2889. }
  2890. ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2891. #ifndef NDEBUG
  2892. for (int k = 0; k < nc; k++) {
  2893. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2894. GGML_UNUSED(x);
  2895. assert(!isnan(x));
  2896. assert(!isinf(x));
  2897. }
  2898. #endif
  2899. }
  2900. }
  2901. static void ggml_compute_forward_swiglu_f16(
  2902. const ggml_compute_params * params,
  2903. ggml_tensor * dst) {
  2904. const ggml_tensor * src0 = dst->src[0];
  2905. const ggml_tensor * src1 = dst->src[1];
  2906. char * src0_d = (char *) src0->data;
  2907. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2908. const size_t src0_o = src0->nb[1];
  2909. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2910. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2911. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2912. if (src1) {
  2913. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2914. GGML_ASSERT(src0->type == src1->type);
  2915. }
  2916. const int ith = params->ith;
  2917. const int nth = params->nth;
  2918. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2919. const int nr = ggml_nrows(src0);
  2920. GGML_ASSERT(dst->ne[0] == nc);
  2921. GGML_ASSERT(ggml_nrows(dst) == nr);
  2922. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2923. // rows per thread
  2924. const int dr = (nr + nth - 1)/nth;
  2925. // row range for this thread
  2926. const int ir0 = dr*ith;
  2927. const int ir1 = MIN(ir0 + dr, nr);
  2928. for (int i1 = ir0; i1 < ir1; i1++) {
  2929. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  2930. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  2931. if (!src1) {
  2932. src0_p += swapped ? nc : 0;
  2933. src1_p += swapped ? 0 : nc;
  2934. }
  2935. ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  2936. #ifndef NDEBUG
  2937. for (int k = 0; k < nc; k++) {
  2938. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  2939. const float v = GGML_FP16_TO_FP32(x);
  2940. GGML_UNUSED(v);
  2941. assert(!isnan(v));
  2942. assert(!isinf(v));
  2943. }
  2944. #endif
  2945. }
  2946. }
  2947. static void ggml_compute_forward_swiglu(
  2948. const ggml_compute_params * params,
  2949. ggml_tensor * dst) {
  2950. const ggml_tensor * src0 = dst->src[0];
  2951. switch (src0->type) {
  2952. case GGML_TYPE_F32:
  2953. {
  2954. ggml_compute_forward_swiglu_f32(params, dst);
  2955. } break;
  2956. case GGML_TYPE_F16:
  2957. {
  2958. ggml_compute_forward_swiglu_f16(params, dst);
  2959. } break;
  2960. default:
  2961. {
  2962. GGML_ABORT("fatal error");
  2963. }
  2964. }
  2965. }
  2966. // ggml_compute_forward_geglu_erf
  2967. static void ggml_compute_forward_geglu_erf_f32(
  2968. const ggml_compute_params * params,
  2969. ggml_tensor * dst) {
  2970. const ggml_tensor * src0 = dst->src[0];
  2971. const ggml_tensor * src1 = dst->src[1];
  2972. char * src0_d = (char *) src0->data;
  2973. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  2974. const size_t src0_o = src0->nb[1];
  2975. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  2976. GGML_ASSERT(ggml_is_contiguous_1(src0));
  2977. GGML_ASSERT(ggml_is_contiguous_1(dst));
  2978. if (src1) {
  2979. GGML_ASSERT(ggml_is_contiguous_1(src1));
  2980. GGML_ASSERT(src0->type == src1->type);
  2981. }
  2982. const int ith = params->ith;
  2983. const int nth = params->nth;
  2984. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  2985. const int nr = ggml_nrows(src0);
  2986. GGML_ASSERT(dst->ne[0] == nc);
  2987. GGML_ASSERT(ggml_nrows(dst) == nr);
  2988. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  2989. // rows per thread
  2990. const int dr = (nr + nth - 1)/nth;
  2991. // row range for this thread
  2992. const int ir0 = dr*ith;
  2993. const int ir1 = MIN(ir0 + dr, nr);
  2994. for (int i1 = ir0; i1 < ir1; i1++) {
  2995. float * src0_p = (float *) (src0_d + i1*src0_o);
  2996. float * src1_p = (float *) (src1_d + i1*src1_o);
  2997. if (!src1) {
  2998. src0_p += swapped ? nc : 0;
  2999. src1_p += swapped ? 0 : nc;
  3000. }
  3001. ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  3002. #ifndef NDEBUG
  3003. for (int k = 0; k < nc; k++) {
  3004. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  3005. GGML_UNUSED(x);
  3006. assert(!isnan(x));
  3007. assert(!isinf(x));
  3008. }
  3009. #endif
  3010. }
  3011. }
  3012. static void ggml_compute_forward_geglu_erf_f16(
  3013. const ggml_compute_params * params,
  3014. ggml_tensor * dst) {
  3015. const ggml_tensor * src0 = dst->src[0];
  3016. const ggml_tensor * src1 = dst->src[1];
  3017. char * src0_d = (char *) src0->data;
  3018. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  3019. const size_t src0_o = src0->nb[1];
  3020. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  3021. GGML_ASSERT(ggml_is_contiguous_1(src0));
  3022. GGML_ASSERT(ggml_is_contiguous_1(dst));
  3023. if (src1) {
  3024. GGML_ASSERT(ggml_is_contiguous_1(src1));
  3025. GGML_ASSERT(src0->type == src1->type);
  3026. }
  3027. const int ith = params->ith;
  3028. const int nth = params->nth;
  3029. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  3030. const int nr = ggml_nrows(src0);
  3031. GGML_ASSERT(dst->ne[0] == nc);
  3032. GGML_ASSERT(ggml_nrows(dst) == nr);
  3033. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  3034. // rows per thread
  3035. const int dr = (nr + nth - 1)/nth;
  3036. // row range for this thread
  3037. const int ir0 = dr*ith;
  3038. const int ir1 = MIN(ir0 + dr, nr);
  3039. for (int i1 = ir0; i1 < ir1; i1++) {
  3040. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  3041. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  3042. if (!src1) {
  3043. src0_p += swapped ? nc : 0;
  3044. src1_p += swapped ? 0 : nc;
  3045. }
  3046. ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  3047. #ifndef NDEBUG
  3048. for (int k = 0; k < nc; k++) {
  3049. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  3050. const float v = GGML_FP16_TO_FP32(x);
  3051. GGML_UNUSED(v);
  3052. assert(!isnan(v));
  3053. assert(!isinf(v));
  3054. }
  3055. #endif
  3056. }
  3057. }
  3058. static void ggml_compute_forward_geglu_erf(
  3059. const ggml_compute_params * params,
  3060. ggml_tensor * dst) {
  3061. const ggml_tensor * src0 = dst->src[0];
  3062. switch (src0->type) {
  3063. case GGML_TYPE_F32:
  3064. {
  3065. ggml_compute_forward_geglu_erf_f32(params, dst);
  3066. } break;
  3067. case GGML_TYPE_F16:
  3068. {
  3069. ggml_compute_forward_geglu_erf_f16(params, dst);
  3070. } break;
  3071. default:
  3072. {
  3073. GGML_ABORT("fatal error");
  3074. }
  3075. }
  3076. }
  3077. // ggml_compute_forward_geglu_quick
  3078. static void ggml_compute_forward_geglu_quick_f32(
  3079. const ggml_compute_params * params,
  3080. ggml_tensor * dst) {
  3081. const ggml_tensor * src0 = dst->src[0];
  3082. const ggml_tensor * src1 = dst->src[1];
  3083. char * src0_d = (char *) src0->data;
  3084. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  3085. const size_t src0_o = src0->nb[1];
  3086. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  3087. GGML_ASSERT(ggml_is_contiguous_1(src0));
  3088. GGML_ASSERT(ggml_is_contiguous_1(dst));
  3089. if (src1) {
  3090. GGML_ASSERT(ggml_is_contiguous_1(src1));
  3091. GGML_ASSERT(src0->type == src1->type);
  3092. }
  3093. const int ith = params->ith;
  3094. const int nth = params->nth;
  3095. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  3096. const int nr = ggml_nrows(src0);
  3097. GGML_ASSERT(dst->ne[0] == nc);
  3098. GGML_ASSERT(ggml_nrows(dst) == nr);
  3099. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  3100. // rows per thread
  3101. const int dr = (nr + nth - 1)/nth;
  3102. // row range for this thread
  3103. const int ir0 = dr*ith;
  3104. const int ir1 = MIN(ir0 + dr, nr);
  3105. for (int i1 = ir0; i1 < ir1; i1++) {
  3106. float * src0_p = (float *) (src0_d + i1*src0_o);
  3107. float * src1_p = (float *) (src1_d + i1*src1_o);
  3108. if (!src1) {
  3109. src0_p += swapped ? nc : 0;
  3110. src1_p += swapped ? 0 : nc;
  3111. }
  3112. ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  3113. #ifndef NDEBUG
  3114. for (int k = 0; k < nc; k++) {
  3115. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  3116. GGML_UNUSED(x);
  3117. assert(!isnan(x));
  3118. assert(!isinf(x));
  3119. }
  3120. #endif
  3121. }
  3122. }
  3123. static void ggml_compute_forward_geglu_quick_f16(
  3124. const ggml_compute_params * params,
  3125. ggml_tensor * dst) {
  3126. const ggml_tensor * src0 = dst->src[0];
  3127. const ggml_tensor * src1 = dst->src[1];
  3128. char * src0_d = (char *) src0->data;
  3129. char * src1_d = (char *) (src1 ? src1->data : src0->data);
  3130. const size_t src0_o = src0->nb[1];
  3131. const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
  3132. GGML_ASSERT(ggml_is_contiguous_1(src0));
  3133. GGML_ASSERT(ggml_is_contiguous_1(dst));
  3134. if (src1) {
  3135. GGML_ASSERT(ggml_is_contiguous_1(src1));
  3136. GGML_ASSERT(src0->type == src1->type);
  3137. }
  3138. const int ith = params->ith;
  3139. const int nth = params->nth;
  3140. const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
  3141. const int nr = ggml_nrows(src0);
  3142. GGML_ASSERT(dst->ne[0] == nc);
  3143. GGML_ASSERT(ggml_nrows(dst) == nr);
  3144. const int32_t swapped = ggml_get_op_params_i32(dst, 1);
  3145. // rows per thread
  3146. const int dr = (nr + nth - 1)/nth;
  3147. // row range for this thread
  3148. const int ir0 = dr*ith;
  3149. const int ir1 = MIN(ir0 + dr, nr);
  3150. for (int i1 = ir0; i1 < ir1; i1++) {
  3151. ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
  3152. ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
  3153. if (!src1) {
  3154. src0_p += swapped ? nc : 0;
  3155. src1_p += swapped ? 0 : nc;
  3156. }
  3157. ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
  3158. #ifndef NDEBUG
  3159. for (int k = 0; k < nc; k++) {
  3160. const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  3161. const float v = GGML_FP16_TO_FP32(x);
  3162. GGML_UNUSED(v);
  3163. assert(!isnan(v));
  3164. assert(!isinf(v));
  3165. }
  3166. #endif
  3167. }
  3168. }
  3169. static void ggml_compute_forward_geglu_quick(
  3170. const ggml_compute_params * params,
  3171. ggml_tensor * dst) {
  3172. const ggml_tensor * src0 = dst->src[0];
  3173. switch (src0->type) {
  3174. case GGML_TYPE_F32:
  3175. {
  3176. ggml_compute_forward_geglu_quick_f32(params, dst);
  3177. } break;
  3178. case GGML_TYPE_F16:
  3179. {
  3180. ggml_compute_forward_geglu_quick_f16(params, dst);
  3181. } break;
  3182. default:
  3183. {
  3184. GGML_ABORT("fatal error");
  3185. }
  3186. }
  3187. }
  3188. // ggml_compute_forward_norm
  3189. static void ggml_compute_forward_norm_f32(
  3190. const ggml_compute_params * params,
  3191. ggml_tensor * dst) {
  3192. const ggml_tensor * src0 = dst->src[0];
  3193. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3194. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3195. const int ith = params->ith;
  3196. const int nth = params->nth;
  3197. GGML_TENSOR_UNARY_OP_LOCALS
  3198. float eps;
  3199. memcpy(&eps, dst->op_params, sizeof(float));
  3200. GGML_ASSERT(eps >= 0.0f);
  3201. // TODO: optimize
  3202. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3203. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3204. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  3205. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  3206. ggml_float sum = 0.0;
  3207. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3208. sum += (ggml_float)x[i00];
  3209. }
  3210. float mean = sum/ne00;
  3211. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  3212. ggml_float sum2 = 0.0;
  3213. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3214. float v = x[i00] - mean;
  3215. y[i00] = v;
  3216. sum2 += (ggml_float)(v*v);
  3217. }
  3218. float variance = sum2/ne00;
  3219. const float scale = 1.0f/sqrtf(variance + eps);
  3220. ggml_vec_scale_f32(ne00, y, scale);
  3221. }
  3222. }
  3223. }
  3224. }
  3225. void ggml_compute_forward_norm(
  3226. const ggml_compute_params * params,
  3227. ggml_tensor * dst) {
  3228. const ggml_tensor * src0 = dst->src[0];
  3229. switch (src0->type) {
  3230. case GGML_TYPE_F32:
  3231. {
  3232. ggml_compute_forward_norm_f32(params, dst);
  3233. } break;
  3234. default:
  3235. {
  3236. GGML_ABORT("fatal error");
  3237. }
  3238. }
  3239. }
  3240. // ggml_compute_forward_group_rms_norm
  3241. static void ggml_compute_forward_rms_norm_f32(
  3242. const ggml_compute_params * params,
  3243. ggml_tensor * dst) {
  3244. const ggml_tensor * src0 = dst->src[0];
  3245. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3246. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3247. const int ith = params->ith;
  3248. const int nth = params->nth;
  3249. GGML_TENSOR_UNARY_OP_LOCALS
  3250. float eps;
  3251. memcpy(&eps, dst->op_params, sizeof(float));
  3252. GGML_ASSERT(eps >= 0.0f);
  3253. // TODO: optimize
  3254. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3255. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3256. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  3257. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  3258. ggml_float sum = 0.0;
  3259. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3260. sum += (ggml_float)(x[i00] * x[i00]);
  3261. }
  3262. const float mean = sum/ne00;
  3263. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  3264. memcpy(y, x, ne00 * sizeof(float));
  3265. // for (int i00 = 0; i00 < ne00; i00++) {
  3266. // y[i00] = x[i00];
  3267. // }
  3268. const float scale = 1.0f/sqrtf(mean + eps);
  3269. ggml_vec_scale_f32(ne00, y, scale);
  3270. }
  3271. }
  3272. }
  3273. }
  3274. void ggml_compute_forward_rms_norm(
  3275. const ggml_compute_params * params,
  3276. ggml_tensor * dst) {
  3277. const ggml_tensor * src0 = dst->src[0];
  3278. switch (src0->type) {
  3279. case GGML_TYPE_F32:
  3280. {
  3281. ggml_compute_forward_rms_norm_f32(params, dst);
  3282. } break;
  3283. default:
  3284. {
  3285. GGML_ABORT("fatal error");
  3286. }
  3287. }
  3288. }
  3289. static void ggml_compute_forward_rms_norm_back_f32(
  3290. const ggml_compute_params * params,
  3291. ggml_tensor * dst) {
  3292. const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
  3293. const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
  3294. GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
  3295. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3296. GGML_ASSERT(src1->nb[0] == sizeof(float));
  3297. const int ith = params->ith;
  3298. const int nth = params->nth;
  3299. GGML_TENSOR_BINARY_OP_LOCALS
  3300. float eps;
  3301. memcpy(&eps, dst->op_params, sizeof(float));
  3302. // TODO: optimize
  3303. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3304. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3305. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  3306. // src1 is same shape as src0 => same indices
  3307. const int64_t i11 = i01;
  3308. const int64_t i12 = i02;
  3309. const int64_t i13 = i03;
  3310. const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  3311. const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
  3312. ggml_float sum_xx = 0.0;
  3313. ggml_float sum_xdz = 0.0;
  3314. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3315. sum_xx += (ggml_float)(x[i00] * x[i00]);
  3316. sum_xdz += (ggml_float)(x[i00] * dz[i00]);
  3317. }
  3318. //const float mean = (float)(sum_xx)/ne00;
  3319. const float mean_eps = (float)(sum_xx)/ne00 + eps;
  3320. const float sum_eps = (float)(sum_xx) + eps*ne00;
  3321. //const float mean_xdz = (float)(sum_xdz)/ne00;
  3322. // we could cache rms from forward pass to improve performance.
  3323. // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
  3324. //const float rms = sqrtf(mean_eps);
  3325. const float rrms = 1.0f / sqrtf(mean_eps);
  3326. //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
  3327. {
  3328. // z = rms_norm(x)
  3329. //
  3330. // rms_norm(src1) =
  3331. // scale(
  3332. // src1,
  3333. // div(
  3334. // 1,
  3335. // sqrt(
  3336. // add(
  3337. // scale(
  3338. // sum(
  3339. // sqr(
  3340. // src1)),
  3341. // (1.0/N)),
  3342. // eps))));
  3343. // postorder:
  3344. // ## op args grad
  3345. // 00 param src1 grad[#00]
  3346. // 01 const 1
  3347. // 02 sqr (#00) grad[#02]
  3348. // 03 sum (#02) grad[#03]
  3349. // 04 const 1/N
  3350. // 05 scale (#03, #04) grad[#05]
  3351. // 06 const eps
  3352. // 07 add (#05, #06) grad[#07]
  3353. // 08 sqrt (#07) grad[#08]
  3354. // 09 div (#01,#08) grad[#09]
  3355. // 10 scale (#00,#09) grad[#10]
  3356. //
  3357. // backward pass, given grad[#10]
  3358. // #10: scale
  3359. // grad[#00] += scale(grad[#10],#09)
  3360. // grad[#09] += sum(mul(grad[#10],#00))
  3361. // #09: div
  3362. // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
  3363. // #08: sqrt
  3364. // grad[#07] += mul(grad[#08], div(0.5, #08))
  3365. // #07: add
  3366. // grad[#05] += grad[#07]
  3367. // #05: scale
  3368. // grad[#03] += scale(grad[#05],#04)
  3369. // #03: sum
  3370. // grad[#02] += repeat(grad[#03], #02)
  3371. // #02:
  3372. // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
  3373. //
  3374. // substitute and simplify:
  3375. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  3376. // grad[#02] = repeat(grad[#03], #02)
  3377. // grad[#02] = repeat(scale(grad[#05],#04), #02)
  3378. // grad[#02] = repeat(scale(grad[#07],#04), #02)
  3379. // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
  3380. // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
  3381. // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
  3382. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
  3383. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
  3384. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
  3385. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
  3386. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  3387. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
  3388. // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
  3389. // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
  3390. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  3391. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  3392. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
  3393. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
  3394. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
  3395. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
  3396. // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
  3397. // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
  3398. // a = b*c + d*e
  3399. // a = b*c*f/f + d*e*f/f
  3400. // a = (b*c*f + d*e*f)*(1/f)
  3401. // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
  3402. // a = (b + d*e/c)*c
  3403. // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
  3404. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
  3405. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
  3406. // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
  3407. // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
  3408. // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
  3409. // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
  3410. // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
  3411. // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  3412. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  3413. }
  3414. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  3415. // post-order:
  3416. // dx := x
  3417. // dx := scale(dx,-mean_xdz/mean_eps)
  3418. // dx := add(dx, dz)
  3419. // dx := scale(dx, rrms)
  3420. float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  3421. // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
  3422. ggml_vec_cpy_f32 (ne00, dx, x);
  3423. // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
  3424. ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
  3425. ggml_vec_acc_f32 (ne00, dx, dz);
  3426. ggml_vec_scale_f32(ne00, dx, rrms);
  3427. }
  3428. }
  3429. }
  3430. }
  3431. void ggml_compute_forward_rms_norm_back(
  3432. const ggml_compute_params * params,
  3433. ggml_tensor * dst) {
  3434. const ggml_tensor * src0 = dst->src[0];
  3435. switch (src0->type) {
  3436. case GGML_TYPE_F32:
  3437. {
  3438. ggml_compute_forward_rms_norm_back_f32(params, dst);
  3439. } break;
  3440. default:
  3441. {
  3442. GGML_ABORT("fatal error");
  3443. }
  3444. }
  3445. }
  3446. // ggml_compute_forward_group_norm
  3447. static void ggml_compute_forward_group_norm_f32(
  3448. const ggml_compute_params * params,
  3449. ggml_tensor * dst) {
  3450. const ggml_tensor * src0 = dst->src[0];
  3451. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3452. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3453. const int ith = params->ith;
  3454. const int nth = params->nth;
  3455. GGML_TENSOR_UNARY_OP_LOCALS
  3456. // TODO: optimize
  3457. float eps;
  3458. memcpy(&eps, dst->op_params + 1, sizeof(float));
  3459. int n_channels = src0->ne[2];
  3460. int n_groups = dst->op_params[0];
  3461. int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
  3462. for (int i = ith; i < n_groups; i += nth) {
  3463. int start = i * n_channels_per_group;
  3464. int end = start + n_channels_per_group;
  3465. if (end > n_channels) {
  3466. end = n_channels;
  3467. }
  3468. int step = end - start;
  3469. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3470. ggml_float sum = 0.0;
  3471. for (int64_t i02 = start; i02 < end; i02++) {
  3472. for (int64_t i01 = 0; i01 < ne01; i01++) {
  3473. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  3474. ggml_float sumr = 0.0;
  3475. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3476. sumr += (ggml_float)x[i00];
  3477. }
  3478. sum += sumr;
  3479. }
  3480. }
  3481. const float mean = sum / (ne00 * ne01 * step);
  3482. ggml_float sum2 = 0.0;
  3483. for (int64_t i02 = start; i02 < end; i02++) {
  3484. for (int64_t i01 = 0; i01 < ne01; i01++) {
  3485. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  3486. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  3487. ggml_float sumr = 0.0;
  3488. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3489. float v = x[i00] - mean;
  3490. y[i00] = v;
  3491. sumr += (ggml_float)(v * v);
  3492. }
  3493. sum2 += sumr;
  3494. }
  3495. }
  3496. const float variance = sum2 / (ne00 * ne01 * step);
  3497. const float scale = 1.0f / sqrtf(variance + eps);
  3498. for (int64_t i02 = start; i02 < end; i02++) {
  3499. for (int64_t i01 = 0; i01 < ne01; i01++) {
  3500. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  3501. ggml_vec_scale_f32(ne00, y, scale);
  3502. }
  3503. }
  3504. }
  3505. }
  3506. }
  3507. void ggml_compute_forward_group_norm(
  3508. const ggml_compute_params * params,
  3509. ggml_tensor * dst) {
  3510. const ggml_tensor * src0 = dst->src[0];
  3511. switch (src0->type) {
  3512. case GGML_TYPE_F32:
  3513. {
  3514. ggml_compute_forward_group_norm_f32(params, dst);
  3515. } break;
  3516. default:
  3517. {
  3518. GGML_ABORT("fatal error");
  3519. }
  3520. }
  3521. }
  3522. // ggml_compute_forward_l2_norm
  3523. static void ggml_compute_forward_l2_norm_f32(
  3524. const ggml_compute_params * params,
  3525. ggml_tensor * dst) {
  3526. const ggml_tensor * src0 = dst->src[0];
  3527. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3528. GGML_ASSERT(src0->nb[0] == sizeof(float));
  3529. const int ith = params->ith;
  3530. const int nth = params->nth;
  3531. GGML_TENSOR_UNARY_OP_LOCALS
  3532. float eps;
  3533. memcpy(&eps, dst->op_params, sizeof(float));
  3534. GGML_ASSERT(eps >= 0.0f);
  3535. // TODO: optimize
  3536. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3537. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3538. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  3539. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  3540. ggml_float sum = 0.0;
  3541. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3542. sum += (ggml_float)(x[i00] * x[i00]);
  3543. }
  3544. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  3545. memcpy(y, x, ne00 * sizeof(float));
  3546. const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
  3547. ggml_vec_scale_f32(ne00, y, scale);
  3548. }
  3549. }
  3550. }
  3551. }
  3552. void ggml_compute_forward_l2_norm(
  3553. const ggml_compute_params * params,
  3554. ggml_tensor * dst) {
  3555. const ggml_tensor * src0 = dst->src[0];
  3556. switch (src0->type) {
  3557. case GGML_TYPE_F32:
  3558. {
  3559. ggml_compute_forward_l2_norm_f32(params, dst);
  3560. } break;
  3561. default:
  3562. {
  3563. GGML_ABORT("fatal error");
  3564. }
  3565. }
  3566. }
  3567. // ggml_compute_forward_out_prod
  3568. static void ggml_compute_forward_out_prod_f32(
  3569. const ggml_compute_params * params,
  3570. ggml_tensor * dst) {
  3571. const ggml_tensor * src0 = dst->src[0];
  3572. const ggml_tensor * src1 = dst->src[1];
  3573. GGML_TENSOR_BINARY_OP_LOCALS
  3574. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  3575. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  3576. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  3577. const int ith = params->ith;
  3578. const int nth = params->nth;
  3579. GGML_ASSERT(ne0 == ne00);
  3580. GGML_ASSERT(ne1 == ne10);
  3581. GGML_ASSERT(ne2 == ne12);
  3582. GGML_ASSERT(ne3 == ne13);
  3583. GGML_ASSERT(ne2 % ne02 == 0);
  3584. GGML_ASSERT(ne3 % ne03 == 0);
  3585. // we don't support permuted src0 or src1
  3586. GGML_ASSERT(nb00 == sizeof(float));
  3587. // dst cannot be transposed or permuted
  3588. GGML_ASSERT(nb0 == sizeof(float));
  3589. // GGML_ASSERT(nb0 <= nb1);
  3590. // GGML_ASSERT(nb1 <= nb2);
  3591. // GGML_ASSERT(nb2 <= nb3);
  3592. // nb01 >= nb00 - src0 is not transposed
  3593. // compute by src0 rows
  3594. if (ith == 0) {
  3595. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  3596. }
  3597. ggml_barrier(params->threadpool);
  3598. // dst[:,:,:,:] = 0
  3599. // for i2,i3:
  3600. // for i1:
  3601. // for i01:
  3602. // for i0:
  3603. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  3604. // parallelize by last three dimensions
  3605. // total rows in dst
  3606. const int64_t nr = ne1*ne2*ne3;
  3607. // rows per thread
  3608. const int64_t dr = (nr + nth - 1)/nth;
  3609. // row range for this thread
  3610. const int64_t ir0 = dr*ith;
  3611. const int64_t ir1 = MIN(ir0 + dr, nr);
  3612. // block-tiling attempt
  3613. const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
  3614. const int64_t blck_1 = 16;
  3615. // dps == dst per src0, used for group query attention
  3616. const int64_t dps2 = ne2 / ne02;
  3617. const int64_t dps3 = ne3 / ne03;
  3618. for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
  3619. const int64_t bir1 = MIN(bir + blck_1, ir1);
  3620. for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
  3621. const int64_t bne01 = MIN(bi01 + blck_0, ne01);
  3622. for (int64_t ir = bir; ir < bir1; ++ir) {
  3623. // dst indices
  3624. const int64_t i3 = ir/(ne2*ne1);
  3625. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  3626. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3627. const int64_t i02 = i2 / dps2;
  3628. const int64_t i03 = i3 / dps3;
  3629. //const int64_t i10 = i1;
  3630. const int64_t i12 = i2;
  3631. const int64_t i13 = i3;
  3632. #if GGML_VEC_MAD_UNROLL > 2
  3633. const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
  3634. for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
  3635. const int64_t i11 = i01;
  3636. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3637. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3638. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3639. ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
  3640. }
  3641. for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
  3642. const int64_t i11 = i01;
  3643. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3644. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3645. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3646. ggml_vec_mad_f32(ne0, d, s0, *s1);
  3647. }
  3648. #else
  3649. for (int64_t i01 = bi01; i01 < bne01; ++i01) {
  3650. const int64_t i11 = i01;
  3651. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3652. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3653. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3654. ggml_vec_mad_f32(ne0, d, s0, *s1);
  3655. }
  3656. #endif
  3657. }
  3658. }
  3659. }
  3660. }
  3661. static void ggml_compute_forward_out_prod_q_f32(
  3662. const ggml_compute_params * params,
  3663. ggml_tensor * dst) {
  3664. const ggml_tensor * src0 = dst->src[0];
  3665. const ggml_tensor * src1 = dst->src[1];
  3666. GGML_TENSOR_BINARY_OP_LOCALS;
  3667. const int ith = params->ith;
  3668. const int nth = params->nth;
  3669. const ggml_type type = src0->type;
  3670. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  3671. GGML_ASSERT(ne02 == ne12);
  3672. GGML_ASSERT(ne03 == ne13);
  3673. GGML_ASSERT(ne2 == ne12);
  3674. GGML_ASSERT(ne3 == ne13);
  3675. // we don't support permuted src0 dim0
  3676. GGML_ASSERT(nb00 == ggml_type_size(type));
  3677. // dst dim0 cannot be transposed or permuted
  3678. GGML_ASSERT(nb0 == sizeof(float));
  3679. // GGML_ASSERT(nb0 <= nb1);
  3680. // GGML_ASSERT(nb1 <= nb2);
  3681. // GGML_ASSERT(nb2 <= nb3);
  3682. GGML_ASSERT(ne0 == ne00);
  3683. GGML_ASSERT(ne1 == ne10);
  3684. GGML_ASSERT(ne2 == ne02);
  3685. GGML_ASSERT(ne3 == ne03);
  3686. // nb01 >= nb00 - src0 is not transposed
  3687. // compute by src0 rows
  3688. if (ith == 0) {
  3689. ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
  3690. }
  3691. ggml_barrier(params->threadpool);
  3692. // parallelize by last three dimensions
  3693. // total rows in dst
  3694. const int64_t nr = ne1*ne2*ne3;
  3695. // rows per thread
  3696. const int64_t dr = (nr + nth - 1)/nth;
  3697. // row range for this thread
  3698. const int64_t ir0 = dr*ith;
  3699. const int64_t ir1 = MIN(ir0 + dr, nr);
  3700. // dst[:,:,:,:] = 0
  3701. // for i2,i3:
  3702. // for i1:
  3703. // for i01:
  3704. // for i0:
  3705. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  3706. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  3707. for (int64_t ir = ir0; ir < ir1; ++ir) {
  3708. // dst indices
  3709. const int64_t i3 = ir/(ne2*ne1);
  3710. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  3711. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3712. const int64_t i02 = i2;
  3713. const int64_t i03 = i3;
  3714. //const int64_t i10 = i1;
  3715. const int64_t i12 = i2;
  3716. const int64_t i13 = i3;
  3717. for (int64_t i01 = 0; i01 < ne01; ++i01) {
  3718. const int64_t i11 = i01;
  3719. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  3720. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  3721. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3722. dequantize_row_q(s0, wdata, ne0);
  3723. ggml_vec_mad_f32(ne0, d, wdata, *s1);
  3724. }
  3725. }
  3726. }
  3727. void ggml_compute_forward_out_prod(
  3728. const ggml_compute_params * params,
  3729. ggml_tensor * dst) {
  3730. const ggml_tensor * src0 = dst->src[0];
  3731. switch (src0->type) {
  3732. case GGML_TYPE_Q4_0:
  3733. case GGML_TYPE_Q4_1:
  3734. case GGML_TYPE_Q5_0:
  3735. case GGML_TYPE_Q5_1:
  3736. case GGML_TYPE_Q8_0:
  3737. case GGML_TYPE_Q2_K:
  3738. case GGML_TYPE_Q3_K:
  3739. case GGML_TYPE_Q4_K:
  3740. case GGML_TYPE_Q5_K:
  3741. case GGML_TYPE_Q6_K:
  3742. case GGML_TYPE_TQ1_0:
  3743. case GGML_TYPE_TQ2_0:
  3744. case GGML_TYPE_IQ2_XXS:
  3745. case GGML_TYPE_IQ2_XS:
  3746. case GGML_TYPE_IQ3_XXS:
  3747. case GGML_TYPE_IQ1_S:
  3748. case GGML_TYPE_IQ1_M:
  3749. case GGML_TYPE_IQ4_NL:
  3750. case GGML_TYPE_IQ4_XS:
  3751. case GGML_TYPE_IQ3_S:
  3752. case GGML_TYPE_IQ2_S:
  3753. {
  3754. ggml_compute_forward_out_prod_q_f32(params, dst);
  3755. } break;
  3756. case GGML_TYPE_F16:
  3757. {
  3758. GGML_ABORT("fatal error"); // todo
  3759. // ggml_compute_forward_out_prod_f16_f32(params, dst);
  3760. }
  3761. case GGML_TYPE_F32:
  3762. {
  3763. ggml_compute_forward_out_prod_f32(params, dst);
  3764. } break;
  3765. default:
  3766. {
  3767. GGML_ABORT("fatal error");
  3768. }
  3769. }
  3770. }
  3771. // ggml_compute_forward_scale
  3772. static void ggml_compute_forward_scale_f32(
  3773. const ggml_compute_params * params,
  3774. ggml_tensor * dst) {
  3775. const ggml_tensor * src0 = dst->src[0];
  3776. GGML_ASSERT(ggml_is_contiguous(src0));
  3777. GGML_ASSERT(ggml_is_contiguous(dst));
  3778. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3779. // scale factor
  3780. float v;
  3781. memcpy(&v, dst->op_params, sizeof(float));
  3782. const int ith = params->ith;
  3783. const int nth = params->nth;
  3784. const int nc = src0->ne[0];
  3785. const int nr = ggml_nrows(src0);
  3786. // rows per thread
  3787. const int dr = (nr + nth - 1)/nth;
  3788. // row range for this thread
  3789. const int ir0 = dr*ith;
  3790. const int ir1 = MIN(ir0 + dr, nr);
  3791. const size_t nb01 = src0->nb[1];
  3792. const size_t nb1 = dst->nb[1];
  3793. for (int i1 = ir0; i1 < ir1; i1++) {
  3794. if (dst->data != src0->data) {
  3795. // src0 is same shape as dst => same indices
  3796. memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
  3797. }
  3798. ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
  3799. }
  3800. }
  3801. void ggml_compute_forward_scale(
  3802. const ggml_compute_params * params,
  3803. ggml_tensor * dst) {
  3804. const ggml_tensor * src0 = dst->src[0];
  3805. switch (src0->type) {
  3806. case GGML_TYPE_F32:
  3807. {
  3808. ggml_compute_forward_scale_f32(params, dst);
  3809. } break;
  3810. default:
  3811. {
  3812. GGML_ABORT("fatal error");
  3813. }
  3814. }
  3815. }
  3816. // ggml_compute_forward_set
  3817. static void ggml_compute_forward_set_f32(
  3818. const ggml_compute_params * params,
  3819. ggml_tensor * dst) {
  3820. const ggml_tensor * src0 = dst->src[0];
  3821. const ggml_tensor * src1 = dst->src[1];
  3822. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3823. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3824. // view src0 and dst with these strides and data offset inbytes during set
  3825. // nb0 is implicitly element_size because src0 and dst are contiguous
  3826. size_t nb1 = ((int32_t *) dst->op_params)[0];
  3827. size_t nb2 = ((int32_t *) dst->op_params)[1];
  3828. size_t nb3 = ((int32_t *) dst->op_params)[2];
  3829. size_t offset = ((int32_t *) dst->op_params)[3];
  3830. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  3831. if (!inplace) {
  3832. if (params->ith == 0) {
  3833. // memcpy needs to be synchronized across threads to avoid race conditions.
  3834. // => do it in INIT phase
  3835. memcpy(
  3836. ((char *) dst->data),
  3837. ((char *) src0->data),
  3838. ggml_nbytes(dst));
  3839. }
  3840. ggml_barrier(params->threadpool);
  3841. }
  3842. const int ith = params->ith;
  3843. const int nth = params->nth;
  3844. const int nr = ggml_nrows(src1);
  3845. const int nc = src1->ne[0];
  3846. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  3847. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  3848. // src0 and dst as viewed during set
  3849. const size_t nb0 = ggml_element_size(src0);
  3850. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  3851. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  3852. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  3853. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  3854. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  3855. GGML_ASSERT(nb10 == sizeof(float));
  3856. // rows per thread
  3857. const int dr = (nr + nth - 1)/nth;
  3858. // row range for this thread
  3859. const int ir0 = dr*ith;
  3860. const int ir1 = MIN(ir0 + dr, nr);
  3861. for (int ir = ir0; ir < ir1; ++ir) {
  3862. // src0 and dst are viewed with shape of src1 and offset
  3863. // => same indices
  3864. const int i3 = ir/(ne12*ne11);
  3865. const int i2 = (ir - i3*ne12*ne11)/ne11;
  3866. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  3867. ggml_vec_cpy_f32(nc,
  3868. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  3869. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  3870. }
  3871. }
  3872. static void ggml_compute_forward_set_i32(
  3873. const ggml_compute_params * params,
  3874. ggml_tensor * dst) {
  3875. const ggml_tensor * src0 = dst->src[0];
  3876. const ggml_tensor * src1 = dst->src[1];
  3877. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3878. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  3879. // view src0 and dst with these strides and data offset inbytes during set
  3880. // nb0 is implicitly element_size because src0 and dst are contiguous
  3881. size_t nb1 = ((int32_t *) dst->op_params)[0];
  3882. size_t nb2 = ((int32_t *) dst->op_params)[1];
  3883. size_t nb3 = ((int32_t *) dst->op_params)[2];
  3884. size_t offset = ((int32_t *) dst->op_params)[3];
  3885. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  3886. if (!inplace) {
  3887. if (params->ith == 0) {
  3888. // memcpy needs to be synchronized across threads to avoid race conditions.
  3889. // => do it in INIT phase
  3890. memcpy(
  3891. ((char *) dst->data),
  3892. ((char *) src0->data),
  3893. ggml_nbytes(dst));
  3894. }
  3895. ggml_barrier(params->threadpool);
  3896. }
  3897. const int ith = params->ith;
  3898. const int nth = params->nth;
  3899. const int nr = ggml_nrows(src1);
  3900. const int nc = src1->ne[0];
  3901. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  3902. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  3903. // src0 and dst as viewed during set
  3904. const size_t nb0 = ggml_element_size(src0);
  3905. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  3906. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  3907. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  3908. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  3909. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  3910. GGML_ASSERT(nb10 == sizeof(int32_t));
  3911. // rows per thread
  3912. const int dr = (nr + nth - 1)/nth;
  3913. // row range for this thread
  3914. const int ir0 = dr*ith;
  3915. const int ir1 = MIN(ir0 + dr, nr);
  3916. for (int ir = ir0; ir < ir1; ++ir) {
  3917. // src0 and dst are viewed with shape of src1 and offset
  3918. // => same indices
  3919. const int i3 = ir/(ne12*ne11);
  3920. const int i2 = (ir - i3*ne12*ne11)/ne11;
  3921. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  3922. ggml_vec_cpy_i32(nc,
  3923. (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  3924. (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  3925. }
  3926. }
  3927. void ggml_compute_forward_set(
  3928. const ggml_compute_params * params,
  3929. ggml_tensor * dst) {
  3930. const ggml_tensor * src0 = dst->src[0];
  3931. switch (src0->type) {
  3932. case GGML_TYPE_F32:
  3933. {
  3934. ggml_compute_forward_set_f32(params, dst);
  3935. } break;
  3936. case GGML_TYPE_I32:
  3937. {
  3938. ggml_compute_forward_set_i32(params, dst);
  3939. } break;
  3940. case GGML_TYPE_F16:
  3941. case GGML_TYPE_BF16:
  3942. case GGML_TYPE_Q4_0:
  3943. case GGML_TYPE_Q4_1:
  3944. case GGML_TYPE_Q5_0:
  3945. case GGML_TYPE_Q5_1:
  3946. case GGML_TYPE_Q8_0:
  3947. case GGML_TYPE_Q8_1:
  3948. case GGML_TYPE_Q2_K:
  3949. case GGML_TYPE_Q3_K:
  3950. case GGML_TYPE_Q4_K:
  3951. case GGML_TYPE_Q5_K:
  3952. case GGML_TYPE_Q6_K:
  3953. case GGML_TYPE_TQ1_0:
  3954. case GGML_TYPE_TQ2_0:
  3955. case GGML_TYPE_IQ2_XXS:
  3956. case GGML_TYPE_IQ2_XS:
  3957. case GGML_TYPE_IQ3_XXS:
  3958. case GGML_TYPE_IQ1_S:
  3959. case GGML_TYPE_IQ1_M:
  3960. case GGML_TYPE_IQ4_NL:
  3961. case GGML_TYPE_IQ4_XS:
  3962. case GGML_TYPE_IQ3_S:
  3963. case GGML_TYPE_IQ2_S:
  3964. default:
  3965. {
  3966. GGML_ABORT("fatal error");
  3967. }
  3968. }
  3969. }
  3970. // ggml_compute_forward_cpy
  3971. void ggml_compute_forward_cpy(
  3972. const ggml_compute_params * params,
  3973. ggml_tensor * dst) {
  3974. ggml_compute_forward_dup(params, dst);
  3975. }
  3976. // ggml_compute_forward_cont
  3977. void ggml_compute_forward_cont(
  3978. const ggml_compute_params * params,
  3979. ggml_tensor * dst) {
  3980. ggml_compute_forward_dup(params, dst);
  3981. }
  3982. // ggml_compute_forward_reshape
  3983. void ggml_compute_forward_reshape(
  3984. const ggml_compute_params * params,
  3985. ggml_tensor * dst) {
  3986. // NOP
  3987. GGML_UNUSED(params);
  3988. GGML_UNUSED(dst);
  3989. }
  3990. // ggml_compute_forward_view
  3991. void ggml_compute_forward_view(
  3992. const ggml_compute_params * params,
  3993. ggml_tensor * dst) {
  3994. // NOP
  3995. GGML_UNUSED(params);
  3996. GGML_UNUSED(dst);
  3997. }
  3998. // ggml_compute_forward_permute
  3999. void ggml_compute_forward_permute(
  4000. const ggml_compute_params * params,
  4001. ggml_tensor * dst) {
  4002. // NOP
  4003. GGML_UNUSED(params);
  4004. GGML_UNUSED(dst);
  4005. }
  4006. // ggml_compute_forward_transpose
  4007. void ggml_compute_forward_transpose(
  4008. const ggml_compute_params * params,
  4009. ggml_tensor * dst) {
  4010. // NOP
  4011. GGML_UNUSED(params);
  4012. GGML_UNUSED(dst);
  4013. }
  4014. // ggml_compute_forward_get_rows
  4015. static void ggml_compute_forward_get_rows_q(
  4016. const ggml_compute_params * params,
  4017. ggml_tensor * dst) {
  4018. const ggml_tensor * src0 = dst->src[0];
  4019. const ggml_tensor * src1 = dst->src[1];
  4020. GGML_TENSOR_BINARY_OP_LOCALS
  4021. const int64_t nc = ne00;
  4022. const int64_t nr = ggml_nelements(src1);
  4023. const ggml_type type = src0->type;
  4024. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  4025. assert(ne0 == nc);
  4026. assert(ne02 == ne11);
  4027. assert(nb00 == ggml_type_size(type));
  4028. assert(ggml_nrows(dst) == nr);
  4029. const int ith = params->ith;
  4030. const int nth = params->nth;
  4031. // rows per thread
  4032. const int dr = (nr + nth - 1)/nth;
  4033. // row range for this thread
  4034. const int ir0 = dr*ith;
  4035. const int ir1 = MIN(ir0 + dr, nr);
  4036. for (int64_t i = ir0; i < ir1; ++i) {
  4037. const int64_t i12 = i/(ne11*ne10);
  4038. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  4039. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  4040. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  4041. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  4042. dequantize_row_q(
  4043. (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  4044. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  4045. }
  4046. }
  4047. static void ggml_compute_forward_get_rows_f16(
  4048. const ggml_compute_params * params,
  4049. ggml_tensor * dst) {
  4050. const ggml_tensor * src0 = dst->src[0];
  4051. const ggml_tensor * src1 = dst->src[1];
  4052. GGML_TENSOR_BINARY_OP_LOCALS
  4053. const int64_t nc = ne00;
  4054. const int64_t nr = ggml_nelements(src1);
  4055. assert(ne0 == nc);
  4056. assert(ne02 == ne11);
  4057. assert(nb00 == sizeof(ggml_fp16_t));
  4058. assert(ggml_nrows(dst) == nr);
  4059. const int ith = params->ith;
  4060. const int nth = params->nth;
  4061. // rows per thread
  4062. const int dr = (nr + nth - 1)/nth;
  4063. // row range for this thread
  4064. const int ir0 = dr*ith;
  4065. const int ir1 = MIN(ir0 + dr, nr);
  4066. for (int64_t i = ir0; i < ir1; ++i) {
  4067. const int64_t i12 = i/(ne11*ne10);
  4068. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  4069. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  4070. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  4071. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  4072. ggml_cpu_fp16_to_fp32(
  4073. (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  4074. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  4075. }
  4076. }
  4077. static void ggml_compute_forward_get_rows_bf16(
  4078. const ggml_compute_params * params,
  4079. ggml_tensor * dst) {
  4080. const ggml_tensor * src0 = dst->src[0];
  4081. const ggml_tensor * src1 = dst->src[1];
  4082. GGML_TENSOR_BINARY_OP_LOCALS
  4083. const int64_t nc = ne00;
  4084. const int64_t nr = ggml_nelements(src1);
  4085. assert(ne0 == nc);
  4086. assert(ne02 == ne11);
  4087. assert(nb00 == sizeof(ggml_bf16_t));
  4088. assert(ggml_nrows(dst) == nr);
  4089. const int ith = params->ith;
  4090. const int nth = params->nth;
  4091. // rows per thread
  4092. const int dr = (nr + nth - 1)/nth;
  4093. // row range for this thread
  4094. const int ir0 = dr*ith;
  4095. const int ir1 = MIN(ir0 + dr, nr);
  4096. for (int64_t i = ir0; i < ir1; ++i) {
  4097. const int64_t i12 = i/(ne11*ne10);
  4098. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  4099. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  4100. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  4101. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  4102. ggml_cpu_bf16_to_fp32(
  4103. (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  4104. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  4105. }
  4106. }
  4107. static void ggml_compute_forward_get_rows_f32(
  4108. const ggml_compute_params * params,
  4109. ggml_tensor * dst) {
  4110. const ggml_tensor * src0 = dst->src[0];
  4111. const ggml_tensor * src1 = dst->src[1];
  4112. GGML_TENSOR_BINARY_OP_LOCALS
  4113. const int64_t nc = ne00;
  4114. const int64_t nr = ggml_nelements(src1);
  4115. assert(ne0 == nc);
  4116. assert(ne02 == ne11);
  4117. assert(nb00 == sizeof(float));
  4118. assert(ggml_nrows(dst) == nr);
  4119. const int ith = params->ith;
  4120. const int nth = params->nth;
  4121. // rows per thread
  4122. const int dr = (nr + nth - 1)/nth;
  4123. // row range for this thread
  4124. const int ir0 = dr*ith;
  4125. const int ir1 = MIN(ir0 + dr, nr);
  4126. for (int64_t i = ir0; i < ir1; ++i) {
  4127. const int64_t i12 = i/(ne11*ne10);
  4128. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  4129. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  4130. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  4131. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  4132. ggml_vec_cpy_f32(nc,
  4133. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
  4134. (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
  4135. }
  4136. }
  4137. void ggml_compute_forward_get_rows(
  4138. const ggml_compute_params * params,
  4139. ggml_tensor * dst) {
  4140. const ggml_tensor * src0 = dst->src[0];
  4141. switch (src0->type) {
  4142. case GGML_TYPE_Q4_0:
  4143. case GGML_TYPE_Q4_1:
  4144. case GGML_TYPE_Q5_0:
  4145. case GGML_TYPE_Q5_1:
  4146. case GGML_TYPE_Q8_0:
  4147. case GGML_TYPE_Q8_1:
  4148. case GGML_TYPE_Q2_K:
  4149. case GGML_TYPE_Q3_K:
  4150. case GGML_TYPE_Q4_K:
  4151. case GGML_TYPE_Q5_K:
  4152. case GGML_TYPE_Q6_K:
  4153. case GGML_TYPE_TQ1_0:
  4154. case GGML_TYPE_TQ2_0:
  4155. case GGML_TYPE_IQ2_XXS:
  4156. case GGML_TYPE_IQ2_XS:
  4157. case GGML_TYPE_IQ3_XXS:
  4158. case GGML_TYPE_IQ1_S:
  4159. case GGML_TYPE_IQ1_M:
  4160. case GGML_TYPE_IQ4_NL:
  4161. case GGML_TYPE_IQ4_XS:
  4162. case GGML_TYPE_IQ3_S:
  4163. case GGML_TYPE_IQ2_S:
  4164. {
  4165. ggml_compute_forward_get_rows_q(params, dst);
  4166. } break;
  4167. case GGML_TYPE_F16:
  4168. {
  4169. ggml_compute_forward_get_rows_f16(params, dst);
  4170. } break;
  4171. case GGML_TYPE_BF16:
  4172. {
  4173. ggml_compute_forward_get_rows_bf16(params, dst);
  4174. } break;
  4175. case GGML_TYPE_F32:
  4176. case GGML_TYPE_I32:
  4177. {
  4178. ggml_compute_forward_get_rows_f32(params, dst);
  4179. } break;
  4180. default:
  4181. {
  4182. GGML_ABORT("fatal error");
  4183. }
  4184. }
  4185. //static bool first = true;
  4186. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  4187. //if (first) {
  4188. // first = false;
  4189. //} else {
  4190. // for (int k = 0; k < dst->ne[1]; ++k) {
  4191. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  4192. // for (int i = 0; i < 16; ++i) {
  4193. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  4194. // }
  4195. // printf("\n");
  4196. // }
  4197. // printf("\n");
  4198. // }
  4199. // printf("\n");
  4200. // exit(0);
  4201. //}
  4202. }
  4203. static void ggml_compute_forward_set_rows_f32(
  4204. const ggml_compute_params * params,
  4205. ggml_tensor * dst) {
  4206. const ggml_tensor * src0 = dst->src[0];
  4207. const ggml_tensor * src1 = dst->src[1];
  4208. GGML_TENSOR_BINARY_OP_LOCALS
  4209. const int64_t nc = ne00;
  4210. const int64_t nr = ne01;
  4211. assert(ne0 == nc);
  4212. assert(ne2 == ne02);
  4213. assert(ne3 == ne03);
  4214. assert(src0->type == GGML_TYPE_F32);
  4215. assert(ne02 % ne11 == 0);
  4216. assert(ne03 % ne12 == 0);
  4217. const int ith = params->ith;
  4218. const int nth = params->nth;
  4219. // rows per thread
  4220. const int64_t dr = (nr + nth - 1)/nth;
  4221. // row range for this thread
  4222. const int64_t ir0 = dr*ith;
  4223. const int64_t ir1 = std::min(ir0 + dr, nr);
  4224. ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
  4225. for (int64_t i03 = 0; i03 < ne03; ++i03) {
  4226. for (int64_t i02 = 0; i02 < ne02; ++i02) {
  4227. for (int64_t i = ir0; i < ir1; ++i) {
  4228. const int64_t i12 = i03%ne12;
  4229. const int64_t i11 = i02%ne11;
  4230. const int64_t i10 = i;
  4231. const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  4232. GGML_ASSERT(i1 >= 0 && i1 < ne1);
  4233. from_float(
  4234. (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
  4235. ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
  4236. }
  4237. }
  4238. }
  4239. }
  4240. void ggml_compute_forward_set_rows(
  4241. const ggml_compute_params * params,
  4242. ggml_tensor * dst) {
  4243. const ggml_tensor * src0 = dst->src[0];
  4244. switch (src0->type) {
  4245. case GGML_TYPE_F32:
  4246. {
  4247. ggml_compute_forward_set_rows_f32(params, dst);
  4248. } break;
  4249. default:
  4250. {
  4251. GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
  4252. }
  4253. }
  4254. }
  4255. // ggml_compute_forward_get_rows_back
  4256. static void ggml_compute_forward_get_rows_back_f32_f16(
  4257. const ggml_compute_params * params,
  4258. ggml_tensor * dst) {
  4259. const ggml_tensor * src0 = dst->src[0];
  4260. const ggml_tensor * src1 = dst->src[1];
  4261. if (params->ith != 0) {
  4262. return;
  4263. }
  4264. GGML_ASSERT(ggml_is_contiguous(dst));
  4265. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  4266. memset(dst->data, 0, ggml_nbytes(dst));
  4267. const int nc = src0->ne[0];
  4268. const int nr = ggml_nelements(src1);
  4269. GGML_ASSERT( dst->ne[0] == nc);
  4270. GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
  4271. for (int i = 0; i < nr; ++i) {
  4272. const int r = ((int32_t *) src1->data)[i];
  4273. for (int j = 0; j < nc; ++j) {
  4274. ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
  4275. ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
  4276. }
  4277. }
  4278. }
  4279. static void ggml_compute_forward_get_rows_back_f32(
  4280. const ggml_compute_params * params,
  4281. ggml_tensor * dst) {
  4282. const ggml_tensor * src0 = dst->src[0];
  4283. const ggml_tensor * src1 = dst->src[1];
  4284. if (params->ith != 0) {
  4285. return;
  4286. }
  4287. GGML_ASSERT(ggml_is_contiguous(dst));
  4288. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  4289. memset(dst->data, 0, ggml_nbytes(dst));
  4290. const int nc = src0->ne[0];
  4291. const int nr = ggml_nelements(src1);
  4292. GGML_ASSERT( dst->ne[0] == nc);
  4293. GGML_ASSERT(src0->nb[0] == sizeof(float));
  4294. for (int i = 0; i < nr; ++i) {
  4295. const int r = ((int32_t *) src1->data)[i];
  4296. ggml_vec_add_f32(nc,
  4297. (float *) ((char *) dst->data + r*dst->nb[1]),
  4298. (float *) ((char *) dst->data + r*dst->nb[1]),
  4299. (float *) ((char *) src0->data + i*src0->nb[1]));
  4300. }
  4301. }
  4302. void ggml_compute_forward_get_rows_back(
  4303. const ggml_compute_params * params,
  4304. ggml_tensor * dst) {
  4305. const ggml_tensor * src0 = dst->src[0];
  4306. switch (src0->type) {
  4307. case GGML_TYPE_F16:
  4308. {
  4309. ggml_compute_forward_get_rows_back_f32_f16(params, dst);
  4310. } break;
  4311. case GGML_TYPE_F32:
  4312. {
  4313. ggml_compute_forward_get_rows_back_f32(params, dst);
  4314. } break;
  4315. default:
  4316. {
  4317. GGML_ABORT("fatal error");
  4318. }
  4319. }
  4320. //static bool first = true;
  4321. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  4322. //if (first) {
  4323. // first = false;
  4324. //} else {
  4325. // for (int k = 0; k < dst->ne[1]; ++k) {
  4326. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  4327. // for (int i = 0; i < 16; ++i) {
  4328. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  4329. // }
  4330. // printf("\n");
  4331. // }
  4332. // printf("\n");
  4333. // }
  4334. // printf("\n");
  4335. // exit(0);
  4336. //}
  4337. }
  4338. // ggml_compute_forward_diag
  4339. static void ggml_compute_forward_diag_f32(
  4340. const ggml_compute_params * params,
  4341. ggml_tensor * dst) {
  4342. const ggml_tensor * src0 = dst->src[0];
  4343. if (params->ith != 0) {
  4344. return;
  4345. }
  4346. // TODO: handle transposed/permuted matrices
  4347. GGML_TENSOR_UNARY_OP_LOCALS
  4348. GGML_ASSERT(ne00 == ne0);
  4349. GGML_ASSERT(ne00 == ne1);
  4350. GGML_ASSERT(ne01 == 1);
  4351. GGML_ASSERT(ne02 == ne2);
  4352. GGML_ASSERT(ne03 == ne3);
  4353. GGML_ASSERT(nb00 == sizeof(float));
  4354. GGML_ASSERT(nb0 == sizeof(float));
  4355. for (int i3 = 0; i3 < ne3; i3++) {
  4356. for (int i2 = 0; i2 < ne2; i2++) {
  4357. for (int i1 = 0; i1 < ne1; i1++) {
  4358. float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  4359. float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
  4360. for (int i0 = 0; i0 < i1; i0++) {
  4361. d[i0] = 0;
  4362. }
  4363. d[i1] = s[i1];
  4364. for (int i0 = i1+1; i0 < ne0; i0++) {
  4365. d[i0] = 0;
  4366. }
  4367. }
  4368. }
  4369. }
  4370. }
  4371. void ggml_compute_forward_diag(
  4372. const ggml_compute_params * params,
  4373. ggml_tensor * dst) {
  4374. const ggml_tensor * src0 = dst->src[0];
  4375. switch (src0->type) {
  4376. case GGML_TYPE_F32:
  4377. {
  4378. ggml_compute_forward_diag_f32(params, dst);
  4379. } break;
  4380. default:
  4381. {
  4382. GGML_ABORT("fatal error");
  4383. }
  4384. }
  4385. }
  4386. // ggml_compute_forward_diag_mask_inf
  4387. static void ggml_compute_forward_diag_mask_f32(
  4388. const ggml_compute_params * params,
  4389. ggml_tensor * dst,
  4390. const float value) {
  4391. const ggml_tensor * src0 = dst->src[0];
  4392. const int ith = params->ith;
  4393. const int nth = params->nth;
  4394. const int n_past = ((int32_t *) dst->op_params)[0];
  4395. const bool inplace = src0->data == dst->data;
  4396. GGML_ASSERT(n_past >= 0);
  4397. if (!inplace) {
  4398. if (ith == 0) {
  4399. // memcpy needs to be synchronized across threads to avoid race conditions.
  4400. // => do it in INIT phase
  4401. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  4402. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  4403. memcpy(
  4404. ((char *) dst->data),
  4405. ((char *) src0->data),
  4406. ggml_nbytes(dst));
  4407. }
  4408. ggml_barrier(params->threadpool);
  4409. }
  4410. // TODO: handle transposed/permuted matrices
  4411. const int n = ggml_nrows(src0);
  4412. const int nc = src0->ne[0];
  4413. const int nr = src0->ne[1];
  4414. const int nz = n/nr;
  4415. GGML_ASSERT( dst->nb[0] == sizeof(float));
  4416. GGML_ASSERT(src0->nb[0] == sizeof(float));
  4417. for (int k = 0; k < nz; k++) {
  4418. for (int j = ith; j < nr; j += nth) {
  4419. for (int i = n_past; i < nc; i++) {
  4420. if (i > n_past + j) {
  4421. *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
  4422. }
  4423. }
  4424. }
  4425. }
  4426. }
  4427. void ggml_compute_forward_diag_mask_inf(
  4428. const ggml_compute_params * params,
  4429. ggml_tensor * dst) {
  4430. const ggml_tensor * src0 = dst->src[0];
  4431. switch (src0->type) {
  4432. case GGML_TYPE_F32:
  4433. {
  4434. ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
  4435. } break;
  4436. default:
  4437. {
  4438. GGML_ABORT("fatal error");
  4439. }
  4440. }
  4441. }
  4442. void ggml_compute_forward_diag_mask_zero(
  4443. const ggml_compute_params * params,
  4444. ggml_tensor * dst) {
  4445. const ggml_tensor * src0 = dst->src[0];
  4446. switch (src0->type) {
  4447. case GGML_TYPE_F32:
  4448. {
  4449. ggml_compute_forward_diag_mask_f32(params, dst, 0);
  4450. } break;
  4451. default:
  4452. {
  4453. GGML_ABORT("fatal error");
  4454. }
  4455. }
  4456. }
  4457. // ggml_compute_forward_soft_max
  4458. static void ggml_compute_forward_soft_max_f32(
  4459. const ggml_compute_params * params,
  4460. ggml_tensor * dst) {
  4461. const ggml_tensor * src0 = dst->src[0];
  4462. const ggml_tensor * src1 = dst->src[1];
  4463. assert(ggml_is_contiguous(dst));
  4464. assert(ggml_are_same_shape(src0, dst));
  4465. float scale = 1.0f;
  4466. float max_bias = 0.0f;
  4467. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  4468. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  4469. const int ith = params->ith;
  4470. const int nth = params->nth;
  4471. GGML_TENSOR_UNARY_OP_LOCALS
  4472. const int64_t nb11 = src1 ? src1->nb[1] : 1;
  4473. const int64_t nb12 = src1 ? src1->nb[2] : 1;
  4474. const int64_t nb13 = src1 ? src1->nb[3] : 1;
  4475. const int64_t ne12 = src1 ? src1->ne[2] : 1;
  4476. const int64_t ne13 = src1 ? src1->ne[3] : 1;
  4477. // TODO: is this supposed to be ceil instead of floor?
  4478. // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
  4479. const uint32_t n_head = ne02;
  4480. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  4481. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  4482. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  4483. float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  4484. const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
  4485. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4486. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4487. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  4488. const int64_t i11 = i01;
  4489. const int64_t i12 = i02%ne12;
  4490. const int64_t i13 = i03%ne13;
  4491. // ALiBi
  4492. const uint32_t h = i02; // head
  4493. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  4494. float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  4495. float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  4496. // broadcast the mask across rows
  4497. ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
  4498. float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
  4499. ggml_vec_cpy_f32 (ne00, wp, sp);
  4500. ggml_vec_scale_f32(ne00, wp, scale);
  4501. if (mp_f32) {
  4502. if (use_f16) {
  4503. for (int i = 0; i < ne00; ++i) {
  4504. wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
  4505. }
  4506. } else {
  4507. for (int i = 0; i < ne00; ++i) {
  4508. wp[i] += slope*mp_f32[i];
  4509. }
  4510. }
  4511. }
  4512. #ifndef NDEBUG
  4513. for (int i = 0; i < ne00; ++i) {
  4514. //printf("p[%d] = %f\n", i, p[i]);
  4515. assert(!isnan(wp[i]));
  4516. }
  4517. #endif
  4518. float max = -INFINITY;
  4519. ggml_vec_max_f32(ne00, &max, wp);
  4520. ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
  4521. assert(sum > 0.0);
  4522. sum = 1.0/sum;
  4523. ggml_vec_scale_f32(ne00, dp, sum);
  4524. #ifndef NDEBUG
  4525. for (int i = 0; i < ne00; ++i) {
  4526. assert(!isnan(dp[i]));
  4527. assert(!isinf(dp[i]));
  4528. }
  4529. #endif
  4530. }
  4531. }
  4532. }
  4533. }
  4534. void ggml_compute_forward_soft_max(
  4535. const ggml_compute_params * params,
  4536. ggml_tensor * dst) {
  4537. const ggml_tensor * src0 = dst->src[0];
  4538. switch (src0->type) {
  4539. case GGML_TYPE_F32:
  4540. {
  4541. ggml_compute_forward_soft_max_f32(params, dst);
  4542. } break;
  4543. default:
  4544. {
  4545. GGML_ABORT("fatal error");
  4546. }
  4547. }
  4548. }
  4549. // ggml_compute_forward_soft_max_ext_back
  4550. static void ggml_compute_forward_soft_max_ext_back_f32(
  4551. const ggml_compute_params * params,
  4552. ggml_tensor * dst) {
  4553. const ggml_tensor * src0 = dst->src[0];
  4554. const ggml_tensor * src1 = dst->src[1];
  4555. GGML_ASSERT(ggml_is_contiguous(src0));
  4556. GGML_ASSERT(ggml_is_contiguous(src1));
  4557. GGML_ASSERT(ggml_is_contiguous(dst));
  4558. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4559. GGML_ASSERT(ggml_are_same_shape(src1, dst));
  4560. float scale = 1.0f;
  4561. float max_bias = 0.0f;
  4562. memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
  4563. memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
  4564. GGML_ASSERT(max_bias == 0.0f);
  4565. // TODO: handle transposed/permuted matrices
  4566. const int ith = params->ith;
  4567. const int nth = params->nth;
  4568. const int nc = src0->ne[0];
  4569. const int nr = ggml_nrows(src0);
  4570. // rows per thread
  4571. const int dr = (nr + nth - 1)/nth;
  4572. // row range for this thread
  4573. const int ir0 = dr*ith;
  4574. const int ir1 = MIN(ir0 + dr, nr);
  4575. for (int i1 = ir0; i1 < ir1; i1++) {
  4576. float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
  4577. float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
  4578. float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
  4579. #ifndef NDEBUG
  4580. for (int i = 0; i < nc; ++i) {
  4581. //printf("p[%d] = %f\n", i, p[i]);
  4582. assert(!isnan(dy[i]));
  4583. assert(!isnan(y[i]));
  4584. }
  4585. #endif
  4586. // Jii = yi - yi*yi
  4587. // Jij = -yi*yj
  4588. // J = diag(y)-y.T*y
  4589. // dx = J * dy
  4590. // dxk = sum_i(Jki * dyi)
  4591. // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
  4592. // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
  4593. // dxk = sum_i(-yk*yi * dyi) + yk*dyk
  4594. // dxk = -yk * sum_i(yi * dyi) + yk*dyk
  4595. // dxk = -yk * dot(y, dy) + yk*dyk
  4596. // dxk = yk * (- dot(y, dy) + dyk)
  4597. // dxk = yk * (dyk - dot(y, dy))
  4598. //
  4599. // post-order:
  4600. // dot_y_dy := dot(y, dy)
  4601. // dx := dy
  4602. // dx := dx - dot_y_dy
  4603. // dx := dx * y
  4604. // linear runtime, no additional memory
  4605. float dot_y_dy = 0;
  4606. ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
  4607. ggml_vec_cpy_f32 (nc, dx, dy);
  4608. ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
  4609. ggml_vec_mul_f32 (nc, dx, dx, y);
  4610. ggml_vec_scale_f32(nc, dx, scale);
  4611. #ifndef NDEBUG
  4612. for (int i = 0; i < nc; ++i) {
  4613. assert(!isnan(dx[i]));
  4614. assert(!isinf(dx[i]));
  4615. }
  4616. #endif
  4617. }
  4618. }
  4619. void ggml_compute_forward_soft_max_ext_back(
  4620. const ggml_compute_params * params,
  4621. ggml_tensor * dst) {
  4622. const ggml_tensor * src0 = dst->src[0];
  4623. switch (src0->type) {
  4624. case GGML_TYPE_F32:
  4625. {
  4626. ggml_compute_forward_soft_max_ext_back_f32(params, dst);
  4627. } break;
  4628. default:
  4629. {
  4630. GGML_ABORT("fatal error");
  4631. }
  4632. }
  4633. }
  4634. // ggml_compute_forward_clamp
  4635. static void ggml_compute_forward_clamp_f32(
  4636. const ggml_compute_params * params,
  4637. ggml_tensor * dst) {
  4638. const ggml_tensor * src0 = dst->src[0];
  4639. float min;
  4640. float max;
  4641. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  4642. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  4643. const int ith = params->ith;
  4644. const int nth = params->nth;
  4645. const int n = ggml_nrows(src0);
  4646. const int nc = src0->ne[0];
  4647. const size_t nb00 = src0->nb[0];
  4648. const size_t nb01 = src0->nb[1];
  4649. const size_t nb0 = dst->nb[0];
  4650. const size_t nb1 = dst->nb[1];
  4651. GGML_ASSERT( nb0 == sizeof(float));
  4652. GGML_ASSERT(nb00 == sizeof(float));
  4653. for (int j = ith; j < n; j += nth) {
  4654. float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
  4655. float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
  4656. for (int i = 0; i < nc; i++) {
  4657. dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
  4658. }
  4659. }
  4660. }
  4661. static void ggml_compute_forward_clamp_f16(
  4662. const ggml_compute_params * params,
  4663. ggml_tensor * dst) {
  4664. const ggml_tensor * src0 = dst->src[0];
  4665. float min;
  4666. float max;
  4667. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  4668. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  4669. const int ith = params->ith;
  4670. const int nth = params->nth;
  4671. const int n = ggml_nrows(src0);
  4672. const int nc = src0->ne[0];
  4673. const size_t nb00 = src0->nb[0];
  4674. const size_t nb01 = src0->nb[1];
  4675. const size_t nb0 = dst->nb[0];
  4676. const size_t nb1 = dst->nb[1];
  4677. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  4678. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4679. for (int j = ith; j < n; j += nth) {
  4680. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
  4681. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
  4682. for (int i = 0; i < nc; i++) {
  4683. float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
  4684. dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
  4685. }
  4686. }
  4687. }
  4688. void ggml_compute_forward_clamp(
  4689. const ggml_compute_params * params,
  4690. ggml_tensor * dst) {
  4691. const ggml_tensor * src0 = dst->src[0];
  4692. switch (src0->type) {
  4693. case GGML_TYPE_F32:
  4694. {
  4695. ggml_compute_forward_clamp_f32(params, dst);
  4696. } break;
  4697. case GGML_TYPE_F16:
  4698. {
  4699. ggml_compute_forward_clamp_f16(params, dst);
  4700. } break;
  4701. case GGML_TYPE_BF16:
  4702. case GGML_TYPE_Q4_0:
  4703. case GGML_TYPE_Q4_1:
  4704. case GGML_TYPE_Q5_0:
  4705. case GGML_TYPE_Q5_1:
  4706. case GGML_TYPE_Q8_0:
  4707. case GGML_TYPE_Q8_1:
  4708. case GGML_TYPE_Q2_K:
  4709. case GGML_TYPE_Q3_K:
  4710. case GGML_TYPE_Q4_K:
  4711. case GGML_TYPE_Q5_K:
  4712. case GGML_TYPE_Q6_K:
  4713. case GGML_TYPE_TQ1_0:
  4714. case GGML_TYPE_TQ2_0:
  4715. case GGML_TYPE_IQ2_XXS:
  4716. case GGML_TYPE_IQ2_XS:
  4717. case GGML_TYPE_IQ3_XXS:
  4718. case GGML_TYPE_IQ1_S:
  4719. case GGML_TYPE_IQ1_M:
  4720. case GGML_TYPE_IQ4_NL:
  4721. case GGML_TYPE_IQ4_XS:
  4722. case GGML_TYPE_IQ3_S:
  4723. case GGML_TYPE_IQ2_S:
  4724. case GGML_TYPE_Q8_K:
  4725. case GGML_TYPE_I8:
  4726. case GGML_TYPE_I16:
  4727. case GGML_TYPE_I32:
  4728. case GGML_TYPE_I64:
  4729. case GGML_TYPE_F64:
  4730. case GGML_TYPE_COUNT:
  4731. {
  4732. GGML_ABORT("fatal error");
  4733. }
  4734. }
  4735. }
  4736. // ggml_compute_forward_rope
  4737. static float rope_yarn_ramp(const float low, const float high, const int i0) {
  4738. const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
  4739. return 1 - MIN(1, MAX(0, y));
  4740. }
  4741. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  4742. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  4743. static void rope_yarn(
  4744. float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
  4745. float * cos_theta, float * sin_theta) {
  4746. // Get n-d rotational scaling corrected for extrapolation
  4747. float theta_interp = freq_scale * theta_extrap;
  4748. float theta = theta_interp;
  4749. if (ext_factor != 0.0f) {
  4750. float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
  4751. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  4752. // Get n-d magnitude scaling corrected for interpolation
  4753. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  4754. }
  4755. *cos_theta = cosf(theta) * mscale;
  4756. *sin_theta = sinf(theta) * mscale;
  4757. }
  4758. static void ggml_rope_cache_init(
  4759. float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  4760. float * cache, float sin_sign, float theta_scale) {
  4761. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  4762. float theta = theta_base;
  4763. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  4764. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  4765. rope_yarn(
  4766. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  4767. );
  4768. cache[i0 + 1] *= sin_sign;
  4769. theta *= theta_scale;
  4770. }
  4771. }
  4772. static void ggml_mrope_cache_init(
  4773. float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
  4774. float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  4775. float * cache, float sin_sign, float theta_scale) {
  4776. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  4777. float theta_t = theta_base_t;
  4778. float theta_h = theta_base_h;
  4779. float theta_w = theta_base_w;
  4780. float theta_e = theta_base_e; // extra position id for vision encoder
  4781. int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
  4782. int sec_w = sections[1] + sections[0];
  4783. int sec_e = sections[2] + sec_w;
  4784. GGML_ASSERT(sect_dims <= ne0);
  4785. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  4786. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  4787. int sector = (i0 / 2) % sect_dims;
  4788. if (indep_sects) {
  4789. // compute theta independently for each dim sections
  4790. // (i.e. reset corresponding theta when `i0` go from one section to another)
  4791. if (sector == 0) {
  4792. theta_t = theta_base_t;
  4793. }
  4794. else if (sector == sections[0]) {
  4795. theta_h = theta_base_h;;
  4796. }
  4797. else if (sector == sec_w) {
  4798. theta_w = theta_base_w;
  4799. }
  4800. else if (sector == sec_e) {
  4801. theta_e = theta_base_e;
  4802. }
  4803. }
  4804. float theta = theta_t;
  4805. if (sector >= sections[0] && sector < sec_w) {
  4806. theta = theta_h;
  4807. }
  4808. else if (sector >= sec_w && sector < sec_w + sections[2]) {
  4809. theta = theta_w;
  4810. }
  4811. else if (sector >= sec_w + sections[2]) {
  4812. theta = theta_e;
  4813. }
  4814. rope_yarn(
  4815. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  4816. );
  4817. cache[i0 + 1] *= sin_sign;
  4818. theta_t *= theta_scale;
  4819. theta_w *= theta_scale;
  4820. theta_h *= theta_scale;
  4821. theta_e *= theta_scale;
  4822. }
  4823. }
  4824. static void ggml_compute_forward_rope_f32(
  4825. const ggml_compute_params * params,
  4826. ggml_tensor * dst,
  4827. const bool forward) {
  4828. const ggml_tensor * src0 = dst->src[0];
  4829. const ggml_tensor * src1 = dst->src[1];
  4830. const ggml_tensor * src2 = dst->src[2];
  4831. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  4832. int sections[4];
  4833. //const int n_past = ((int32_t *) dst->op_params)[0];
  4834. const int n_dims = ((int32_t *) dst->op_params)[1];
  4835. const int mode = ((int32_t *) dst->op_params)[2];
  4836. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  4837. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  4838. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  4839. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  4840. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  4841. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  4842. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  4843. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  4844. memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
  4845. GGML_TENSOR_UNARY_OP_LOCALS
  4846. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  4847. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  4848. GGML_ASSERT(nb00 == sizeof(float));
  4849. const int ith = params->ith;
  4850. const int nth = params->nth;
  4851. const int nr = ggml_nrows(dst);
  4852. GGML_ASSERT(n_dims <= ne0);
  4853. GGML_ASSERT(n_dims % 2 == 0);
  4854. // rows per thread
  4855. const int dr = (nr + nth - 1)/nth;
  4856. // row range for this thread
  4857. const int ir0 = dr*ith;
  4858. const int ir1 = MIN(ir0 + dr, nr);
  4859. // row index used to determine which thread to use
  4860. int ir = 0;
  4861. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  4862. float corr_dims[2];
  4863. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  4864. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  4865. const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
  4866. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  4867. if (is_mrope) {
  4868. GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
  4869. }
  4870. if (is_vision) {
  4871. GGML_ASSERT(n_dims == ne0/2);
  4872. }
  4873. const float * freq_factors = NULL;
  4874. if (src2 != NULL) {
  4875. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  4876. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  4877. freq_factors = (const float *) src2->data;
  4878. }
  4879. // backward process uses inverse rotation by cos and sin.
  4880. // cos and sin build a rotation matrix, where the inverse is the transpose.
  4881. // this essentially just switches the sign of sin.
  4882. const float sin_sign = forward ? 1.0f : -1.0f;
  4883. const int32_t * pos = (const int32_t *) src1->data;
  4884. for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
  4885. for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
  4886. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  4887. if (!is_mrope) {
  4888. const int64_t p = pos[i2];
  4889. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4890. }
  4891. else {
  4892. const int64_t p_t = pos[i2];
  4893. const int64_t p_h = pos[i2 + ne2];
  4894. const int64_t p_w = pos[i2 + ne2 * 2];
  4895. const int64_t p_e = pos[i2 + ne2 * 3];
  4896. ggml_mrope_cache_init(
  4897. p_t, p_h, p_w, p_e, sections, is_vision,
  4898. freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  4899. }
  4900. for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
  4901. if (ir++ < ir0) continue;
  4902. if (ir > ir1) break;
  4903. if (is_neox || is_mrope) {
  4904. if (is_vision){
  4905. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4906. const int64_t ic = i0/2;
  4907. const float cos_theta = cache[i0 + 0];
  4908. const float sin_theta = cache[i0 + 1];
  4909. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4910. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4911. const float x0 = src[0];
  4912. const float x1 = src[n_dims];
  4913. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4914. dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
  4915. }
  4916. } else {
  4917. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4918. const int64_t ic = i0/2;
  4919. const float cos_theta = cache[i0 + 0];
  4920. const float sin_theta = cache[i0 + 1];
  4921. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4922. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4923. const float x0 = src[0];
  4924. const float x1 = src[n_dims/2];
  4925. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4926. dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
  4927. }
  4928. }
  4929. } else {
  4930. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  4931. const float cos_theta = cache[i0 + 0];
  4932. const float sin_theta = cache[i0 + 1];
  4933. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4934. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4935. const float x0 = src[0];
  4936. const float x1 = src[1];
  4937. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4938. dst_data[1] = x0*sin_theta + x1*cos_theta;
  4939. }
  4940. }
  4941. if (is_vision) {
  4942. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4943. const int64_t ic = i0/2;
  4944. const float cos_theta = cache[i0 + 0];
  4945. const float sin_theta = cache[i0 + 1];
  4946. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  4947. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  4948. const float x0 = src[0];
  4949. const float x1 = src[n_dims];
  4950. dst_data[0] = x0*cos_theta - x1*sin_theta;
  4951. dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
  4952. }
  4953. } else {
  4954. // fill the remain channels with data from src tensor
  4955. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  4956. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  4957. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  4958. dst_data[0] = src[0];
  4959. dst_data[1] = src[1];
  4960. }
  4961. }
  4962. }
  4963. }
  4964. }
  4965. }
  4966. // TODO: deduplicate f16/f32 code
  4967. static void ggml_compute_forward_rope_f16(
  4968. const ggml_compute_params * params,
  4969. ggml_tensor * dst,
  4970. const bool forward) {
  4971. const ggml_tensor * src0 = dst->src[0];
  4972. const ggml_tensor * src1 = dst->src[1];
  4973. const ggml_tensor * src2 = dst->src[2];
  4974. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  4975. int sections[4];
  4976. //const int n_past = ((int32_t *) dst->op_params)[0];
  4977. const int n_dims = ((int32_t *) dst->op_params)[1];
  4978. const int mode = ((int32_t *) dst->op_params)[2];
  4979. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  4980. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  4981. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  4982. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  4983. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  4984. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  4985. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  4986. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  4987. memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
  4988. GGML_TENSOR_UNARY_OP_LOCALS
  4989. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  4990. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  4991. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  4992. const int ith = params->ith;
  4993. const int nth = params->nth;
  4994. const int nr = ggml_nrows(dst);
  4995. GGML_ASSERT(n_dims <= ne0);
  4996. GGML_ASSERT(n_dims % 2 == 0);
  4997. // rows per thread
  4998. const int dr = (nr + nth - 1)/nth;
  4999. // row range for this thread
  5000. const int ir0 = dr*ith;
  5001. const int ir1 = MIN(ir0 + dr, nr);
  5002. // row index used to determine which thread to use
  5003. int ir = 0;
  5004. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  5005. float corr_dims[2];
  5006. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  5007. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  5008. const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
  5009. const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
  5010. if (is_mrope) {
  5011. GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
  5012. }
  5013. if (is_vision) {
  5014. GGML_ASSERT(n_dims == ne0/2);
  5015. }
  5016. const float * freq_factors = NULL;
  5017. if (src2 != NULL) {
  5018. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  5019. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  5020. freq_factors = (const float *) src2->data;
  5021. }
  5022. // backward process uses inverse rotation by cos and sin.
  5023. // cos and sin build a rotation matrix, where the inverse is the transpose.
  5024. // this essentially just switches the sign of sin.
  5025. const float sin_sign = forward ? 1.0f : -1.0f;
  5026. const int32_t * pos = (const int32_t *) src1->data;
  5027. for (int64_t i3 = 0; i3 < ne3; i3++) {
  5028. for (int64_t i2 = 0; i2 < ne2; i2++) {
  5029. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  5030. if (!is_mrope) {
  5031. const int64_t p = pos[i2];
  5032. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  5033. }
  5034. else {
  5035. const int64_t p_t = pos[i2];
  5036. const int64_t p_h = pos[i2 + ne2];
  5037. const int64_t p_w = pos[i2 + ne2 * 2];
  5038. const int64_t p_e = pos[i2 + ne2 * 3];
  5039. ggml_mrope_cache_init(
  5040. p_t, p_h, p_w, p_e, sections, is_vision,
  5041. freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  5042. }
  5043. for (int64_t i1 = 0; i1 < ne1; i1++) {
  5044. if (ir++ < ir0) continue;
  5045. if (ir > ir1) break;
  5046. if (is_neox || is_mrope) {
  5047. if (is_vision) {
  5048. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  5049. const int64_t ic = i0/2;
  5050. const float cos_theta = cache[i0 + 0];
  5051. const float sin_theta = cache[i0 + 1];
  5052. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  5053. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  5054. const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
  5055. const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
  5056. dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  5057. dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  5058. }
  5059. } else {
  5060. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  5061. const int64_t ic = i0/2;
  5062. const float cos_theta = cache[i0 + 0];
  5063. const float sin_theta = cache[i0 + 1];
  5064. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  5065. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  5066. const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
  5067. const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
  5068. dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  5069. dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  5070. }
  5071. }
  5072. } else {
  5073. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  5074. const float cos_theta = cache[i0 + 0];
  5075. const float sin_theta = cache[i0 + 1];
  5076. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  5077. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  5078. const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
  5079. const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
  5080. dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  5081. dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  5082. }
  5083. }
  5084. if (is_vision) {
  5085. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  5086. const int64_t ic = i0/2;
  5087. const float cos_theta = cache[i0 + 0];
  5088. const float sin_theta = cache[i0 + 1];
  5089. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  5090. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  5091. const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
  5092. const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
  5093. dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  5094. dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  5095. }
  5096. } else {
  5097. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  5098. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  5099. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  5100. dst_data[0] = src[0];
  5101. dst_data[1] = src[1];
  5102. }
  5103. }
  5104. }
  5105. }
  5106. }
  5107. }
  5108. void ggml_compute_forward_rope(
  5109. const ggml_compute_params * params,
  5110. ggml_tensor * dst) {
  5111. const ggml_tensor * src0 = dst->src[0];
  5112. switch (src0->type) {
  5113. case GGML_TYPE_F16:
  5114. {
  5115. ggml_compute_forward_rope_f16(params, dst, true);
  5116. } break;
  5117. case GGML_TYPE_F32:
  5118. {
  5119. ggml_compute_forward_rope_f32(params, dst, true);
  5120. } break;
  5121. default:
  5122. {
  5123. GGML_ABORT("fatal error");
  5124. }
  5125. }
  5126. }
  5127. // ggml_compute_forward_rope_back
  5128. void ggml_compute_forward_rope_back(
  5129. const ggml_compute_params * params,
  5130. ggml_tensor * dst) {
  5131. const ggml_tensor * src0 = dst->src[0];
  5132. switch (src0->type) {
  5133. case GGML_TYPE_F16:
  5134. {
  5135. ggml_compute_forward_rope_f16(params, dst, false);
  5136. } break;
  5137. case GGML_TYPE_F32:
  5138. {
  5139. ggml_compute_forward_rope_f32(params, dst, false);
  5140. } break;
  5141. default:
  5142. {
  5143. GGML_ABORT("fatal error");
  5144. }
  5145. }
  5146. }
  5147. // ggml_compute_forward_conv_transpose_1d
  5148. static void ggml_compute_forward_conv_transpose_1d_f16_f32(
  5149. const ggml_compute_params * params,
  5150. ggml_tensor * dst) {
  5151. const ggml_tensor * src0 = dst->src[0];
  5152. const ggml_tensor * src1 = dst->src[1];
  5153. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  5154. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5155. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5156. GGML_TENSOR_BINARY_OP_LOCALS
  5157. const int ith = params->ith;
  5158. const int nth = params->nth;
  5159. const int nk = ne00*ne01*ne02;
  5160. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  5161. GGML_ASSERT(nb10 == sizeof(float));
  5162. if (ith == 0) {
  5163. memset(params->wdata, 0, params->wsize);
  5164. // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  5165. {
  5166. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  5167. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5168. for (int64_t i01 = 0; i01 < ne01; i01++) {
  5169. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
  5170. ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
  5171. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5172. dst_data[i00*ne02 + i02] = src[i00];
  5173. }
  5174. }
  5175. }
  5176. }
  5177. // permute source data (src1) from (L x Cin) to (Cin x L)
  5178. {
  5179. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  5180. ggml_fp16_t * dst_data = wdata;
  5181. for (int64_t i11 = 0; i11 < ne11; i11++) {
  5182. const float * const src = (float *)((char *) src1->data + i11*nb11);
  5183. for (int64_t i10 = 0; i10 < ne10; i10++) {
  5184. dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
  5185. }
  5186. }
  5187. }
  5188. // need to zero dst since we are accumulating into it
  5189. memset(dst->data, 0, ggml_nbytes(dst));
  5190. }
  5191. ggml_barrier(params->threadpool);
  5192. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  5193. // total rows in dst
  5194. const int nr = ne1;
  5195. // rows per thread
  5196. const int dr = (nr + nth - 1)/nth;
  5197. // row range for this thread
  5198. const int ir0 = dr*ith;
  5199. const int ir1 = MIN(ir0 + dr, nr);
  5200. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  5201. ggml_fp16_t * const wdata_src = wdata + nk;
  5202. for (int i1 = ir0; i1 < ir1; i1++) {
  5203. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  5204. ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
  5205. for (int i10 = 0; i10 < ne10; i10++) {
  5206. const int i1n = i10*ne11;
  5207. for (int i00 = 0; i00 < ne00; i00++) {
  5208. float v = 0;
  5209. ggml_vec_dot_f16(ne02, &v, 0,
  5210. (ggml_fp16_t *) wdata_src + i1n, 0,
  5211. (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
  5212. dst_data[i10*s0 + i00] += v;
  5213. }
  5214. }
  5215. }
  5216. }
  5217. static void ggml_compute_forward_conv_transpose_1d_f32(
  5218. const ggml_compute_params * params,
  5219. ggml_tensor * dst) {
  5220. const ggml_tensor * src0 = dst->src[0];
  5221. const ggml_tensor * src1 = dst->src[1];
  5222. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  5223. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5224. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5225. GGML_TENSOR_BINARY_OP_LOCALS
  5226. const int ith = params->ith;
  5227. const int nth = params->nth;
  5228. const int nk = ne00*ne01*ne02;
  5229. GGML_ASSERT(nb00 == sizeof(float));
  5230. GGML_ASSERT(nb10 == sizeof(float));
  5231. if (ith == 0) {
  5232. memset(params->wdata, 0, params->wsize);
  5233. // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  5234. {
  5235. float * const wdata = (float *) params->wdata + 0;
  5236. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5237. for (int64_t i01 = 0; i01 < ne01; i01++) {
  5238. const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
  5239. float * dst_data = wdata + i01*ne00*ne02;
  5240. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5241. dst_data[i00*ne02 + i02] = src[i00];
  5242. }
  5243. }
  5244. }
  5245. }
  5246. // prepare source data (src1)
  5247. {
  5248. float * const wdata = (float *) params->wdata + nk;
  5249. float * dst_data = wdata;
  5250. for (int64_t i11 = 0; i11 < ne11; i11++) {
  5251. const float * const src = (float *)((char *) src1->data + i11*nb11);
  5252. for (int64_t i10 = 0; i10 < ne10; i10++) {
  5253. dst_data[i10*ne11 + i11] = src[i10];
  5254. }
  5255. }
  5256. }
  5257. // need to zero dst since we are accumulating into it
  5258. memset(dst->data, 0, ggml_nbytes(dst));
  5259. }
  5260. ggml_barrier(params->threadpool);
  5261. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  5262. // total rows in dst
  5263. const int nr = ne1;
  5264. // rows per thread
  5265. const int dr = (nr + nth - 1)/nth;
  5266. // row range for this thread
  5267. const int ir0 = dr*ith;
  5268. const int ir1 = MIN(ir0 + dr, nr);
  5269. float * const wdata = (float *) params->wdata + 0;
  5270. float * const wdata_src = wdata + nk;
  5271. for (int i1 = ir0; i1 < ir1; i1++) {
  5272. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  5273. float * wdata_kernel = wdata + i1*ne02*ne00;
  5274. for (int i10 = 0; i10 < ne10; i10++) {
  5275. const int i1n = i10*ne11;
  5276. for (int i00 = 0; i00 < ne00; i00++) {
  5277. float v = 0;
  5278. ggml_vec_dot_f32(ne02, &v, 0,
  5279. wdata_src + i1n, 0,
  5280. wdata_kernel + i00*ne02, 0, 1);
  5281. dst_data[i10*s0 + i00] += v;
  5282. }
  5283. }
  5284. }
  5285. }
  5286. void ggml_compute_forward_conv_transpose_1d(
  5287. const ggml_compute_params * params,
  5288. ggml_tensor * dst) {
  5289. const ggml_tensor * src0 = dst->src[0];
  5290. switch (src0->type) {
  5291. case GGML_TYPE_F16:
  5292. {
  5293. ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
  5294. } break;
  5295. case GGML_TYPE_F32:
  5296. {
  5297. ggml_compute_forward_conv_transpose_1d_f32(params, dst);
  5298. } break;
  5299. default:
  5300. {
  5301. GGML_ABORT("fatal error");
  5302. }
  5303. }
  5304. }
  5305. // ggml_compute_forward_im2col_f32
  5306. // src0: kernel [OC, IC, KH, KW]
  5307. // src1: image [N, IC, IH, IW]
  5308. // dst: result [N, OH, OW, IC*KH*KW]
  5309. static void ggml_compute_forward_im2col_f32(
  5310. const ggml_compute_params * params,
  5311. ggml_tensor * dst) {
  5312. const ggml_tensor * src0 = dst->src[0];
  5313. const ggml_tensor * src1 = dst->src[1];
  5314. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5315. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5316. GGML_TENSOR_BINARY_OP_LOCALS;
  5317. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  5318. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  5319. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  5320. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  5321. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  5322. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  5323. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  5324. const int ith = params->ith;
  5325. const int nth = params->nth;
  5326. const int64_t N = is_2D ? ne13 : ne12;
  5327. const int64_t IC = is_2D ? ne12 : ne11;
  5328. const int64_t IH = is_2D ? ne11 : 1;
  5329. const int64_t IW = ne10;
  5330. const int64_t KH = is_2D ? ne01 : 1;
  5331. const int64_t KW = ne00;
  5332. const int64_t OH = is_2D ? ne2 : 1;
  5333. const int64_t OW = ne1;
  5334. int ofs0 = is_2D ? nb13 : nb12;
  5335. int ofs1 = is_2D ? nb12 : nb11;
  5336. GGML_ASSERT(nb10 == sizeof(float));
  5337. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  5338. {
  5339. float * const wdata = (float *) dst->data;
  5340. for (int64_t in = 0; in < N; in++) {
  5341. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  5342. for (int64_t iow = 0; iow < OW; iow++) {
  5343. for (int64_t iic = ith; iic < IC; iic += nth) {
  5344. // micro kernel
  5345. float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  5346. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  5347. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  5348. for (int64_t ikw = 0; ikw < KW; ikw++) {
  5349. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  5350. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  5351. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  5352. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  5353. } else {
  5354. dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
  5355. }
  5356. }
  5357. }
  5358. }
  5359. }
  5360. }
  5361. }
  5362. }
  5363. }
  5364. // ggml_compute_forward_im2col_f16
  5365. // src0: kernel [OC, IC, KH, KW]
  5366. // src1: image [N, IC, IH, IW]
  5367. // dst: result [N, OH, OW, IC*KH*KW]
  5368. static void ggml_compute_forward_im2col_f16(
  5369. const ggml_compute_params * params,
  5370. ggml_tensor * dst) {
  5371. const ggml_tensor * src0 = dst->src[0];
  5372. const ggml_tensor * src1 = dst->src[1];
  5373. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  5374. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5375. GGML_ASSERT( dst->type == GGML_TYPE_F16);
  5376. GGML_TENSOR_BINARY_OP_LOCALS;
  5377. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  5378. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  5379. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  5380. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  5381. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  5382. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  5383. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  5384. const int ith = params->ith;
  5385. const int nth = params->nth;
  5386. const int64_t N = is_2D ? ne13 : ne12;
  5387. const int64_t IC = is_2D ? ne12 : ne11;
  5388. const int64_t IH = is_2D ? ne11 : 1;
  5389. const int64_t IW = ne10;
  5390. const int64_t KH = is_2D ? ne01 : 1;
  5391. const int64_t KW = ne00;
  5392. const int64_t OH = is_2D ? ne2 : 1;
  5393. const int64_t OW = ne1;
  5394. int ofs0 = is_2D ? nb13 : nb12;
  5395. int ofs1 = is_2D ? nb12 : nb11;
  5396. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  5397. GGML_ASSERT(nb10 == sizeof(float));
  5398. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  5399. {
  5400. ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
  5401. for (int64_t in = 0; in < N; in++) {
  5402. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  5403. for (int64_t iow = 0; iow < OW; iow++) {
  5404. for (int64_t iic = ith; iic < IC; iic += nth) {
  5405. // micro kernel
  5406. ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  5407. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  5408. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  5409. for (int64_t ikw = 0; ikw < KW; ikw++) {
  5410. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  5411. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  5412. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  5413. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  5414. } else {
  5415. dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
  5416. }
  5417. }
  5418. }
  5419. }
  5420. }
  5421. }
  5422. }
  5423. }
  5424. }
  5425. void ggml_compute_forward_im2col(
  5426. const ggml_compute_params * params,
  5427. ggml_tensor * dst) {
  5428. switch (dst->type) {
  5429. case GGML_TYPE_F16:
  5430. {
  5431. ggml_compute_forward_im2col_f16(params, dst);
  5432. } break;
  5433. case GGML_TYPE_F32:
  5434. {
  5435. ggml_compute_forward_im2col_f32(params, dst);
  5436. } break;
  5437. default:
  5438. {
  5439. GGML_ABORT("fatal error");
  5440. }
  5441. }
  5442. }
  5443. // ggml_compute_forward_im2col_back_f32
  5444. void ggml_compute_forward_im2col_back_f32(
  5445. const ggml_compute_params * params,
  5446. ggml_tensor * dst) {
  5447. const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
  5448. const ggml_tensor * src1 = dst->src[1]; // convolution kernel
  5449. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  5450. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5451. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5452. GGML_TENSOR_BINARY_OP_LOCALS;
  5453. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  5454. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  5455. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  5456. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  5457. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  5458. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  5459. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  5460. const int ith = params->ith;
  5461. const int nth = params->nth;
  5462. const int64_t N = is_2D ? ne3 : ne2;
  5463. const int64_t IC = is_2D ? ne2 : ne1;
  5464. const int64_t IH = is_2D ? ne1 : 1;
  5465. const int64_t IW = ne0;
  5466. const int64_t KH = is_2D ? ne11 : 1;
  5467. const int64_t KW = ne10;
  5468. const int64_t OH = is_2D ? ne02 : 1;
  5469. const int64_t OW = ne01;
  5470. int ofs0 = is_2D ? nb3 : nb2;
  5471. int ofs1 = is_2D ? nb2 : nb1;
  5472. GGML_ASSERT(nb0 == sizeof(float));
  5473. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  5474. {
  5475. float * const wdata = (float *) dst->data;
  5476. for (int64_t in = 0; in < N; in++) {
  5477. for (int64_t iic = ith; iic < IC; iic += nth) {
  5478. for (int64_t iih = 0; iih < IH; iih++) {
  5479. for (int64_t iiw = 0; iiw < IW; iiw++) {
  5480. // micro kernel
  5481. float grad = 0.0f;
  5482. for (int64_t ikh = 0; ikh < KH; ikh++) {
  5483. for (int64_t ikw = 0; ikw < KW; ikw++) {
  5484. // For s0 > 1 some values were skipped over in the forward pass.
  5485. // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
  5486. const int64_t tmpw = (iiw + p0 - ikw*d0);
  5487. if (tmpw % s0 != 0) {
  5488. continue;
  5489. }
  5490. const int64_t iow = tmpw / s0;
  5491. // Equivalent logic as above except for s1.
  5492. int64_t ioh;
  5493. if (is_2D) {
  5494. const int64_t tmph = iih + p1 - ikh*d1;
  5495. if (tmph % s1 != 0) {
  5496. continue;
  5497. }
  5498. ioh = tmph / s1;
  5499. } else {
  5500. ioh = 0;
  5501. }
  5502. if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
  5503. continue;
  5504. }
  5505. const float * const grad_in = (const float *) src0->data
  5506. + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  5507. grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
  5508. }
  5509. }
  5510. float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
  5511. dst_data[iih*IW + iiw] = grad;
  5512. }
  5513. }
  5514. }
  5515. }
  5516. }
  5517. }
  5518. static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
  5519. void * a, void * b, float * c) {
  5520. const ggml_type_traits * traits = ggml_get_type_traits(type);
  5521. struct ggml_tensor src1 = {};
  5522. src1.type = type;
  5523. src1.ne[0] = k;
  5524. src1.ne[1] = m;
  5525. src1.ne[2] = 1;
  5526. src1.ne[3] = 1;
  5527. src1.nb[0] = traits->type_size;
  5528. src1.nb[1] = k * traits->type_size;
  5529. src1.nb[2] = src1.nb[1];
  5530. src1.nb[3] = src1.nb[2];
  5531. src1.data = a;
  5532. struct ggml_tensor src0 = {};
  5533. src0.type = type;
  5534. src0.ne[0] = k;
  5535. src0.ne[1] = n;
  5536. src0.ne[2] = 1;
  5537. src0.ne[3] = 1;
  5538. src0.nb[0] = traits->type_size;
  5539. src0.nb[1] = k * traits->type_size;
  5540. src0.nb[2] = src0.nb[1];
  5541. src0.nb[3] = src0.nb[2];
  5542. src0.data = b;
  5543. struct ggml_tensor dst = {};
  5544. dst.ne[0] = n;
  5545. dst.ne[1] = m;
  5546. dst.ne[2] = 1;
  5547. dst.ne[3] = 1;
  5548. dst.nb[0] = sizeof(float);
  5549. dst.nb[1] = n * sizeof(float);
  5550. dst.nb[2] = dst.nb[1];
  5551. dst.nb[3] = dst.nb[2];
  5552. dst.data = c;
  5553. dst.src[0] = &src0;
  5554. dst.src[1] = &src1;
  5555. ggml_compute_forward_mul_mat(params, &dst);
  5556. }
  5557. // ggml_compute_forward_conv_2d
  5558. static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
  5559. const ggml_tensor * kernel, // [KW, KH, IC, OC]
  5560. const ggml_tensor * src, // [W, H, C, N]
  5561. ggml_tensor * dst, // [OW, OH, OC, N]
  5562. ggml_type kernel_type) {
  5563. GGML_ASSERT(ggml_is_contiguous(kernel));
  5564. GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
  5565. GGML_ASSERT(kernel->type == kernel_type);
  5566. const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
  5567. const int32_t stride_x = dst->op_params[0];
  5568. const int32_t stride_y = dst->op_params[1];
  5569. const int32_t pad_x = dst->op_params[2];
  5570. const int32_t pad_y = dst->op_params[3];
  5571. const int32_t dilation_x = dst->op_params[4];
  5572. const int32_t dilation_y = dst->op_params[5];
  5573. const int64_t c_in = src->ne[2];
  5574. const int64_t c_out = kernel->ne[3];
  5575. GGML_ASSERT(c_in == kernel->ne[2]);
  5576. const int64_t src_w = src->ne[0];
  5577. const int64_t src_h = src->ne[1];
  5578. const int64_t knl_w = kernel->ne[0];
  5579. const int64_t knl_h = kernel->ne[1];
  5580. const int64_t dst_w = dst->ne[0];
  5581. const int64_t dst_h = dst->ne[1];
  5582. const float * src_data = (float *) src->data;
  5583. void * knl_data = kernel->data;
  5584. float * dst_data = (float *) dst->data;
  5585. const int64_t knl_n = knl_w * knl_h * c_in;
  5586. const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
  5587. const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
  5588. const int64_t batch_size = params->wsize / space_per_patch;
  5589. const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
  5590. const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
  5591. GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
  5592. void * tmp = params->wdata;
  5593. for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
  5594. const int64_t patch_start_batch = batch_i * patches_per_batch;
  5595. const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
  5596. patch_total);
  5597. const int64_t patch_n = patch_end_batch - patch_start_batch;
  5598. const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
  5599. const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
  5600. const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
  5601. //im2col for a patch
  5602. for (int64_t p = patch_start; p < patch_end; ++p) {
  5603. const int64_t batch_n = p / (dst_w * dst_h);
  5604. const int64_t src_x = (p / dst_w) % dst_h;
  5605. const int64_t src_y = p % dst_w;
  5606. const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
  5607. char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
  5608. for (int64_t ic = 0; ic < c_in; ++ic) {
  5609. for (int64_t ky = 0; ky < knl_h; ++ky) {
  5610. for (int64_t kx = 0; kx < knl_w; ++kx) {
  5611. const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
  5612. const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
  5613. int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
  5614. float src_val;
  5615. if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
  5616. src_val = 0.0f;
  5617. } else {
  5618. const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
  5619. src_val = *src_ptr;
  5620. }
  5621. char * element_ptr = dst_row + dst_idx * traits->type_size;
  5622. if (kernel_type == GGML_TYPE_F32) {
  5623. *(float *) element_ptr = src_val;
  5624. } else if (kernel_type == GGML_TYPE_F16) {
  5625. *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
  5626. }
  5627. }
  5628. }
  5629. }
  5630. } // patches handled by this thread
  5631. ggml_barrier(params->threadpool);
  5632. float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
  5633. GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
  5634. // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
  5635. ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
  5636. ggml_barrier(params->threadpool);
  5637. //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
  5638. const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
  5639. const int64_t permute_start = params->ith * permute_per_thread;
  5640. const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
  5641. for (int64_t i = permute_start; i < permute_end; ++i) {
  5642. const int64_t p = patch_start_batch + i;
  5643. const int64_t batch_n = p / (dst_w * dst_h);
  5644. const int64_t dst_y = (p / dst_w) % dst_h;
  5645. const int64_t dst_x = p % dst_w;
  5646. for (int64_t oc = 0; oc < c_out; ++oc) {
  5647. const float value = gemm_output[i * c_out + oc];
  5648. float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
  5649. *dst_ptr = value;
  5650. }
  5651. }
  5652. }
  5653. }
  5654. void ggml_compute_forward_conv_2d(
  5655. const ggml_compute_params * params,
  5656. ggml_tensor * dst) {
  5657. const ggml_tensor * src0 = dst->src[0];
  5658. const ggml_tensor * src1 = dst->src[1];
  5659. ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
  5660. }
  5661. // ggml_compute_forward_conv_transpose_2d
  5662. void ggml_compute_forward_conv_transpose_2d(
  5663. const ggml_compute_params * params,
  5664. ggml_tensor * dst) {
  5665. const ggml_tensor * src0 = dst->src[0];
  5666. const ggml_tensor * src1 = dst->src[1];
  5667. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  5668. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  5669. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  5670. GGML_TENSOR_BINARY_OP_LOCALS
  5671. const int ith = params->ith;
  5672. const int nth = params->nth;
  5673. const int nk = ne00*ne01*ne02*ne03;
  5674. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  5675. GGML_ASSERT(nb10 == sizeof(float));
  5676. if (ith == 0) {
  5677. memset(params->wdata, 0, params->wsize);
  5678. // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
  5679. {
  5680. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  5681. for (int64_t i03 = 0; i03 < ne03; i03++) {
  5682. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5683. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
  5684. ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
  5685. for (int64_t i01 = 0; i01 < ne01; i01++) {
  5686. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5687. dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
  5688. }
  5689. }
  5690. }
  5691. }
  5692. }
  5693. // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
  5694. {
  5695. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  5696. for (int i12 = 0; i12 < ne12; i12++) {
  5697. for (int i11 = 0; i11 < ne11; i11++) {
  5698. const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
  5699. ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
  5700. for (int i10 = 0; i10 < ne10; i10++) {
  5701. dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
  5702. }
  5703. }
  5704. }
  5705. }
  5706. memset(dst->data, 0, ggml_nbytes(dst));
  5707. }
  5708. ggml_barrier(params->threadpool);
  5709. const int32_t stride = ggml_get_op_params_i32(dst, 0);
  5710. // total patches in dst
  5711. const int np = ne2;
  5712. // patches per thread
  5713. const int dp = (np + nth - 1)/nth;
  5714. // patch range for this thread
  5715. const int ip0 = dp*ith;
  5716. const int ip1 = MIN(ip0 + dp, np);
  5717. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  5718. ggml_fp16_t * const wdata_src = wdata + nk;
  5719. for (int i2 = ip0; i2 < ip1; i2++) { // Cout
  5720. float * dst_data = (float *)((char *) dst->data + i2*nb2);
  5721. ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
  5722. for (int i11 = 0; i11 < ne11; i11++) {
  5723. for (int i10 = 0; i10 < ne10; i10++) {
  5724. const int i1n = i11*ne10*ne12 + i10*ne12;
  5725. for (int i01 = 0; i01 < ne01; i01++) {
  5726. for (int i00 = 0; i00 < ne00; i00++) {
  5727. float v = 0;
  5728. ggml_vec_dot_f16(ne03, &v, 0,
  5729. wdata_src + i1n, 0,
  5730. wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
  5731. dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
  5732. }
  5733. }
  5734. }
  5735. }
  5736. }
  5737. }
  5738. // ggml_compute_forward_conv_2d_dw
  5739. struct ggml_conv_2d_dw_params {
  5740. int64_t channels;
  5741. int64_t batch;
  5742. int64_t src_w;
  5743. int64_t src_h;
  5744. int64_t dst_w;
  5745. int64_t dst_h;
  5746. int64_t knl_w;
  5747. int64_t knl_h;
  5748. int stride_x;
  5749. int stride_y;
  5750. int pad_x;
  5751. int pad_y;
  5752. int dilation_x;
  5753. int dilation_y;
  5754. };
  5755. static void ggml_compute_forward_conv_2d_dw_cwhn(
  5756. const ggml_compute_params * params,
  5757. const ggml_tensor * src,
  5758. const ggml_tensor * kernel,
  5759. ggml_tensor * dst,
  5760. const ggml_conv_2d_dw_params & p) {
  5761. const int64_t c = p.channels;
  5762. const float * knl_data = (const float *)kernel->data;
  5763. const int64_t rows_total = p.dst_h * p.batch;
  5764. const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
  5765. const int64_t row_start = params->ith * rows_per_thread;
  5766. const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
  5767. #ifdef GGML_SIMD
  5768. const int64_t pkg_size = GGML_F32_EPR;
  5769. const int64_t pkg_count = c / pkg_size;
  5770. const int64_t c_pkg_end = pkg_count * pkg_size;
  5771. #else
  5772. const int64_t c_pkg_end = 0;
  5773. #endif
  5774. for (int64_t row = row_start; row < row_end; ++row) {
  5775. const int64_t dst_y = row % p.dst_h;
  5776. const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
  5777. for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
  5778. float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
  5779. const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
  5780. const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
  5781. #ifdef GGML_SIMD
  5782. // Vectorized loop
  5783. for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
  5784. GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
  5785. for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
  5786. const int64_t src_y = src_y_base + knl_y * p.dilation_y;
  5787. if (src_y < 0 || src_y >= p.src_h) {
  5788. continue;
  5789. }
  5790. for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
  5791. const int64_t src_x = src_x_base + knl_x * p.dilation_x;
  5792. if (src_x < 0 || src_x >= p.src_w) {
  5793. continue;
  5794. }
  5795. GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
  5796. GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
  5797. sum = GGML_F32_VEC_FMA(sum, k, s);
  5798. }
  5799. }
  5800. GGML_F32_VEC_STORE(dst_data + c_i, sum);
  5801. }
  5802. #endif
  5803. // Scalar loop
  5804. for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
  5805. float sum = 0.0f;
  5806. for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
  5807. const int64_t src_y = src_y_base + knl_y * p.dilation_y;
  5808. if (src_y < 0 || src_y >= p.src_h) {
  5809. continue;
  5810. }
  5811. for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
  5812. const int64_t src_x = src_x_base + knl_x * p.dilation_x;
  5813. if (src_x < 0 || src_x >= p.src_w) {
  5814. continue;
  5815. }
  5816. sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
  5817. * src_data[(src_y * p.src_w + src_x) * c + c_i];
  5818. }
  5819. }
  5820. dst_data[c_i] = sum;
  5821. }
  5822. }
  5823. }
  5824. }
  5825. static void ggml_compute_forward_conv_2d_dw_whcn(
  5826. const ggml_compute_params * params,
  5827. const ggml_tensor * src,
  5828. const ggml_tensor * kernel,
  5829. ggml_tensor * dst,
  5830. const ggml_conv_2d_dw_params & p) {
  5831. const int64_t n = p.channels * p.batch;
  5832. const int64_t per_thread = (n + params->nth - 1) / params->nth;
  5833. const int64_t start = params->ith * per_thread;
  5834. const int64_t end = MIN(start + per_thread, n);
  5835. for (int64_t i = start; i < end; ++i) {
  5836. const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
  5837. const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
  5838. float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
  5839. for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
  5840. for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
  5841. float sum = 0.0f;
  5842. for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
  5843. const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
  5844. if (src_y < 0 || src_y >= p.src_h) {
  5845. continue;
  5846. }
  5847. for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
  5848. const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
  5849. if (src_x < 0 || src_x >= p.src_w) {
  5850. continue;
  5851. }
  5852. sum += knl_data[knl_y * p.knl_w + knl_x]
  5853. * src_data[src_y * p.src_w + src_x];
  5854. }
  5855. }
  5856. dst_data[dst_y * p.dst_w + dst_x] = sum;
  5857. }
  5858. }
  5859. }
  5860. }
  5861. void ggml_compute_forward_conv_2d_dw(
  5862. const ggml_compute_params * params,
  5863. ggml_tensor * dst) {
  5864. const ggml_tensor * kernel = dst->src[0];
  5865. const ggml_tensor * src = dst->src[1];
  5866. ggml_conv_2d_dw_params p;
  5867. p.channels = src->ne[2];
  5868. p.batch = src->ne[3];
  5869. p.src_w = src->ne[0];
  5870. p.src_h = src->ne[1];
  5871. p.dst_w = dst->ne[0];
  5872. p.dst_h = dst->ne[1];
  5873. p.knl_w = kernel->ne[0];
  5874. p.knl_h = kernel->ne[1];
  5875. p.stride_x = dst->op_params[0];
  5876. p.stride_y = dst->op_params[1];
  5877. p.pad_x = dst->op_params[2];
  5878. p.pad_y = dst->op_params[3];
  5879. p.dilation_x = dst->op_params[4];
  5880. p.dilation_y = dst->op_params[5];
  5881. GGML_ASSERT(kernel->ne[3] == p.channels);
  5882. GGML_ASSERT(dst->ne[3] == p.batch);
  5883. if (ggml_is_contiguous(src)) {
  5884. ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
  5885. } else if (ggml_is_contiguous_channels(src)) {
  5886. // kernel should also have channels most contiguous in memory
  5887. GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
  5888. ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
  5889. } else {
  5890. GGML_ABORT("non-contiguous memory layout not supported");
  5891. }
  5892. }
  5893. // ggml_compute_forward_pool_1d_sk_p0
  5894. static void ggml_compute_forward_pool_1d_sk_p0(
  5895. const ggml_compute_params * params,
  5896. const ggml_op_pool op,
  5897. const int k,
  5898. ggml_tensor * dst) {
  5899. const ggml_tensor * src = dst->src[0];
  5900. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  5901. if (params->ith != 0) {
  5902. return;
  5903. }
  5904. const char * cdata = (const char *)src->data;
  5905. const char * const data_end = cdata + ggml_nbytes(src);
  5906. float * drow = (float *)dst->data;
  5907. const int64_t rs = dst->ne[0];
  5908. while (cdata < data_end) {
  5909. const void * srow = (const void *)cdata;
  5910. int j = 0;
  5911. for (int64_t i = 0; i < rs; ++i) {
  5912. switch (op) {
  5913. case GGML_OP_POOL_AVG: drow[i] = 0; break;
  5914. case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
  5915. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5916. }
  5917. for (int ki = 0; ki < k; ++ki) {
  5918. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  5919. switch (op) {
  5920. case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
  5921. case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
  5922. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5923. }
  5924. ++j;
  5925. }
  5926. switch (op) {
  5927. case GGML_OP_POOL_AVG: drow[i] /= k; break;
  5928. case GGML_OP_POOL_MAX: break;
  5929. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5930. }
  5931. }
  5932. cdata += src->nb[1];
  5933. drow += rs;
  5934. }
  5935. }
  5936. // ggml_compute_forward_pool_1d
  5937. void ggml_compute_forward_pool_1d(
  5938. const ggml_compute_params * params,
  5939. ggml_tensor * dst) {
  5940. const int32_t * opts = (const int32_t *)dst->op_params;
  5941. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  5942. const int k0 = opts[1];
  5943. const int s0 = opts[2];
  5944. const int p0 = opts[3];
  5945. GGML_ASSERT(p0 == 0); // padding not supported
  5946. GGML_ASSERT(k0 == s0); // only s = k supported
  5947. ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
  5948. }
  5949. // ggml_compute_forward_pool_2d
  5950. void ggml_compute_forward_pool_2d(
  5951. const ggml_compute_params * params,
  5952. ggml_tensor * dst) {
  5953. const ggml_tensor * src = dst->src[0];
  5954. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  5955. if (params->ith != 0) {
  5956. return;
  5957. }
  5958. const int32_t * opts = (const int32_t *)dst->op_params;
  5959. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  5960. const int k0 = opts[1];
  5961. const int k1 = opts[2];
  5962. const int s0 = opts[3];
  5963. const int s1 = opts[4];
  5964. const int p0 = opts[5];
  5965. const int p1 = opts[6];
  5966. const char * cdata = (const char*)src->data;
  5967. const char * const data_end = cdata + ggml_nbytes(src);
  5968. const int64_t px = dst->ne[0];
  5969. const int64_t py = dst->ne[1];
  5970. const int64_t pa = px * py;
  5971. float * dplane = (float *)dst->data;
  5972. const int ka = k0 * k1;
  5973. const int offset0 = -p0;
  5974. const int offset1 = -p1;
  5975. while (cdata < data_end) {
  5976. for (int oy = 0; oy < py; ++oy) {
  5977. float * const drow = dplane + oy * px;
  5978. for (int ox = 0; ox < px; ++ox) {
  5979. float * const out = drow + ox;
  5980. switch (op) {
  5981. case GGML_OP_POOL_AVG: *out = 0; break;
  5982. case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
  5983. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5984. }
  5985. const int ix = offset0 + ox * s0;
  5986. const int iy = offset1 + oy * s1;
  5987. for (int ky = 0; ky < k1; ++ky) {
  5988. if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
  5989. const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
  5990. for (int kx = 0; kx < k0; ++kx) {
  5991. int j = ix + kx;
  5992. if (j < 0 || j >= src->ne[0]) continue;
  5993. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  5994. switch (op) {
  5995. case GGML_OP_POOL_AVG: *out += srow_j; break;
  5996. case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
  5997. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  5998. }
  5999. }
  6000. }
  6001. switch (op) {
  6002. case GGML_OP_POOL_AVG: *out /= ka; break;
  6003. case GGML_OP_POOL_MAX: break;
  6004. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  6005. }
  6006. }
  6007. }
  6008. cdata += src->nb[2];
  6009. dplane += pa;
  6010. }
  6011. }
  6012. // ggml_compute_forward_pool_2d_back
  6013. void ggml_compute_forward_pool_2d_back(
  6014. const ggml_compute_params * params,
  6015. ggml_tensor * dst) {
  6016. const ggml_tensor * src = dst->src[0];
  6017. const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
  6018. assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  6019. if (params->ith != 0) {
  6020. return;
  6021. }
  6022. const int32_t * opts = (const int32_t *)dst->op_params;
  6023. ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  6024. const int k0 = opts[1];
  6025. const int k1 = opts[2];
  6026. const int s0 = opts[3];
  6027. const int s1 = opts[4];
  6028. const int p0 = opts[5];
  6029. const int p1 = opts[6];
  6030. char * cdata = (char *) dst->data;
  6031. const char * cdataf = (const char *) dstf->data;
  6032. const char * const data_end = cdata + ggml_nbytes(dst);
  6033. GGML_ASSERT(params->ith == 0);
  6034. memset(cdata, 0, ggml_nbytes(dst));
  6035. const int64_t px = src->ne[0];
  6036. const int64_t py = src->ne[1];
  6037. const int64_t pa = px * py;
  6038. const float * splane = (const float *) src->data;
  6039. const int ka = k0 * k1;
  6040. const int offset0 = -p0;
  6041. const int offset1 = -p1;
  6042. while (cdata < data_end) {
  6043. for (int oy = 0; oy < py; ++oy) {
  6044. const float * const srow = splane + oy * px;
  6045. for (int ox = 0; ox < px; ++ox) {
  6046. const float grad0 = srow[ox];
  6047. const int ix = offset0 + ox * s0;
  6048. const int iy = offset1 + oy * s1;
  6049. if (op == GGML_OP_POOL_MAX) {
  6050. float maxval = -FLT_MAX;
  6051. int kxmax = -1;
  6052. int kymax = -1;
  6053. for (int ky = 0; ky < k1; ++ky) {
  6054. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  6055. continue;
  6056. }
  6057. const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
  6058. for (int kx = 0; kx < k0; ++kx) {
  6059. int j = ix + kx;
  6060. if (j < 0 || j >= dst->ne[0]) {
  6061. continue;
  6062. }
  6063. const float val = dst->type == GGML_TYPE_F32 ?
  6064. ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
  6065. if (val <= maxval) {
  6066. continue;
  6067. }
  6068. maxval = val;
  6069. kxmax = kx;
  6070. kymax = ky;
  6071. }
  6072. }
  6073. if (kxmax == -1 || kymax == -1) {
  6074. continue;
  6075. }
  6076. void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
  6077. const int j = ix + kxmax;
  6078. if (dst->type == GGML_TYPE_F32) {
  6079. ((float *) drow)[j] += grad0;
  6080. } else {
  6081. ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
  6082. }
  6083. } else if (op == GGML_OP_POOL_AVG) {
  6084. const float grad = grad0 / ka;
  6085. for (int ky = 0; ky < k1; ++ky) {
  6086. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  6087. continue;
  6088. }
  6089. void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
  6090. for (int kx = 0; kx < k0; ++kx) {
  6091. int j = ix + kx;
  6092. if (j < 0 || j >= dst->ne[0]) {
  6093. continue;
  6094. }
  6095. if (dst->type == GGML_TYPE_F32) {
  6096. ((float *) drow)[j] += grad;
  6097. } else {
  6098. ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
  6099. }
  6100. }
  6101. }
  6102. } else {
  6103. GGML_ASSERT(false);
  6104. }
  6105. }
  6106. }
  6107. cdata += dst->nb[2];
  6108. cdataf += dst->nb[2];
  6109. splane += pa;
  6110. }
  6111. }
  6112. // ggml_compute_forward_upscale
  6113. static void ggml_compute_forward_upscale_f32(
  6114. const ggml_compute_params * params,
  6115. ggml_tensor * dst) {
  6116. const ggml_tensor * src0 = dst->src[0];
  6117. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  6118. const int ith = params->ith;
  6119. const int nth = params->nth;
  6120. GGML_TENSOR_UNARY_OP_LOCALS
  6121. float sf0 = (float)ne0/src0->ne[0];
  6122. float sf1 = (float)ne1/src0->ne[1];
  6123. float sf2 = (float)ne2/src0->ne[2];
  6124. float sf3 = (float)ne3/src0->ne[3];
  6125. const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
  6126. const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
  6127. if (mode == GGML_SCALE_MODE_NEAREST) {
  6128. for (int64_t i3 = 0; i3 < ne3; i3++) {
  6129. const int64_t i03 = i3 / sf3;
  6130. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  6131. const int64_t i02 = i2 / sf2;
  6132. for (int64_t i1 = 0; i1 < ne1; i1++) {
  6133. const int64_t i01 = i1 / sf1;
  6134. for (int64_t i0 = 0; i0 < ne0; i0++) {
  6135. const int64_t i00 = i0 / sf0;
  6136. const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  6137. float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  6138. *y = *x;
  6139. }
  6140. }
  6141. }
  6142. }
  6143. } else if (mode == GGML_SCALE_MODE_BILINEAR) {
  6144. float pixel_offset = 0.5f;
  6145. if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
  6146. pixel_offset = 0.0f;
  6147. sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
  6148. sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
  6149. }
  6150. for (int64_t i3 = 0; i3 < ne3; i3++) {
  6151. const int64_t i03 = i3 / sf3;
  6152. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  6153. const int64_t i02 = i2 / sf2;
  6154. for (int64_t i1 = 0; i1 < ne1; i1++) {
  6155. const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
  6156. int64_t y0 = (int64_t)floorf(y);
  6157. int64_t y1 = y0 + 1;
  6158. y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
  6159. y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
  6160. float dy = y - (float)y0;
  6161. dy = std::max(0.0f, std::min(dy, 1.0f));
  6162. for (int64_t i0 = 0; i0 < ne0; i0++) {
  6163. const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
  6164. int64_t x0 = (int64_t)floorf(x);
  6165. int64_t x1 = x0 + 1;
  6166. x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
  6167. x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
  6168. float dx = x - (float)x0;
  6169. dx = std::max(0.0f, std::min(dx, 1.0f));
  6170. // fetch the four surrounding pixel values and interpolate
  6171. const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
  6172. const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
  6173. const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
  6174. const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
  6175. const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
  6176. float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  6177. *y_dst = val;
  6178. }
  6179. }
  6180. }
  6181. }
  6182. } else {
  6183. GGML_ABORT("unsupported upscale mode");
  6184. }
  6185. }
  6186. void ggml_compute_forward_upscale(
  6187. const ggml_compute_params * params,
  6188. ggml_tensor * dst) {
  6189. const ggml_tensor * src0 = dst->src[0];
  6190. switch (src0->type) {
  6191. case GGML_TYPE_F32:
  6192. {
  6193. ggml_compute_forward_upscale_f32(params, dst);
  6194. } break;
  6195. default:
  6196. {
  6197. GGML_ABORT("fatal error");
  6198. }
  6199. }
  6200. }
  6201. // ggml_compute_forward_pad
  6202. static void ggml_compute_forward_pad_f32(
  6203. const ggml_compute_params * params,
  6204. ggml_tensor * dst) {
  6205. const ggml_tensor * src0 = dst->src[0];
  6206. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6207. GGML_ASSERT( dst->nb[0] == sizeof(float));
  6208. const int ith = params->ith;
  6209. const int nth = params->nth;
  6210. GGML_TENSOR_UNARY_OP_LOCALS
  6211. float * dst_ptr = (float *) dst->data;
  6212. // TODO: optimize
  6213. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  6214. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  6215. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  6216. for (int64_t i3 = 0; i3 < ne3; ++i3) {
  6217. const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
  6218. const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  6219. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  6220. dst_ptr[dst_idx] = *src_ptr;
  6221. } else {
  6222. dst_ptr[dst_idx] = 0;
  6223. }
  6224. }
  6225. }
  6226. }
  6227. }
  6228. }
  6229. void ggml_compute_forward_pad(
  6230. const ggml_compute_params * params,
  6231. ggml_tensor * dst) {
  6232. const ggml_tensor * src0 = dst->src[0];
  6233. switch (src0->type) {
  6234. case GGML_TYPE_F32:
  6235. {
  6236. ggml_compute_forward_pad_f32(params, dst);
  6237. } break;
  6238. default:
  6239. {
  6240. GGML_ABORT("fatal error");
  6241. }
  6242. }
  6243. }
  6244. // ggml_compute_forward_pad_reflect_1d
  6245. void ggml_compute_forward_pad_reflect_1d(
  6246. const ggml_compute_params * params,
  6247. ggml_tensor * dst) {
  6248. const ggml_tensor * src0 = dst->src[0];
  6249. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  6250. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  6251. const int ith = params->ith;
  6252. const int nth = params->nth;
  6253. const int32_t * opts = (const int32_t *) dst->op_params;
  6254. const int p0 = opts[0];
  6255. const int p1 = opts[1];
  6256. GGML_TENSOR_UNARY_OP_LOCALS
  6257. for (int64_t i3 = 0; i3 < ne3; i3++) {
  6258. for (int64_t i2 = 0; i2 < ne2; i2++) {
  6259. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  6260. float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
  6261. float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
  6262. ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
  6263. for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
  6264. for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
  6265. }
  6266. }
  6267. }
  6268. }
  6269. // ggml_compute_forward_roll
  6270. static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
  6271. if (i < 0) {
  6272. return i + ne;
  6273. } else if (i >= ne) {
  6274. return i - ne;
  6275. }
  6276. return i;
  6277. }
  6278. static void ggml_compute_forward_roll_f32(
  6279. const ggml_compute_params * params,
  6280. ggml_tensor * dst) {
  6281. const ggml_tensor * src0 = dst->src[0];
  6282. const float * src_data = (const float *) src0->data;
  6283. float * dst_data = (float *) dst->data;
  6284. GGML_TENSOR_UNARY_OP_LOCALS
  6285. const int s0 = ggml_get_op_params_i32(dst, 0);
  6286. const int s1 = ggml_get_op_params_i32(dst, 1);
  6287. const int s2 = ggml_get_op_params_i32(dst, 2);
  6288. const int s3 = ggml_get_op_params_i32(dst, 3);
  6289. const int64_t total = ne1 * ne2 * ne3;
  6290. const int64_t per_thread = (total + params->nth) / params->nth;
  6291. const int64_t start = params->ith * per_thread;
  6292. const int64_t end = std::min(start + per_thread, total);
  6293. for (int64_t i = start; i < end; ++i) {
  6294. const int64_t i1 = i % ne1;
  6295. const int64_t i2 = (i / ne1) % ne2;
  6296. const int64_t i3 = i / (ne2 * ne1);
  6297. float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
  6298. const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
  6299. const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
  6300. const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
  6301. const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
  6302. const int64_t s = ggml_wrap_index(-s0, ne00);
  6303. const int64_t n = ne00 - s;
  6304. ggml_vec_cpy_f32(n, dst_row, src_row + s);
  6305. ggml_vec_cpy_f32(s, dst_row + n, src_row);
  6306. }
  6307. }
  6308. void ggml_compute_forward_roll(
  6309. const ggml_compute_params * params,
  6310. ggml_tensor * dst) {
  6311. const ggml_tensor * src0 = dst->src[0];
  6312. switch (src0->type) {
  6313. case GGML_TYPE_F32:
  6314. {
  6315. ggml_compute_forward_roll_f32(params, dst);
  6316. } break;
  6317. default:
  6318. {
  6319. GGML_ABORT("fatal error");
  6320. }
  6321. }
  6322. }
  6323. // ggml_compute_forward_arange
  6324. static void ggml_compute_forward_arange_f32(
  6325. const ggml_compute_params * params,
  6326. ggml_tensor * dst) {
  6327. GGML_ASSERT(dst->nb[0] == sizeof(float));
  6328. const int ith = params->ith;
  6329. const int nth = params->nth;
  6330. const float start = ggml_get_op_params_f32(dst, 0);
  6331. const float stop = ggml_get_op_params_f32(dst, 1);
  6332. const float step = ggml_get_op_params_f32(dst, 2);
  6333. const int64_t steps = (int64_t) ceilf((stop - start) / step);
  6334. GGML_ASSERT(ggml_nelements(dst) == steps);
  6335. for (int64_t i = ith; i < steps; i+= nth) {
  6336. float value = start + step * i;
  6337. ((float *)dst->data)[i] = value;
  6338. }
  6339. }
  6340. void ggml_compute_forward_arange(
  6341. const ggml_compute_params * params,
  6342. ggml_tensor * dst) {
  6343. switch (dst->type) {
  6344. case GGML_TYPE_F32:
  6345. {
  6346. ggml_compute_forward_arange_f32(params, dst);
  6347. } break;
  6348. default:
  6349. {
  6350. GGML_ABORT("fatal error");
  6351. }
  6352. }
  6353. }
  6354. static void ggml_compute_forward_timestep_embedding_f32(
  6355. const ggml_compute_params * params,
  6356. ggml_tensor * dst) {
  6357. const ggml_tensor * src0 = dst->src[0];
  6358. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6359. const int ith = params->ith;
  6360. const int nth = params->nth;
  6361. GGML_TENSOR_UNARY_OP_LOCALS
  6362. const int dim = ggml_get_op_params_i32(dst, 0);
  6363. const int max_period = ggml_get_op_params_i32(dst, 1);
  6364. int half = dim / 2;
  6365. for (int64_t i = 0; i < ne00; i++) {
  6366. float * embed_data = (float *)((char *) dst->data + i*nb1);
  6367. for (int64_t j = ith; j < half; j += nth) {
  6368. float timestep = ((float *)src0->data)[i];
  6369. float freq = (float)expf(-logf(max_period) * j / half);
  6370. float arg = timestep * freq;
  6371. embed_data[j] = cosf(arg);
  6372. embed_data[j + half] = sinf(arg);
  6373. }
  6374. if (dim % 2 != 0 && ith == 0) {
  6375. embed_data[dim] = 0.f;
  6376. }
  6377. }
  6378. }
  6379. void ggml_compute_forward_timestep_embedding(
  6380. const ggml_compute_params * params,
  6381. ggml_tensor * dst) {
  6382. const ggml_tensor * src0 = dst->src[0];
  6383. switch (src0->type) {
  6384. case GGML_TYPE_F32:
  6385. {
  6386. ggml_compute_forward_timestep_embedding_f32(params, dst);
  6387. } break;
  6388. default:
  6389. {
  6390. GGML_ABORT("fatal error");
  6391. }
  6392. }
  6393. }
  6394. // ggml_compute_forward_argsort
  6395. static void ggml_compute_forward_argsort_f32(
  6396. const ggml_compute_params * params,
  6397. ggml_tensor * dst) {
  6398. const ggml_tensor * src0 = dst->src[0];
  6399. GGML_TENSOR_UNARY_OP_LOCALS
  6400. GGML_ASSERT(nb0 == sizeof(float));
  6401. const int ith = params->ith;
  6402. const int nth = params->nth;
  6403. const int64_t nr = ggml_nrows(src0);
  6404. ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
  6405. for (int64_t i = ith; i < nr; i += nth) {
  6406. int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
  6407. const float * src_data = (float *)((char *) src0->data + i*nb01);
  6408. for (int64_t j = 0; j < ne0; j++) {
  6409. dst_data[j] = j;
  6410. }
  6411. // C doesn't have a functional sort, so we do a bubble sort instead
  6412. for (int64_t j = 0; j < ne0; j++) {
  6413. for (int64_t k = j + 1; k < ne0; k++) {
  6414. if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
  6415. (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
  6416. int32_t tmp = dst_data[j];
  6417. dst_data[j] = dst_data[k];
  6418. dst_data[k] = tmp;
  6419. }
  6420. }
  6421. }
  6422. }
  6423. }
  6424. void ggml_compute_forward_argsort(
  6425. const ggml_compute_params * params,
  6426. ggml_tensor * dst) {
  6427. const ggml_tensor * src0 = dst->src[0];
  6428. switch (src0->type) {
  6429. case GGML_TYPE_F32:
  6430. {
  6431. ggml_compute_forward_argsort_f32(params, dst);
  6432. } break;
  6433. default:
  6434. {
  6435. GGML_ABORT("fatal error");
  6436. }
  6437. }
  6438. }
  6439. // ggml_compute_forward_flash_attn_ext
  6440. static void ggml_compute_forward_flash_attn_ext_f16(
  6441. const ggml_compute_params * params,
  6442. const ggml_tensor * q,
  6443. const ggml_tensor * k,
  6444. const ggml_tensor * v,
  6445. const ggml_tensor * mask,
  6446. ggml_tensor * dst) {
  6447. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  6448. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  6449. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  6450. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  6451. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  6452. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  6453. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6454. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  6455. const int ith = params->ith;
  6456. const int nth = params->nth;
  6457. const int64_t DK = nek0;
  6458. const int64_t DV = nev0;
  6459. const int64_t N = neq1;
  6460. GGML_ASSERT(ne0 == DV);
  6461. GGML_ASSERT(ne2 == N);
  6462. // input tensor rows must be contiguous
  6463. GGML_ASSERT(nbq0 == ggml_type_size(q->type));
  6464. GGML_ASSERT(nbk0 == ggml_type_size(k->type));
  6465. GGML_ASSERT(nbv0 == ggml_type_size(v->type));
  6466. GGML_ASSERT(neq0 == DK);
  6467. GGML_ASSERT(nek0 == DK);
  6468. GGML_ASSERT(nev0 == DV);
  6469. GGML_ASSERT(neq1 == N);
  6470. // dst cannot be transposed or permuted
  6471. GGML_ASSERT(nb0 == sizeof(float));
  6472. GGML_ASSERT(nb0 <= nb1);
  6473. GGML_ASSERT(nb1 <= nb2);
  6474. GGML_ASSERT(nb2 <= nb3);
  6475. // broadcast factors
  6476. const int64_t rk2 = neq2/nek2;
  6477. const int64_t rk3 = neq3/nek3;
  6478. const int64_t rv2 = neq2/nev2;
  6479. const int64_t rv3 = neq3/nev3;
  6480. // parallelize by q rows using ggml_vec_dot_f32
  6481. // total rows in q
  6482. const int nr = neq1*neq2*neq3;
  6483. // rows per thread
  6484. const int dr = (nr + nth - 1)/nth;
  6485. // row range for this thread
  6486. const int ir0 = dr*ith;
  6487. const int ir1 = MIN(ir0 + dr, nr);
  6488. float scale = 1.0f;
  6489. float max_bias = 0.0f;
  6490. float logit_softcap = 0.0f;
  6491. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  6492. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  6493. memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
  6494. if (logit_softcap != 0) {
  6495. scale /= logit_softcap;
  6496. }
  6497. const uint32_t n_head = neq2;
  6498. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  6499. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  6500. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  6501. ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
  6502. ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
  6503. ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
  6504. ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
  6505. GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
  6506. GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
  6507. // loop over n_batch and n_head
  6508. for (int ir = ir0; ir < ir1; ++ir) {
  6509. // q indices
  6510. const int iq3 = ir/(neq2*neq1);
  6511. const int iq2 = (ir - iq3*neq2*neq1)/neq1;
  6512. const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
  6513. const uint32_t h = iq2; // head index
  6514. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  6515. float S = 0.0f; // sum
  6516. float M = -INFINITY; // maximum KQ value
  6517. float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
  6518. float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
  6519. ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
  6520. ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
  6521. if (v->type == GGML_TYPE_F16) {
  6522. memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
  6523. } else {
  6524. memset(VKQ32, 0, DV*sizeof(float));
  6525. }
  6526. const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
  6527. // k indices
  6528. const int ik3 = iq3 / rk3;
  6529. const int ik2 = iq2 / rk2;
  6530. // v indices
  6531. const int iv3 = iq3 / rv3;
  6532. const int iv2 = iq2 / rv2;
  6533. const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
  6534. q_to_vec_dot(pq, Q_q, DK);
  6535. // online softmax / attention
  6536. // loop over n_kv and n_head_kv
  6537. // ref: https://arxiv.org/pdf/2112.05682.pdf
  6538. for (int64_t ic = 0; ic < nek1; ++ic) {
  6539. const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
  6540. if (mv == -INFINITY) {
  6541. continue;
  6542. }
  6543. float s; // KQ value
  6544. const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
  6545. kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
  6546. s = s*scale; // scale KQ value
  6547. if (logit_softcap != 0.0f) {
  6548. s = logit_softcap*tanhf(s);
  6549. }
  6550. s += mv; // apply mask
  6551. const float Mold = M;
  6552. float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
  6553. float vs = 1.0f; // post-softmax KQ value, expf(s - M)
  6554. const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
  6555. if (v->type == GGML_TYPE_F16) {
  6556. if (s > M) {
  6557. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  6558. M = s;
  6559. ms = expf(Mold - M);
  6560. // V = V*expf(Mold - M)
  6561. ggml_vec_scale_f16(DV, VKQ16, ms);
  6562. } else {
  6563. // no new maximum, ms == 1.0f, vs != 1.0f
  6564. vs = expf(s - M);
  6565. }
  6566. // V += v*expf(s - M)
  6567. ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
  6568. } else {
  6569. if (s > M) {
  6570. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  6571. M = s;
  6572. ms = expf(Mold - M);
  6573. // V = V*expf(Mold - M)
  6574. ggml_vec_scale_f32(DV, VKQ32, ms);
  6575. } else {
  6576. // no new maximum, ms == 1.0f, vs != 1.0f
  6577. vs = expf(s - M);
  6578. }
  6579. // V += v*expf(s - M)
  6580. if (v_to_float) {
  6581. v_to_float(v_data, V32, DV);
  6582. ggml_vec_mad_f32(DV, VKQ32, V32, vs);
  6583. } else {
  6584. // V is F32
  6585. ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
  6586. }
  6587. }
  6588. S = S*ms + vs; // scale and increment sum with partial sum
  6589. }
  6590. if (v->type == GGML_TYPE_F16) {
  6591. for (int64_t d = 0; d < DV; ++d) {
  6592. VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
  6593. }
  6594. }
  6595. // V /= S
  6596. const float S_inv = 1.0f/S;
  6597. ggml_vec_scale_f32(DV, VKQ32, S_inv);
  6598. // dst indices
  6599. const int i1 = iq1;
  6600. const int i2 = iq2;
  6601. const int i3 = iq3;
  6602. // original
  6603. //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
  6604. // permute(0, 2, 1, 3)
  6605. memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
  6606. }
  6607. }
  6608. void ggml_compute_forward_flash_attn_ext(
  6609. const ggml_compute_params * params,
  6610. const ggml_tensor * q,
  6611. const ggml_tensor * k,
  6612. const ggml_tensor * v,
  6613. const ggml_tensor * mask,
  6614. ggml_tensor * dst) {
  6615. switch (dst->op_params[3]) {
  6616. case GGML_PREC_DEFAULT:
  6617. case GGML_PREC_F32:
  6618. {
  6619. // uses F32 accumulators
  6620. ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
  6621. } break;
  6622. default:
  6623. {
  6624. GGML_ABORT("fatal error");
  6625. }
  6626. }
  6627. }
  6628. // ggml_compute_forward_flash_attn_back
  6629. static void ggml_compute_forward_flash_attn_back_f32(
  6630. const ggml_compute_params * params,
  6631. const bool masked,
  6632. ggml_tensor * dst) {
  6633. const ggml_tensor * q = dst->src[0];
  6634. const ggml_tensor * k = dst->src[1];
  6635. const ggml_tensor * v = dst->src[2];
  6636. const ggml_tensor * d = dst->src[3];
  6637. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  6638. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  6639. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  6640. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  6641. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  6642. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  6643. GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
  6644. GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
  6645. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  6646. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  6647. const int ith = params->ith;
  6648. const int nth = params->nth;
  6649. const int64_t D = neq0;
  6650. const int64_t N = neq1;
  6651. const int64_t P = nek1 - N;
  6652. const int64_t M = P + N;
  6653. const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
  6654. const int mxDM = MAX(D, Mup);
  6655. // GGML_ASSERT(ne0 == D);
  6656. // GGML_ASSERT(ne1 == N);
  6657. GGML_ASSERT(P >= 0);
  6658. GGML_ASSERT(nbq0 == sizeof(float));
  6659. GGML_ASSERT(nbk0 == sizeof(float));
  6660. GGML_ASSERT(nbv0 == sizeof(float));
  6661. GGML_ASSERT(neq0 == D);
  6662. GGML_ASSERT(nek0 == D);
  6663. GGML_ASSERT(nev1 == D);
  6664. GGML_ASSERT(ned0 == D);
  6665. GGML_ASSERT(neq1 == N);
  6666. GGML_ASSERT(nek1 == N + P);
  6667. GGML_ASSERT(nev1 == D);
  6668. GGML_ASSERT(ned1 == N);
  6669. // dst cannot be transposed or permuted
  6670. GGML_ASSERT(nb0 == sizeof(float));
  6671. GGML_ASSERT(nb0 <= nb1);
  6672. GGML_ASSERT(nb1 <= nb2);
  6673. GGML_ASSERT(nb2 <= nb3);
  6674. if (ith == 0) {
  6675. memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
  6676. }
  6677. ggml_barrier(params->threadpool);
  6678. const int64_t elem_q = ggml_nelements(q);
  6679. const int64_t elem_k = ggml_nelements(k);
  6680. ggml_type result_type = dst->type;
  6681. GGML_ASSERT(ggml_blck_size(result_type) == 1);
  6682. const size_t tsize = ggml_type_size(result_type);
  6683. const size_t offs_q = 0;
  6684. const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
  6685. const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
  6686. void * grad_q = (char *) dst->data;
  6687. void * grad_k = (char *) dst->data + offs_k;
  6688. void * grad_v = (char *) dst->data + offs_v;
  6689. const size_t nbgq1 = nb0*neq0;
  6690. const size_t nbgq2 = nb0*neq0*neq1;
  6691. const size_t nbgq3 = nb0*neq0*neq1*neq2;
  6692. const size_t nbgk1 = nb0*nek0;
  6693. const size_t nbgk2 = nb0*nek0*nek1;
  6694. const size_t nbgk3 = nb0*nek0*nek1*neq2;
  6695. const size_t nbgv1 = nb0*nev0;
  6696. const size_t nbgv2 = nb0*nev0*nev1;
  6697. const size_t nbgv3 = nb0*nev0*nev1*neq2;
  6698. // parallelize by k rows using ggml_vec_dot_f32
  6699. // total rows in k
  6700. const int nr = nek2*nek3;
  6701. // rows per thread
  6702. const int dr = (nr + nth - 1)/nth;
  6703. // row range for this thread
  6704. const int ir0 = dr*ith;
  6705. const int ir1 = MIN(ir0 + dr, nr);
  6706. const float scale = 1.0f/sqrtf(D);
  6707. //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
  6708. // how often k2 (and v2) is repeated in q2
  6709. int nrep = neq2/nek2;
  6710. for (int ir = ir0; ir < ir1; ++ir) {
  6711. // q indices
  6712. const int ik3 = ir/(nek2);
  6713. const int ik2 = ir - ik3*nek2;
  6714. const int iq3 = ik3;
  6715. const int id3 = ik3;
  6716. const int iv3 = ik3;
  6717. const int iv2 = ik2;
  6718. for (int irep = 0; irep < nrep; ++irep) {
  6719. const int iq2 = ik2 + irep*nek2;
  6720. const int id2 = iq2;
  6721. // (ik2 + irep*nek2) % nek2 == ik2
  6722. for (int iq1 = 0; iq1 < neq1; ++iq1) {
  6723. const int id1 = iq1;
  6724. // not sure about CACHE_LINE_SIZE_F32..
  6725. // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
  6726. float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
  6727. float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
  6728. for (int i = M; i < Mup; ++i) {
  6729. S[i] = -INFINITY;
  6730. }
  6731. const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
  6732. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  6733. // k indices
  6734. const int ik1 = ic;
  6735. // S indices
  6736. const int i1 = ik1;
  6737. ggml_vec_dot_f32(neq0,
  6738. S + i1, 0,
  6739. (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
  6740. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
  6741. }
  6742. // scale
  6743. ggml_vec_scale_f32(masked_begin, S, scale);
  6744. for (int64_t i = masked_begin; i < M; i++) {
  6745. S[i] = -INFINITY;
  6746. }
  6747. // softmax
  6748. // exclude known -INF S[..] values from max and loop
  6749. // dont forget to set their SM values to zero
  6750. {
  6751. float max = -INFINITY;
  6752. ggml_vec_max_f32(masked_begin, &max, S);
  6753. ggml_float sum = 0.0;
  6754. {
  6755. #ifdef GGML_SOFT_MAX_ACCELERATE
  6756. max = -max;
  6757. vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
  6758. vvexpf(SM, SM, &Mup);
  6759. ggml_vec_sum_f32(Mup, &sum, SM);
  6760. #else
  6761. sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
  6762. #endif
  6763. }
  6764. assert(sum > 0.0);
  6765. sum = 1.0/sum;
  6766. ggml_vec_scale_f32(masked_begin, SM, sum);
  6767. }
  6768. // step-by-step explanation
  6769. {
  6770. // forward-process shape grads from backward process
  6771. // parallel_for ik2,ik3:
  6772. // for irep:
  6773. // iq2 = ik2 + irep*nek2
  6774. // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
  6775. // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
  6776. // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
  6777. // for iq1:
  6778. // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
  6779. // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
  6780. // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
  6781. // S0 = -Inf [D,1,1,1]
  6782. // ~S1[i] = dot(kcur[:D,i], qcur)
  6783. // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
  6784. // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
  6785. // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  6786. // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
  6787. // ~S5[i] = dot(vcur[:,i], S4)
  6788. // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
  6789. // ~dst[i,iq1,iq2,iq3] = S5[i] ^
  6790. // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
  6791. // dst backward-/ grad[dst] = d
  6792. //
  6793. // output gradients with their dependencies:
  6794. //
  6795. // grad[kcur] = grad[S1].T @ qcur
  6796. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  6797. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  6798. // grad[S4] = grad[S5] @ vcur
  6799. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  6800. // grad[qcur] = grad[S1] @ kcur
  6801. // grad[vcur] = grad[S5].T @ S4
  6802. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  6803. //
  6804. // in post-order:
  6805. //
  6806. // S1 = qcur @ kcur.T
  6807. // S2 = S1 * scale
  6808. // S3 = diag_mask_inf(S2, P)
  6809. // S4 = softmax(S3)
  6810. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  6811. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  6812. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  6813. // grad[qcur] = grad[S1] @ kcur
  6814. // grad[kcur] = grad[S1].T @ qcur
  6815. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  6816. //
  6817. // using less variables (SM=S4):
  6818. //
  6819. // S = diag_mask_inf(qcur @ kcur.T * scale, P)
  6820. // SM = softmax(S)
  6821. // S = d[:D,iq1,iq2,iq3] @ vcur
  6822. // dot_SM_gradSM = dot(SM, S)
  6823. // S = SM * (S - dot(SM, S))
  6824. // S = diag_mask_zero(S, P) * scale
  6825. //
  6826. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  6827. // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
  6828. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  6829. }
  6830. // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  6831. // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  6832. // for ic:
  6833. // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
  6834. // exclude known future zero S[..] values from operation
  6835. ggml_vec_set_f32(masked_begin, S, 0);
  6836. for (int64_t ic = 0; ic < D; ++ic) {
  6837. ggml_vec_mad_f32(masked_begin,
  6838. S,
  6839. (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
  6840. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  6841. }
  6842. // S = SM * (S - dot(SM, S))
  6843. float dot_SM_gradSM = 0;
  6844. ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
  6845. ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
  6846. ggml_vec_mul_f32 (masked_begin, S, S, SM);
  6847. // S = diag_mask_zero(S, P) * scale
  6848. // already done by above ggml_vec_set_f32
  6849. // exclude known zero S[..] values from operation
  6850. ggml_vec_scale_f32(masked_begin, S, scale);
  6851. // S shape [M,1]
  6852. // SM shape [M,1]
  6853. // kcur shape [D,M]
  6854. // qcur shape [D,1]
  6855. // vcur shape [M,D]
  6856. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  6857. // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
  6858. // for ic:
  6859. // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
  6860. // exclude known zero S[..] values from loop
  6861. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  6862. ggml_vec_mad_f32(D,
  6863. (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
  6864. (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
  6865. S[ic]);
  6866. }
  6867. // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
  6868. // for ic:
  6869. // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
  6870. // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
  6871. // exclude known zero S[..] values from loop
  6872. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  6873. ggml_vec_mad_f32(D,
  6874. (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
  6875. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
  6876. S[ic]);
  6877. }
  6878. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  6879. // for ic:
  6880. // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
  6881. // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
  6882. // exclude known zero SM[..] values from mad
  6883. for (int64_t ic = 0; ic < D; ++ic) {
  6884. ggml_vec_mad_f32(masked_begin,
  6885. (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
  6886. SM,
  6887. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  6888. }
  6889. }
  6890. }
  6891. }
  6892. }
  6893. void ggml_compute_forward_flash_attn_back(
  6894. const ggml_compute_params * params,
  6895. const bool masked,
  6896. ggml_tensor * dst) {
  6897. const ggml_tensor * q = dst->src[0];
  6898. switch (q->type) {
  6899. case GGML_TYPE_F32:
  6900. {
  6901. ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
  6902. } break;
  6903. default:
  6904. {
  6905. GGML_ABORT("fatal error");
  6906. }
  6907. }
  6908. }
  6909. // ggml_compute_forward_ssm_conv
  6910. static void ggml_compute_forward_ssm_conv_f32(
  6911. const ggml_compute_params * params,
  6912. ggml_tensor * dst) {
  6913. const ggml_tensor * src0 = dst->src[0]; // conv_x
  6914. const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
  6915. const int ith = params->ith;
  6916. const int nth = params->nth;
  6917. const int nc = src1->ne[0]; // d_conv
  6918. const int ncs = src0->ne[0]; // d_conv - 1 + n_t
  6919. const int nr = src0->ne[1]; // d_inner
  6920. const int n_t = dst->ne[1]; // tokens per sequence
  6921. const int n_s = dst->ne[2]; // number of sequences in the batch
  6922. GGML_ASSERT( dst->ne[0] == nr);
  6923. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6924. GGML_ASSERT(src1->nb[0] == sizeof(float));
  6925. GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
  6926. // rows per thread
  6927. const int dr = (nr + nth - 1)/nth;
  6928. // row range for this thread
  6929. const int ir0 = dr*ith;
  6930. const int ir1 = MIN(ir0 + dr, nr);
  6931. const int ir = ir1 - ir0;
  6932. for (int i3 = 0; i3 < n_s; ++i3) {
  6933. for (int i2 = 0; i2 < n_t; ++i2) {
  6934. // {d_conv - 1 + n_t, d_inner, n_seqs}
  6935. // sliding window
  6936. const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
  6937. const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
  6938. float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
  6939. // TODO: transpose the output for smaller strides for big batches?
  6940. // d_inner
  6941. for (int i1 = 0; i1 < ir; ++i1) {
  6942. // rowwise dot product
  6943. // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
  6944. float sumf = 0.0f;
  6945. // d_conv
  6946. for (int i0 = 0; i0 < nc; ++i0) {
  6947. sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
  6948. }
  6949. x[i1] = sumf;
  6950. }
  6951. }
  6952. }
  6953. }
  6954. void ggml_compute_forward_ssm_conv(
  6955. const ggml_compute_params * params,
  6956. ggml_tensor * dst) {
  6957. switch (dst->src[0]->type) {
  6958. case GGML_TYPE_F32:
  6959. {
  6960. ggml_compute_forward_ssm_conv_f32(params, dst);
  6961. } break;
  6962. default:
  6963. {
  6964. GGML_ABORT("fatal error");
  6965. }
  6966. }
  6967. }
  6968. // ggml_compute_forward_ssm_scan
  6969. static void ggml_compute_forward_ssm_scan_f32(
  6970. const ggml_compute_params * params,
  6971. ggml_tensor * dst) {
  6972. const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
  6973. const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
  6974. const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
  6975. const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
  6976. const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
  6977. const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
  6978. const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
  6979. const int ith = params->ith;
  6980. const int nth = params->nth;
  6981. const int64_t nc = src0->ne[0]; // d_state
  6982. const int64_t nr = src0->ne[1]; // dim
  6983. const int64_t nh = src1->ne[1]; // n_head
  6984. const int64_t ng = src4->ne[1];
  6985. const int64_t nt = src1->ne[2]; // number of tokens per sequence
  6986. const int64_t ns = src1->ne[3]; // number of sequences in the batch
  6987. // can't use ggml_nbytes because src1 is not necessarily contiguous
  6988. const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
  6989. GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
  6990. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6991. GGML_ASSERT(src1->nb[0] == sizeof(float));
  6992. GGML_ASSERT(src2->nb[0] == sizeof(float));
  6993. GGML_ASSERT(src3->nb[0] == sizeof(float));
  6994. GGML_ASSERT(src4->nb[0] == sizeof(float));
  6995. GGML_ASSERT(src5->nb[0] == sizeof(float));
  6996. GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
  6997. // allows optimizing the modulo since n_group should be a power of 2
  6998. GGML_ASSERT((ng & -ng) == ng);
  6999. // heads per thread
  7000. const int dh = (nh + nth - 1)/nth;
  7001. // head range for this thread
  7002. const int ih0 = dh*ith;
  7003. const int ih1 = MIN(ih0 + dh, nh);
  7004. const int32_t * ids = (const int32_t *) src6->data;
  7005. for (int i3 = 0; i3 < ns; ++i3) {
  7006. const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
  7007. float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
  7008. for (int i2 = 0; i2 < nt; ++i2) {
  7009. const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
  7010. const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
  7011. const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
  7012. const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
  7013. const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
  7014. float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
  7015. if (src3->ne[0] == 1) {
  7016. // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
  7017. // n_head
  7018. for (int h = ih0; h < ih1; ++h) {
  7019. // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
  7020. const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
  7021. const float dA = expf(dt_soft_plus * A[h]);
  7022. // dim
  7023. for (int i1 = 0; i1 < nr; ++i1) {
  7024. const int ii = i1 + h*nr;
  7025. const float x_dt = x[ii] * dt_soft_plus;
  7026. float sumf = 0.0f;
  7027. #if defined(GGML_SIMD)
  7028. #if defined(__ARM_FEATURE_SVE)
  7029. const int ggml_f32_epr = svcntw();
  7030. const int ggml_f32_step = 1 * ggml_f32_epr;
  7031. const int np = (nc & ~(ggml_f32_step - 1));
  7032. GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
  7033. GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
  7034. GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
  7035. for (int i = 0; i < np; i += ggml_f32_step) {
  7036. // TODO: maybe unroll more?
  7037. for (int j = 0; j < 1; j++) {
  7038. GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
  7039. GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
  7040. GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
  7041. t0 = GGML_F32_VEC_MUL(t0, adA);
  7042. t1 = GGML_F32_VEC_MUL(t1, axdt);
  7043. t0 = GGML_F32_VEC_ADD(t0, t1);
  7044. sum = GGML_F32_VEC_FMA(sum, t0, t2);
  7045. GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
  7046. }
  7047. }
  7048. sumf = GGML_F32xt_REDUCE_ONE(sum);
  7049. #else
  7050. const int np = (nc & ~(GGML_F32_STEP - 1));
  7051. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  7052. GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
  7053. GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
  7054. GGML_F32_VEC ax[GGML_F32_ARR];
  7055. GGML_F32_VEC ay[GGML_F32_ARR];
  7056. GGML_F32_VEC az[GGML_F32_ARR];
  7057. for (int i = 0; i < np; i += GGML_F32_STEP) {
  7058. for (int j = 0; j < GGML_F32_ARR; j++) {
  7059. ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
  7060. ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
  7061. az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
  7062. ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
  7063. ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
  7064. ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
  7065. sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
  7066. GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
  7067. }
  7068. }
  7069. // reduce sum0..sum3 to sum0
  7070. GGML_F32_VEC_REDUCE(sumf, sum);
  7071. #endif
  7072. #else
  7073. const int np = 0;
  7074. #endif
  7075. // d_state
  7076. for (int i0 = np; i0 < nc; ++i0) {
  7077. const int i = i0 + ii*nc;
  7078. const int ig = i0 + (h & (ng - 1))*nc;
  7079. // state = prev_state * dA + dB * x
  7080. const float state = (s0[i] * dA) + (B[ig] * x_dt);
  7081. // y = rowwise_dotprod(state, C)
  7082. sumf += state * C[ig];
  7083. s[i] = state;
  7084. }
  7085. y[ii] = sumf;
  7086. }
  7087. }
  7088. } else {
  7089. // Mamba-1 has an element-wise decay factor for the states
  7090. // n_head
  7091. for (int h = ih0; h < ih1; ++h) {
  7092. // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
  7093. const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
  7094. // dim
  7095. for (int i1 = 0; i1 < nr; ++i1) {
  7096. const int ii = i1 + h*nr;
  7097. const float x_dt = x[ii] * dt_soft_plus;
  7098. #if defined(__ARM_FEATURE_SVE)
  7099. svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
  7100. svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
  7101. svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
  7102. // d_state
  7103. // TODO: what happens when (d_state % svcntw()) != 0?
  7104. for (int64_t k = 0; k < nc; k += svcntw()) {
  7105. svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
  7106. svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
  7107. svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
  7108. svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
  7109. svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
  7110. t1 = exp_ps_sve(svptrue_b32(), t1);
  7111. svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
  7112. vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
  7113. r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
  7114. GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
  7115. }
  7116. y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
  7117. #else
  7118. float sumf = 0.0f;
  7119. // NOTE: can't really use GGML_SIMD here because d_state is usually 16
  7120. // and also because expf is used within the loop.
  7121. // d_state
  7122. for (int i0 = 0; i0 < nc; ++i0) {
  7123. const int i = i0 + ii*nc;
  7124. const int ig = i0 + (h & (ng - 1))*nc;
  7125. // state = prev_state * dA + dB * x
  7126. const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
  7127. // y = rowwise_dotprod(state, C)
  7128. sumf += state * C[ig];
  7129. s[i] = state;
  7130. }
  7131. y[ii] = sumf;
  7132. #endif
  7133. }
  7134. }
  7135. }
  7136. // use the output as the source when it's not the first token-wise iteration
  7137. s0 = s;
  7138. }
  7139. }
  7140. }
  7141. void ggml_compute_forward_ssm_scan(
  7142. const ggml_compute_params * params,
  7143. ggml_tensor * dst) {
  7144. switch (dst->src[0]->type) {
  7145. case GGML_TYPE_F32:
  7146. {
  7147. ggml_compute_forward_ssm_scan_f32(params, dst);
  7148. } break;
  7149. default:
  7150. {
  7151. GGML_ABORT("fatal error");
  7152. }
  7153. }
  7154. }
  7155. // ggml_compute_forward_win_part
  7156. static void ggml_compute_forward_win_part_f32(
  7157. const ggml_compute_params * params,
  7158. ggml_tensor * dst) {
  7159. GGML_UNUSED(params);
  7160. const ggml_tensor * src0 = dst->src[0];
  7161. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  7162. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  7163. const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
  7164. const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
  7165. const int32_t w = ((const int32_t *)(dst->op_params))[2];
  7166. assert(ne00 == ne0);
  7167. assert(ne3 == nep0*nep1);
  7168. // TODO: optimize / multi-thread
  7169. for (int py = 0; py < nep1; ++py) {
  7170. for (int px = 0; px < nep0; ++px) {
  7171. const int64_t i3 = py*nep0 + px;
  7172. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  7173. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  7174. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  7175. const int64_t i02 = py*w + i2;
  7176. const int64_t i01 = px*w + i1;
  7177. const int64_t i00 = i0;
  7178. const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
  7179. const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
  7180. if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
  7181. ((float *) dst->data)[i] = 0.0f;
  7182. } else {
  7183. ((float *) dst->data)[i] = ((float *) src0->data)[j];
  7184. }
  7185. }
  7186. }
  7187. }
  7188. }
  7189. }
  7190. }
  7191. void ggml_compute_forward_win_part(
  7192. const ggml_compute_params * params,
  7193. ggml_tensor * dst) {
  7194. const ggml_tensor * src0 = dst->src[0];
  7195. switch (src0->type) {
  7196. case GGML_TYPE_F32:
  7197. {
  7198. ggml_compute_forward_win_part_f32(params, dst);
  7199. } break;
  7200. default:
  7201. {
  7202. GGML_ABORT("fatal error");
  7203. }
  7204. }
  7205. }
  7206. // ggml_compute_forward_win_unpart
  7207. static void ggml_compute_forward_win_unpart_f32(
  7208. const ggml_compute_params * params,
  7209. ggml_tensor * dst) {
  7210. GGML_UNUSED(params);
  7211. const ggml_tensor * src0 = dst->src[0];
  7212. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  7213. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  7214. const int32_t w = ((const int32_t *)(dst->op_params))[0];
  7215. // padding
  7216. const int px = (w - ne1%w)%w;
  7217. //const int py = (w - ne2%w)%w;
  7218. const int npx = (px + ne1)/w;
  7219. //const int npy = (py + ne2)/w;
  7220. assert(ne0 == ne00);
  7221. // TODO: optimize / multi-thread
  7222. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  7223. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  7224. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  7225. const int ip2 = i2/w;
  7226. const int ip1 = i1/w;
  7227. const int64_t i02 = i2%w;
  7228. const int64_t i01 = i1%w;
  7229. const int64_t i00 = i0;
  7230. const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
  7231. const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
  7232. ((float *) dst->data)[j] = ((float *) src0->data)[i];
  7233. }
  7234. }
  7235. }
  7236. }
  7237. void ggml_compute_forward_win_unpart(
  7238. const ggml_compute_params * params,
  7239. ggml_tensor * dst) {
  7240. const ggml_tensor * src0 = dst->src[0];
  7241. switch (src0->type) {
  7242. case GGML_TYPE_F32:
  7243. {
  7244. ggml_compute_forward_win_unpart_f32(params, dst);
  7245. } break;
  7246. default:
  7247. {
  7248. GGML_ABORT("fatal error");
  7249. }
  7250. }
  7251. }
  7252. //gmml_compute_forward_unary
  7253. void ggml_compute_forward_unary(
  7254. const ggml_compute_params * params,
  7255. ggml_tensor * dst) {
  7256. const ggml_unary_op op = ggml_get_unary_op(dst);
  7257. switch (op) {
  7258. case GGML_UNARY_OP_ABS:
  7259. {
  7260. ggml_compute_forward_abs(params, dst);
  7261. } break;
  7262. case GGML_UNARY_OP_SGN:
  7263. {
  7264. ggml_compute_forward_sgn(params, dst);
  7265. } break;
  7266. case GGML_UNARY_OP_NEG:
  7267. {
  7268. ggml_compute_forward_neg(params, dst);
  7269. } break;
  7270. case GGML_UNARY_OP_STEP:
  7271. {
  7272. ggml_compute_forward_step(params, dst);
  7273. } break;
  7274. case GGML_UNARY_OP_TANH:
  7275. {
  7276. ggml_compute_forward_tanh(params, dst);
  7277. } break;
  7278. case GGML_UNARY_OP_ELU:
  7279. {
  7280. ggml_compute_forward_elu(params, dst);
  7281. } break;
  7282. case GGML_UNARY_OP_RELU:
  7283. {
  7284. ggml_compute_forward_relu(params, dst);
  7285. } break;
  7286. case GGML_UNARY_OP_SIGMOID:
  7287. {
  7288. ggml_compute_forward_sigmoid(params, dst);
  7289. } break;
  7290. case GGML_UNARY_OP_GELU:
  7291. {
  7292. ggml_compute_forward_gelu(params, dst);
  7293. } break;
  7294. case GGML_UNARY_OP_GELU_ERF:
  7295. {
  7296. ggml_compute_forward_gelu_erf(params, dst);
  7297. } break;
  7298. case GGML_UNARY_OP_GELU_QUICK:
  7299. {
  7300. ggml_compute_forward_gelu_quick(params, dst);
  7301. } break;
  7302. case GGML_UNARY_OP_SILU:
  7303. {
  7304. ggml_compute_forward_silu(params, dst);
  7305. } break;
  7306. case GGML_UNARY_OP_HARDSWISH:
  7307. {
  7308. ggml_compute_forward_hardswish(params, dst);
  7309. } break;
  7310. case GGML_UNARY_OP_HARDSIGMOID:
  7311. {
  7312. ggml_compute_forward_hardsigmoid(params, dst);
  7313. } break;
  7314. case GGML_UNARY_OP_EXP:
  7315. {
  7316. ggml_compute_forward_exp(params, dst);
  7317. } break;
  7318. default:
  7319. {
  7320. GGML_ABORT("fatal error");
  7321. }
  7322. }
  7323. }
  7324. //ggml_compute_forward_glu
  7325. void ggml_compute_forward_glu(
  7326. const ggml_compute_params * params,
  7327. ggml_tensor * dst) {
  7328. const ggml_glu_op op = ggml_get_glu_op(dst);
  7329. switch (op) {
  7330. case GGML_GLU_OP_REGLU:
  7331. {
  7332. ggml_compute_forward_reglu(params, dst);
  7333. } break;
  7334. case GGML_GLU_OP_GEGLU:
  7335. {
  7336. ggml_compute_forward_geglu(params, dst);
  7337. } break;
  7338. case GGML_GLU_OP_SWIGLU:
  7339. {
  7340. ggml_compute_forward_swiglu(params, dst);
  7341. } break;
  7342. case GGML_GLU_OP_GEGLU_ERF:
  7343. {
  7344. ggml_compute_forward_geglu_erf(params, dst);
  7345. } break;
  7346. case GGML_GLU_OP_GEGLU_QUICK:
  7347. {
  7348. ggml_compute_forward_geglu_quick(params, dst);
  7349. } break;
  7350. default:
  7351. {
  7352. GGML_ABORT("fatal error");
  7353. }
  7354. }
  7355. }
  7356. // ggml_compute_forward_get_rel_pos
  7357. static void ggml_compute_forward_get_rel_pos_f16(
  7358. const ggml_compute_params * params,
  7359. ggml_tensor * dst) {
  7360. GGML_UNUSED(params);
  7361. const ggml_tensor * src0 = dst->src[0];
  7362. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
  7363. GGML_TENSOR_UNARY_OP_LOCALS
  7364. const int64_t w = ne1;
  7365. ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
  7366. ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
  7367. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  7368. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  7369. const int64_t pos = (w - i1 - 1) + i2;
  7370. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  7371. dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
  7372. }
  7373. }
  7374. }
  7375. }
  7376. void ggml_compute_forward_get_rel_pos(
  7377. const ggml_compute_params * params,
  7378. ggml_tensor * dst) {
  7379. const ggml_tensor * src0 = dst->src[0];
  7380. switch (src0->type) {
  7381. case GGML_TYPE_F16:
  7382. case GGML_TYPE_BF16:
  7383. {
  7384. ggml_compute_forward_get_rel_pos_f16(params, dst);
  7385. } break;
  7386. default:
  7387. {
  7388. GGML_ABORT("fatal error");
  7389. }
  7390. }
  7391. }
  7392. // ggml_compute_forward_add_rel_pos
  7393. static void ggml_compute_forward_add_rel_pos_f32(
  7394. const ggml_compute_params * params,
  7395. ggml_tensor * dst) {
  7396. const ggml_tensor * src0 = dst->src[0];
  7397. const ggml_tensor * src1 = dst->src[1];
  7398. const ggml_tensor * src2 = dst->src[2];
  7399. const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
  7400. if (!inplace) {
  7401. if (params->ith == 0) {
  7402. memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
  7403. }
  7404. ggml_barrier(params->threadpool);
  7405. }
  7406. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
  7407. float * src1_data = (float *) src1->data;
  7408. float * src2_data = (float *) src2->data;
  7409. float * dst_data = (float *) dst->data;
  7410. const int64_t ne10 = src1->ne[0];
  7411. const int64_t ne11 = src1->ne[1];
  7412. const int64_t ne12 = src1->ne[2];
  7413. const int64_t ne13 = src1->ne[3];
  7414. const int ith = params->ith;
  7415. const int nth = params->nth;
  7416. // total patches in dst
  7417. const int np = ne13;
  7418. // patches per thread
  7419. const int dp = (np + nth - 1)/nth;
  7420. // patch range for this thread
  7421. const int ip0 = dp*ith;
  7422. const int ip1 = MIN(ip0 + dp, np);
  7423. for (int64_t i13 = ip0; i13 < ip1; ++i13) {
  7424. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  7425. for (int64_t i11 = 0; i11 < ne11; ++i11) {
  7426. const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
  7427. for (int64_t i10 = 0; i10 < ne10; ++i10) {
  7428. const int64_t jp0 = jp1 + i10;
  7429. const float src1_e = src1_data[jp0];
  7430. const float src2_e = src2_data[jp0];
  7431. const int64_t jdh = jp0 * ne10;
  7432. const int64_t jdw = jdh - (ne10 - 1) * i10;
  7433. for (int64_t j = 0; j < ne10; ++j) {
  7434. dst_data[jdh + j ] += src2_e;
  7435. dst_data[jdw + j*ne10] += src1_e;
  7436. }
  7437. }
  7438. }
  7439. }
  7440. }
  7441. }
  7442. void ggml_compute_forward_add_rel_pos(
  7443. const ggml_compute_params * params,
  7444. ggml_tensor * dst) {
  7445. const ggml_tensor * src0 = dst->src[0];
  7446. switch (src0->type) {
  7447. case GGML_TYPE_F32:
  7448. {
  7449. ggml_compute_forward_add_rel_pos_f32(params, dst);
  7450. } break;
  7451. default:
  7452. {
  7453. GGML_ABORT("fatal error");
  7454. }
  7455. }
  7456. }
  7457. // ggml_compute_forward_rwkv_wkv6
  7458. static void ggml_compute_forward_rwkv_wkv6_f32(
  7459. const ggml_compute_params * params,
  7460. ggml_tensor * dst) {
  7461. const int64_t T = dst->src[1]->ne[2];
  7462. const int64_t C = dst->ne[0];
  7463. const int64_t HEADS = dst->src[1]->ne[1];
  7464. const int64_t n_seqs = dst->src[5]->ne[1];
  7465. const int64_t head_size = C / HEADS;
  7466. float * dst_data = (float *) dst->data;
  7467. float * state = ((float *) dst->data) + C * T;
  7468. const int ith = params->ith;
  7469. const int nth = params->nth;
  7470. if (ith >= HEADS) {
  7471. return;
  7472. }
  7473. const int h_start = (HEADS * ith) / nth;
  7474. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  7475. (HEADS * (ith + 1)) / nth : HEADS;
  7476. float * k = (float *) dst->src[0]->data;
  7477. float * v = (float *) dst->src[1]->data;
  7478. float * r = (float *) dst->src[2]->data;
  7479. float * time_faaaa = (float *) dst->src[3]->data;
  7480. float * time_decay = (float *) dst->src[4]->data;
  7481. size_t t_stride = HEADS * head_size; // Same to C
  7482. size_t h_stride = C / HEADS;
  7483. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  7484. size_t h_stride_2d = head_size * head_size;
  7485. if (ith == 0) {
  7486. memset(dst_data, 0, T * C * sizeof(float));
  7487. }
  7488. ggml_barrier(params->threadpool);
  7489. #if defined(__AVX__) && !defined(__AVX512F__)
  7490. #define GGML_F32X GGML_F32x8
  7491. #define GGML_F32X_SET1 GGML_F32x8_SET1
  7492. #define GGML_F32X_LOAD GGML_F32x8_LOAD
  7493. #define GGML_F32X_STORE GGML_F32x8_STORE
  7494. #define GGML_F32X_MUL GGML_F32x8_MUL
  7495. #define GGML_F32X_FMA GGML_F32x8_FMA
  7496. #define WKV_VECTOR_SIZE 8
  7497. #elif defined(__AVX512F__)
  7498. #define GGML_F32X GGML_F32x16
  7499. #define GGML_F32X_SET1 GGML_F32x16_SET1
  7500. #define GGML_F32X_LOAD GGML_F32x16_LOAD
  7501. #define GGML_F32X_STORE GGML_F32x16_STORE
  7502. #define GGML_F32X_MUL GGML_F32x16_MUL
  7503. #define GGML_F32X_FMA GGML_F32x16_FMA
  7504. #define WKV_VECTOR_SIZE 16
  7505. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  7506. #define GGML_F32X GGML_F32xt
  7507. #define GGML_F32X_SET1 GGML_F32xt_SET1
  7508. #define GGML_F32X_LOAD GGML_F32xt_LOAD
  7509. #define GGML_F32X_STORE GGML_F32xt_STORE
  7510. #define GGML_F32X_MUL GGML_F32xt_MUL
  7511. #define GGML_F32X_FMA GGML_F32xt_FMA
  7512. #define WKV_VECTOR_SIZE 8
  7513. #elif defined(__ARM_NEON) && defined(__aarch64__)
  7514. #define GGML_F32X GGML_F32x4
  7515. #define GGML_F32X_SET1 GGML_F32x4_SET1
  7516. #define GGML_F32X_LOAD GGML_F32x4_LOAD
  7517. #define GGML_F32X_STORE GGML_F32x4_STORE
  7518. #define GGML_F32X_MUL GGML_F32x4_MUL
  7519. #define GGML_F32X_FMA GGML_F32x4_FMA
  7520. #define WKV_VECTOR_SIZE 4
  7521. #endif
  7522. #ifdef WKV_VECTOR_SIZE
  7523. int wkv_vector_size;
  7524. #if defined(__ARM_FEATURE_SVE)
  7525. wkv_vector_size = svcntw();
  7526. #else
  7527. wkv_vector_size = WKV_VECTOR_SIZE;
  7528. #endif
  7529. const int64_t vec_count = head_size / wkv_vector_size;
  7530. for (int64_t t = 0; t < T; t++) {
  7531. size_t t_offset = t * t_stride;
  7532. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7533. float * state_cur = state + state_offset;
  7534. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  7535. for (int64_t h = h_start; h < h_end; h++) {
  7536. size_t h_offset = h * h_stride;
  7537. size_t t_h_offset = t_offset + h_offset;
  7538. size_t h_2d_offset = h * h_stride_2d;
  7539. for (int64_t i = 0; i < head_size; i++) {
  7540. size_t t_h_i_offset = t_h_offset + i;
  7541. size_t h_i_offset = h_offset + i;
  7542. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7543. float k_val = k[t_h_i_offset];
  7544. float r_val = r[t_h_i_offset];
  7545. float time_faaaa_val = time_faaaa[h_i_offset];
  7546. float time_decay_val = time_decay[t_h_i_offset];
  7547. // Broadcast scalar values to vectors
  7548. GGML_F32X k_vec = GGML_F32X_SET1(k_val);
  7549. GGML_F32X r_vec = GGML_F32X_SET1(r_val);
  7550. GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
  7551. GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
  7552. for (int64_t j = 0; j < vec_count; j++) {
  7553. size_t base_j = j * wkv_vector_size;
  7554. size_t t_h_j_offset = t_h_offset + base_j;
  7555. size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
  7556. // Load x elements at once
  7557. GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
  7558. GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
  7559. GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
  7560. // Compute kv = v * k
  7561. GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
  7562. // Compute temp = kv * time_faaaa + prev_state
  7563. GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
  7564. // Update dst: dst += temp * r
  7565. dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
  7566. GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
  7567. // Update state: state = prev_state * time_decay + kv
  7568. GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
  7569. GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
  7570. }
  7571. // Handle remaining elements, this will not be used.
  7572. for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
  7573. size_t t_h_j_offset = t_h_offset + j;
  7574. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7575. float v_val = v[t_h_j_offset];
  7576. float kv_val = v_val * k_val;
  7577. float prev_state_val = state_prev[h_2d_i_j_offset];
  7578. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  7579. dst_data[t_h_j_offset] += temp_val * r_val;
  7580. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  7581. }
  7582. }
  7583. }
  7584. }
  7585. #else
  7586. // basically fused operations:
  7587. // dst = r @ (time_faaaa * (k @ v) + state),
  7588. // state = time_decay * state + (k @ v),
  7589. // recursive through each token
  7590. for (int64_t t = 0; t < T; t++) {
  7591. size_t t_offset = t * t_stride;
  7592. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7593. float * state_cur = state + state_offset;
  7594. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  7595. for (int64_t h = h_start; h < h_end; h++) {
  7596. size_t h_offset = h * h_stride;
  7597. size_t t_h_offset = t_offset + h_offset;
  7598. size_t h_2d_offset = h * h_stride_2d;
  7599. for (int64_t i = 0; i < head_size; i++) {
  7600. size_t t_h_i_offset = t_h_offset + i;
  7601. size_t h_i_offset = h_offset + i;
  7602. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7603. float k_val = k[t_h_i_offset];
  7604. float r_val = r[t_h_i_offset];
  7605. float time_faaaa_val = time_faaaa[h_i_offset];
  7606. // RWKV v6: different time_decay for each token.
  7607. float time_decay_val = time_decay[t_h_i_offset];
  7608. for (int64_t j = 0; j < head_size; j++) {
  7609. size_t t_h_j_offset = t_h_offset + j;
  7610. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7611. float v_val = v[t_h_j_offset];
  7612. float kv_val = v_val * k_val;
  7613. float prev_state_val = state_prev[h_2d_i_j_offset];
  7614. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  7615. dst_data[t_h_j_offset] += temp_val * r_val;
  7616. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  7617. }
  7618. }
  7619. }
  7620. }
  7621. #endif
  7622. }
  7623. void ggml_compute_forward_rwkv_wkv6(
  7624. const ggml_compute_params * params,
  7625. ggml_tensor * dst) {
  7626. const ggml_tensor * src0 = dst->src[0];
  7627. switch (src0->type) {
  7628. case GGML_TYPE_F32:
  7629. {
  7630. ggml_compute_forward_rwkv_wkv6_f32(params, dst);
  7631. } break;
  7632. default:
  7633. {
  7634. GGML_ABORT("fatal error");
  7635. }
  7636. }
  7637. }
  7638. // ggml_compute_forward_gla
  7639. static void ggml_compute_forward_gla_f32(
  7640. const ggml_compute_params * params,
  7641. ggml_tensor * dst) {
  7642. const int64_t T = dst->src[1]->ne[2];
  7643. const int64_t C = dst->ne[0];
  7644. const int64_t HEADS = dst->src[1]->ne[1];
  7645. const int64_t n_seqs = dst->src[4]->ne[1];
  7646. const int64_t head_size = C / HEADS;
  7647. const float scale = ggml_get_op_params_f32(dst, 0);
  7648. float * dst_data = (float *) dst->data;
  7649. float * state = ((float *) dst->data) + C * T;
  7650. const int ith = params->ith;
  7651. const int nth = params->nth;
  7652. if (ith >= HEADS) {
  7653. return;
  7654. }
  7655. const int h_start = (HEADS * ith) / nth;
  7656. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  7657. (HEADS * (ith + 1)) / nth : HEADS;
  7658. float * k = (float *) dst->src[0]->data;
  7659. float * v = (float *) dst->src[1]->data;
  7660. float * q = (float *) dst->src[2]->data;
  7661. float * g = (float *) dst->src[3]->data;
  7662. size_t t_stride = HEADS * head_size; // Same to C
  7663. size_t h_stride = C / HEADS;
  7664. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  7665. size_t h_stride_2d = head_size * head_size;
  7666. if (ith == 0) {
  7667. memset(dst_data, 0, T * C * sizeof(float));
  7668. }
  7669. ggml_barrier(params->threadpool);
  7670. #if defined(__AVX__) && !defined(__AVX512F__)
  7671. #define GGML_F32X GGML_F32x8
  7672. #define GGML_F32X_SET1 GGML_F32x8_SET1
  7673. #define GGML_F32X_LOAD GGML_F32x8_LOAD
  7674. #define GGML_F32X_STORE GGML_F32x8_STORE
  7675. #define GGML_F32X_MUL GGML_F32x8_MUL
  7676. #define GGML_F32X_FMA GGML_F32x8_FMA
  7677. #define GLA_VECTOR_SIZE 8
  7678. #elif defined(__AVX512F__)
  7679. #define GGML_F32X GGML_F32x16
  7680. #define GGML_F32X_SET1 GGML_F32x16_SET1
  7681. #define GGML_F32X_LOAD GGML_F32x16_LOAD
  7682. #define GGML_F32X_STORE GGML_F32x16_STORE
  7683. #define GGML_F32X_MUL GGML_F32x16_MUL
  7684. #define GGML_F32X_FMA GGML_F32x16_FMA
  7685. #define GLA_VECTOR_SIZE 16
  7686. #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
  7687. #define GGML_F32X GGML_F32xt
  7688. #define GGML_F32X_SET1 GGML_F32xt_SET1
  7689. #define GGML_F32X_LOAD GGML_F32xt_LOAD
  7690. #define GGML_F32X_STORE GGML_F32xt_STORE
  7691. #define GGML_F32X_MUL GGML_F32xt_MUL
  7692. #define GGML_F32X_FMA GGML_F32xt_FMA
  7693. #define GLA_VECTOR_SIZE 8
  7694. #elif defined(__ARM_NEON) && defined(__aarch64__)
  7695. #define GGML_F32X GGML_F32x4
  7696. #define GGML_F32X_SET1 GGML_F32x4_SET1
  7697. #define GGML_F32X_LOAD GGML_F32x4_LOAD
  7698. #define GGML_F32X_STORE GGML_F32x4_STORE
  7699. #define GGML_F32X_MUL GGML_F32x4_MUL
  7700. #define GGML_F32X_FMA GGML_F32x4_FMA
  7701. #define GLA_VECTOR_SIZE 4
  7702. #endif
  7703. #ifdef GLA_VECTOR_SIZE
  7704. int gla_vector_size;
  7705. #if defined(__ARM_FEATURE_SVE)
  7706. gla_vector_size = svcntw();
  7707. #else
  7708. gla_vector_size = GLA_VECTOR_SIZE;
  7709. #endif
  7710. const int64_t vec_count = head_size / gla_vector_size;
  7711. for (int64_t t = 0; t < T; t++) {
  7712. size_t t_offset = t * t_stride;
  7713. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7714. float * state_cur = state + state_offset;
  7715. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
  7716. for (int64_t h = h_start; h < h_end; h++) {
  7717. size_t h_offset = h * h_stride;
  7718. size_t t_h_offset = t_offset + h_offset;
  7719. size_t h_2d_offset = h * h_stride_2d;
  7720. for (int64_t i = 0; i < head_size; i++) {
  7721. size_t t_h_i_offset = t_h_offset + i;
  7722. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7723. float k_val = k[t_h_i_offset];
  7724. float q_val = q[t_h_i_offset] * scale;
  7725. float g_val = g[t_h_i_offset];
  7726. // Broadcast scalar values to vectors
  7727. GGML_F32X k_vec = GGML_F32X_SET1(k_val);
  7728. GGML_F32X q_vec = GGML_F32X_SET1(q_val);
  7729. GGML_F32X g_vec = GGML_F32X_SET1(g_val);
  7730. for (int64_t j = 0; j < vec_count; j++) {
  7731. size_t base_j = j * gla_vector_size;
  7732. size_t t_h_j_offset = t_h_offset + base_j;
  7733. size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
  7734. // Load x elements at once
  7735. GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
  7736. GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
  7737. GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
  7738. // Compute kv = v * k
  7739. GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
  7740. // Compute temp = prev_state * g + kv
  7741. GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
  7742. // Update dst: dst += temp * q
  7743. dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
  7744. GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
  7745. // Update state
  7746. GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
  7747. }
  7748. // Handle remaining elements, this will not be used.
  7749. for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
  7750. size_t t_h_j_offset = t_h_offset + j;
  7751. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7752. float v_val = v[t_h_j_offset];
  7753. float kv_val = v_val * k_val;
  7754. float prev_state_val = state_prev[h_2d_i_j_offset];
  7755. float temp_val = kv_val + prev_state_val * g_val;
  7756. dst_data[t_h_j_offset] += temp_val * q_val;
  7757. state_cur[h_2d_i_j_offset] = temp_val;
  7758. }
  7759. }
  7760. }
  7761. }
  7762. #else
  7763. for (int64_t t = 0; t < T; t++) {
  7764. size_t t_offset = t * t_stride;
  7765. size_t state_offset = head_size * C * (t / (T / n_seqs));
  7766. float * state_cur = state + state_offset;
  7767. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
  7768. for (int64_t h = h_start; h < h_end; h++) {
  7769. size_t h_offset = h * h_stride;
  7770. size_t t_h_offset = t_offset + h_offset;
  7771. size_t h_2d_offset = h * h_stride_2d;
  7772. for (int64_t i = 0; i < head_size; i++) {
  7773. size_t t_h_i_offset = t_h_offset + i;
  7774. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7775. float k_val = k[t_h_i_offset];
  7776. float q_val = q[t_h_i_offset] * scale;
  7777. float g_val = g[t_h_i_offset];
  7778. for (int64_t j = 0; j < head_size; j++) {
  7779. size_t t_h_j_offset = t_h_offset + j;
  7780. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  7781. float v_val = v[t_h_j_offset];
  7782. float kv_val = v_val * k_val;
  7783. float prev_state_val = state_prev[h_2d_i_j_offset];
  7784. float temp_val = prev_state_val * g_val + kv_val;
  7785. dst_data[t_h_j_offset] += temp_val * q_val;
  7786. state_cur[h_2d_i_j_offset] = temp_val;
  7787. }
  7788. }
  7789. }
  7790. }
  7791. #endif
  7792. }
  7793. void ggml_compute_forward_gla(
  7794. const ggml_compute_params * params,
  7795. ggml_tensor * dst) {
  7796. const ggml_tensor * src0 = dst->src[0];
  7797. switch (src0->type) {
  7798. case GGML_TYPE_F32:
  7799. {
  7800. ggml_compute_forward_gla_f32(params, dst);
  7801. } break;
  7802. default:
  7803. {
  7804. GGML_ABORT("fatal error");
  7805. }
  7806. }
  7807. }
  7808. // ggml_compute_forward_rwkv_wkv7
  7809. static void ggml_compute_forward_rwkv_wkv7_f32(
  7810. const ggml_compute_params * params,
  7811. ggml_tensor * dst) {
  7812. const int64_t T = dst->src[1]->ne[2];
  7813. const int64_t C = dst->ne[0];
  7814. const int64_t HEADS = dst->src[1]->ne[1];
  7815. const int64_t n_seqs = dst->src[6]->ne[1];
  7816. const int64_t head_size = C / HEADS;
  7817. float * dst_data = (float *) dst->data;
  7818. float * state = ((float *) dst->data) + C * T;
  7819. const int ith = params->ith;
  7820. const int nth = params->nth;
  7821. if (ith >= HEADS) {
  7822. return;
  7823. }
  7824. const int h_start = (HEADS * ith) / nth;
  7825. const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
  7826. (HEADS * (ith + 1)) / nth : HEADS;
  7827. float * r = (float *) dst->src[0]->data;
  7828. float * w = (float *) dst->src[1]->data;
  7829. float * k = (float *) dst->src[2]->data;
  7830. float * v = (float *) dst->src[3]->data;
  7831. float * a = (float *) dst->src[4]->data;
  7832. float * b = (float *) dst->src[5]->data;
  7833. int64_t t_stride = HEADS * head_size; // Same to C
  7834. int64_t h_stride = C / HEADS;
  7835. GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
  7836. int64_t h_stride_2d = head_size * head_size;
  7837. #if defined(GGML_SIMD)
  7838. #if defined(__ARM_FEATURE_SVE)
  7839. // scalar Route to scalar implementation //TODO: Write SVE code
  7840. for (int64_t t = 0; t < T; t++) {
  7841. int64_t t_offset = t * t_stride;
  7842. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  7843. float * state_cur = state + state_offset;
  7844. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  7845. for (int64_t h = h_start; h < h_end; h++) {
  7846. int64_t h_offset = h * h_stride;
  7847. int64_t t_h_offset = t_offset + h_offset;
  7848. int64_t h_2d_offset = h * h_stride_2d;
  7849. for (int64_t i = 0; i < head_size; i++) {
  7850. int64_t t_h_i_offset = t_h_offset + i;
  7851. int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7852. float v_val = v[t_h_i_offset];
  7853. float sa = 0, result = 0;
  7854. for (int64_t j = 0; j < head_size; j++) {
  7855. sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
  7856. }
  7857. for (int64_t j = 0; j < head_size; j++) {
  7858. int64_t t_h_j_offset = t_h_offset + j;
  7859. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  7860. float r_val = r[t_h_j_offset];
  7861. float w_val = w[t_h_j_offset];
  7862. float k_val = k[t_h_j_offset];
  7863. float b_val = b[t_h_j_offset];
  7864. float kv_val = v_val * k_val;
  7865. float prev_state_val = state_prev[h_2d_i_j_offset];
  7866. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  7867. result += state_cur[h_2d_i_j_offset] * r_val;
  7868. }
  7869. dst_data[t_h_i_offset] = result;
  7870. }
  7871. }
  7872. }
  7873. #else
  7874. for (int64_t t = 0; t < T; t++) {
  7875. int64_t t_offset = t * t_stride;
  7876. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  7877. float * state_cur = state + state_offset;
  7878. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  7879. for (int64_t h = h_start; h < h_end; h++) {
  7880. int64_t h_offset = h * h_stride;
  7881. int64_t t_h_offset = t_offset + h_offset;
  7882. int64_t h_2d_offset = h * h_stride_2d;
  7883. for (int64_t ii = 0; ii < head_size; ii++) {
  7884. int64_t t_h_i_offset = t_h_offset + ii;
  7885. int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
  7886. GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
  7887. float sa = 0;
  7888. {
  7889. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  7890. GGML_F32_VEC ax[GGML_F32_ARR];
  7891. GGML_F32_VEC ay[GGML_F32_ARR];
  7892. for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
  7893. for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
  7894. ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
  7895. ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
  7896. sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
  7897. }
  7898. }
  7899. GGML_F32_VEC_REDUCE(sa, sum);
  7900. }
  7901. GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
  7902. int64_t j = 0;
  7903. GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  7904. for (; j < head_size; j += GGML_F32_STEP) {
  7905. for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
  7906. int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
  7907. int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
  7908. GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
  7909. GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
  7910. GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
  7911. GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
  7912. k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
  7913. GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
  7914. // kv + s * decay + sa * b
  7915. state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
  7916. state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
  7917. GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
  7918. result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
  7919. }
  7920. }
  7921. GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
  7922. // There shouldn't be left-overs though.
  7923. for (; j < head_size; j++) {
  7924. int64_t t_h_j_offset = t_h_offset + j;
  7925. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  7926. float r_val = r[t_h_j_offset];
  7927. float w_val = w[t_h_j_offset];
  7928. float k_val = k[t_h_j_offset];
  7929. float b_val = b[t_h_j_offset];
  7930. float kv_val = v[t_h_i_offset] * k_val;
  7931. float prev_state_val = state_prev[h_2d_i_j_offset];
  7932. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  7933. dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
  7934. }
  7935. }
  7936. }
  7937. }
  7938. #endif
  7939. #else
  7940. for (int64_t t = 0; t < T; t++) {
  7941. int64_t t_offset = t * t_stride;
  7942. int64_t state_offset = head_size * C * (t / (T / n_seqs));
  7943. float * state_cur = state + state_offset;
  7944. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
  7945. for (int64_t h = h_start; h < h_end; h++) {
  7946. int64_t h_offset = h * h_stride;
  7947. int64_t t_h_offset = t_offset + h_offset;
  7948. int64_t h_2d_offset = h * h_stride_2d;
  7949. for (int64_t i = 0; i < head_size; i++) {
  7950. int64_t t_h_i_offset = t_h_offset + i;
  7951. int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
  7952. float v_val = v[t_h_i_offset];
  7953. float sa = 0, result = 0;
  7954. for (int64_t j = 0; j < head_size; j++) {
  7955. sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
  7956. }
  7957. for (int64_t j = 0; j < head_size; j++) {
  7958. int64_t t_h_j_offset = t_h_offset + j;
  7959. int64_t h_2d_i_j_offset = h_2d_i_offset + j;
  7960. float r_val = r[t_h_j_offset];
  7961. float w_val = w[t_h_j_offset];
  7962. float k_val = k[t_h_j_offset];
  7963. float b_val = b[t_h_j_offset];
  7964. float kv_val = v_val * k_val;
  7965. float prev_state_val = state_prev[h_2d_i_j_offset];
  7966. state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
  7967. result += state_cur[h_2d_i_j_offset] * r_val;
  7968. }
  7969. dst_data[t_h_i_offset] = result;
  7970. }
  7971. }
  7972. }
  7973. #endif
  7974. }
  7975. void ggml_compute_forward_rwkv_wkv7(
  7976. const ggml_compute_params * params,
  7977. ggml_tensor * dst) {
  7978. const ggml_tensor * src0 = dst->src[0];
  7979. switch (src0->type) {
  7980. case GGML_TYPE_F32:
  7981. {
  7982. ggml_compute_forward_rwkv_wkv7_f32(params, dst);
  7983. } break;
  7984. default:
  7985. {
  7986. GGML_ABORT("fatal error");
  7987. }
  7988. }
  7989. }
  7990. // ggml_compute_forward_map_custom1
  7991. void ggml_compute_forward_map_custom1(
  7992. const ggml_compute_params * params,
  7993. ggml_tensor * dst) {
  7994. const ggml_tensor * a = dst->src[0];
  7995. struct ggml_map_custom1_op_params p;
  7996. memcpy(&p, dst->op_params, sizeof(p));
  7997. p.fun(dst, a, params->ith, params->nth, p.userdata);
  7998. }
  7999. // ggml_compute_forward_map_custom2
  8000. void ggml_compute_forward_map_custom2(
  8001. const ggml_compute_params * params,
  8002. ggml_tensor * dst) {
  8003. const ggml_tensor * a = dst->src[0];
  8004. const ggml_tensor * b = dst->src[1];
  8005. struct ggml_map_custom2_op_params p;
  8006. memcpy(&p, dst->op_params, sizeof(p));
  8007. p.fun(dst, a, b, params->ith, params->nth, p.userdata);
  8008. }
  8009. // ggml_compute_forward_map_custom3
  8010. void ggml_compute_forward_map_custom3(
  8011. const ggml_compute_params * params,
  8012. ggml_tensor * dst) {
  8013. const ggml_tensor * a = dst->src[0];
  8014. const ggml_tensor * b = dst->src[1];
  8015. const ggml_tensor * c = dst->src[2];
  8016. struct ggml_map_custom3_op_params p;
  8017. memcpy(&p, dst->op_params, sizeof(p));
  8018. p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
  8019. }
  8020. // ggml_compute_forward_custom
  8021. void ggml_compute_forward_custom(
  8022. const struct ggml_compute_params * params,
  8023. struct ggml_tensor * dst) {
  8024. struct ggml_custom_op_params p;
  8025. memcpy(&p, dst->op_params, sizeof(p));
  8026. p.fun(dst, params->ith, params->nth, p.userdata);
  8027. }
  8028. // ggml_compute_forward_cross_entropy_loss
  8029. static void ggml_compute_forward_cross_entropy_loss_f32(
  8030. const ggml_compute_params * params,
  8031. ggml_tensor * dst) {
  8032. const ggml_tensor * src0 = dst->src[0];
  8033. const ggml_tensor * src1 = dst->src[1];
  8034. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  8035. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  8036. GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
  8037. GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
  8038. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  8039. GGML_ASSERT(ggml_is_scalar(dst));
  8040. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  8041. // TODO: handle transposed/permuted matrices
  8042. const int64_t nc = src0->ne[0];
  8043. const int64_t nr = ggml_nrows(src0);
  8044. const int ith = params->ith;
  8045. const int nth = params->nth;
  8046. float * sums = (float *) params->wdata;
  8047. float * st = ((float *) params->wdata) + nth + ith*nc;
  8048. float sum_thread = 0.0f;
  8049. GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
  8050. // rows per thread
  8051. const int64_t dr = (nr + nth - 1)/nth;
  8052. // row range for this thread
  8053. const int64_t ir0 = dr*ith;
  8054. const int64_t ir1 = MIN(ir0 + dr, nr);
  8055. for (int64_t i1 = ir0; i1 < ir1; ++i1) {
  8056. const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
  8057. const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
  8058. #ifndef NDEBUG
  8059. for (int64_t i = 0; i < nc; ++i) {
  8060. //printf("p[%d] = %f\n", i, p[i]);
  8061. assert(!isnan(s0[i]));
  8062. assert(!isnan(s1[i]));
  8063. }
  8064. #endif
  8065. float max = -INFINITY;
  8066. ggml_vec_max_f32(nc, &max, s0);
  8067. const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
  8068. assert(sum_softmax >= 0.0);
  8069. ggml_vec_add1_f32(nc, st, st, -sum_softmax);
  8070. ggml_vec_mul_f32(nc, st, st, s1);
  8071. float sum_st = 0.0f;
  8072. ggml_vec_sum_f32(nc, &sum_st, st);
  8073. sum_thread += sum_st;
  8074. #ifndef NDEBUG
  8075. for (int64_t i = 0; i < nc; ++i) {
  8076. assert(!isnan(st[i]));
  8077. assert(!isinf(st[i]));
  8078. }
  8079. #endif
  8080. }
  8081. sums[ith] = sum_thread;
  8082. ggml_barrier(params->threadpool);
  8083. if (ith == 0) {
  8084. float * dp = (float *) dst->data;
  8085. ggml_vec_sum_f32(nth, dp, sums);
  8086. dp[0] *= -1.0f / (float) nr;
  8087. }
  8088. }
  8089. void ggml_compute_forward_cross_entropy_loss(
  8090. const ggml_compute_params * params,
  8091. ggml_tensor * dst) {
  8092. const ggml_tensor * src0 = dst->src[0];
  8093. switch (src0->type) {
  8094. case GGML_TYPE_F32:
  8095. {
  8096. ggml_compute_forward_cross_entropy_loss_f32(params, dst);
  8097. } break;
  8098. default:
  8099. {
  8100. GGML_ABORT("fatal error");
  8101. }
  8102. }
  8103. }
  8104. // ggml_compute_forward_cross_entropy_loss_back
  8105. static void ggml_compute_forward_cross_entropy_loss_back_f32(
  8106. const ggml_compute_params * params,
  8107. ggml_tensor * dst) {
  8108. const ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
  8109. const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
  8110. const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
  8111. GGML_ASSERT(ggml_is_contiguous(dst));
  8112. GGML_ASSERT(ggml_is_contiguous(src0f));
  8113. GGML_ASSERT(ggml_is_contiguous(src1f));
  8114. GGML_ASSERT(ggml_is_contiguous(grad));
  8115. GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
  8116. const int64_t ith = params->ith;
  8117. const int64_t nth = params->nth;
  8118. // TODO: handle transposed/permuted matrices
  8119. const int64_t nc = src0f->ne[0];
  8120. const int64_t nr = ggml_nrows(src0f);
  8121. // rows per thread
  8122. const int64_t dr = (nr + nth - 1)/nth;
  8123. // row range for this thread
  8124. const int64_t ir0 = dr*ith;
  8125. const int64_t ir1 = MIN(ir0 + dr, nr);
  8126. const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
  8127. for (int64_t i1 = ir0; i1 < ir1; i1++) {
  8128. float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
  8129. const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
  8130. const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
  8131. #ifndef NDEBUG
  8132. for (int64_t i = 0; i < nc; ++i) {
  8133. //printf("p[%d] = %f\n", i, p[i]);
  8134. assert(!isnan(s0[i]));
  8135. assert(!isnan(s1[i]));
  8136. }
  8137. #endif
  8138. // soft_max
  8139. float max = -INFINITY;
  8140. ggml_vec_max_f32(nc, &max, s0);
  8141. const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
  8142. assert(sum > 0.0);
  8143. ggml_vec_scale_f32(nc, ds0, 1.0/sum);
  8144. // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
  8145. ggml_vec_sub_f32(nc, ds0, ds0, s1);
  8146. ggml_vec_scale_f32(nc, ds0, d_by_nr);
  8147. #ifndef NDEBUG
  8148. for (int64_t i = 0; i < nc; ++i) {
  8149. assert(!isnan(ds0[i]));
  8150. assert(!isinf(ds0[i]));
  8151. }
  8152. #endif
  8153. }
  8154. }
  8155. void ggml_compute_forward_cross_entropy_loss_back(
  8156. const ggml_compute_params * params,
  8157. ggml_tensor * dst) {
  8158. const ggml_tensor * src0 = dst->src[0];
  8159. switch (src0->type) {
  8160. case GGML_TYPE_F32:
  8161. {
  8162. ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
  8163. } break;
  8164. default:
  8165. {
  8166. GGML_ABORT("fatal error");
  8167. }
  8168. }
  8169. }
  8170. static void ggml_compute_forward_opt_step_adamw_f32(
  8171. const ggml_compute_params * params,
  8172. ggml_tensor * dst) {
  8173. const ggml_tensor * src0 = dst->src[0];
  8174. const ggml_tensor * src0_grad = dst->src[1];
  8175. const ggml_tensor * src0_grad_m = dst->src[2];
  8176. const ggml_tensor * src0_grad_v = dst->src[3];
  8177. const ggml_tensor * adamw_params = dst->src[4];
  8178. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
  8179. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
  8180. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
  8181. GGML_ASSERT(ggml_nelements(adamw_params) == 7);
  8182. const int ith = params->ith;
  8183. const int nth = params->nth;
  8184. const int nr = ggml_nrows(src0);
  8185. GGML_TENSOR_UNARY_OP_LOCALS
  8186. GGML_ASSERT(nb00 == sizeof(float));
  8187. // rows per thread
  8188. const int dr = (nr + nth - 1)/nth;
  8189. // row range for this thread
  8190. const int ir0 = dr*ith;
  8191. const int ir1 = MIN(ir0 + dr, nr);
  8192. const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
  8193. const float alpha = adamw_params_ptr[0];
  8194. const float beta1 = adamw_params_ptr[1];
  8195. const float beta2 = adamw_params_ptr[2];
  8196. const float eps = adamw_params_ptr[3];
  8197. const float wd = adamw_params_ptr[4];
  8198. const float beta1h = adamw_params_ptr[5];
  8199. const float beta2h = adamw_params_ptr[6];
  8200. for (int ir = ir0; ir < ir1; ++ir) {
  8201. const int64_t i03 = ir/(ne02*ne01);
  8202. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  8203. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  8204. const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
  8205. float * w = (float *) ((char *) src0->data + offset); // weight
  8206. const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
  8207. float * m = (float *) ((char *) src0_grad_m->data + offset);
  8208. float * v = (float *) ((char *) src0_grad_v->data + offset);
  8209. for (int i00 = 0; i00 < ne00; ++i00) {
  8210. m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
  8211. v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
  8212. const float mh = m[i00]*beta1h;
  8213. const float vh = sqrtf(v[i00]*beta2h) + eps;
  8214. // The weight decay is applied independently of the Adam momenta m and v.
  8215. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
  8216. // See: https://arxiv.org/pdf/1711.05101v3.pdf
  8217. w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
  8218. }
  8219. }
  8220. }
  8221. void ggml_compute_forward_opt_step_adamw(
  8222. const ggml_compute_params * params,
  8223. ggml_tensor * dst) {
  8224. const ggml_tensor * src0 = dst->src[0];
  8225. switch (src0->type) {
  8226. case GGML_TYPE_F32:
  8227. {
  8228. ggml_compute_forward_opt_step_adamw_f32(params, dst);
  8229. } break;
  8230. default:
  8231. {
  8232. GGML_ABORT("fatal error");
  8233. }
  8234. }
  8235. }