ggml-cpu.c 454 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691769276937694769576967697769876997700770177027703770477057706770777087709771077117712771377147715771677177718771977207721772277237724772577267727772877297730773177327733773477357736773777387739774077417742774377447745774677477748774977507751775277537754775577567757775877597760776177627763776477657766776777687769777077717772777377747775777677777778777977807781778277837784778577867787778877897790779177927793779477957796779777987799780078017802780378047805780678077808780978107811781278137814781578167817781878197820782178227823782478257826782778287829783078317832783378347835783678377838783978407841784278437844784578467847784878497850785178527853785478557856785778587859786078617862786378647865786678677868786978707871787278737874787578767877787878797880788178827883788478857886788778887889789078917892789378947895789678977898789979007901790279037904790579067907790879097910791179127913791479157916791779187919792079217922792379247925792679277928792979307931793279337934793579367937793879397940794179427943794479457946794779487949795079517952795379547955795679577958795979607961796279637964796579667967796879697970797179727973797479757976797779787979798079817982798379847985798679877988798979907991799279937994799579967997799879998000800180028003800480058006800780088009801080118012801380148015801680178018801980208021802280238024802580268027802880298030803180328033803480358036803780388039804080418042804380448045804680478048804980508051805280538054805580568057805880598060806180628063806480658066806780688069807080718072807380748075807680778078807980808081808280838084808580868087808880898090809180928093809480958096809780988099810081018102810381048105810681078108810981108111811281138114811581168117811881198120812181228123812481258126812781288129813081318132813381348135813681378138813981408141814281438144814581468147814881498150815181528153815481558156815781588159816081618162816381648165816681678168816981708171817281738174817581768177817881798180818181828183818481858186818781888189819081918192819381948195819681978198819982008201820282038204820582068207820882098210821182128213821482158216821782188219822082218222822382248225822682278228822982308231823282338234823582368237823882398240824182428243824482458246824782488249825082518252825382548255825682578258825982608261826282638264826582668267826882698270827182728273827482758276827782788279828082818282828382848285828682878288828982908291829282938294829582968297829882998300830183028303830483058306830783088309831083118312831383148315831683178318831983208321832283238324832583268327832883298330833183328333833483358336833783388339834083418342834383448345834683478348834983508351835283538354835583568357835883598360836183628363836483658366836783688369837083718372837383748375837683778378837983808381838283838384838583868387838883898390839183928393839483958396839783988399840084018402840384048405840684078408840984108411841284138414841584168417841884198420842184228423842484258426842784288429843084318432843384348435843684378438843984408441844284438444844584468447844884498450845184528453845484558456845784588459846084618462846384648465846684678468846984708471847284738474847584768477847884798480848184828483848484858486848784888489849084918492849384948495849684978498849985008501850285038504850585068507850885098510851185128513851485158516851785188519852085218522852385248525852685278528852985308531853285338534853585368537853885398540854185428543854485458546854785488549855085518552855385548555855685578558855985608561856285638564856585668567856885698570857185728573857485758576857785788579858085818582858385848585858685878588858985908591859285938594859585968597859885998600860186028603860486058606860786088609861086118612861386148615861686178618861986208621862286238624862586268627862886298630863186328633863486358636863786388639864086418642864386448645864686478648864986508651865286538654865586568657865886598660866186628663866486658666866786688669867086718672867386748675867686778678867986808681868286838684868586868687868886898690869186928693869486958696869786988699870087018702870387048705870687078708870987108711871287138714871587168717871887198720872187228723872487258726872787288729873087318732873387348735873687378738873987408741874287438744874587468747874887498750875187528753875487558756875787588759876087618762876387648765876687678768876987708771877287738774877587768777877887798780878187828783878487858786878787888789879087918792879387948795879687978798879988008801880288038804880588068807880888098810881188128813881488158816881788188819882088218822882388248825882688278828882988308831883288338834883588368837883888398840884188428843884488458846884788488849885088518852885388548855885688578858885988608861886288638864886588668867886888698870887188728873887488758876887788788879888088818882888388848885888688878888888988908891889288938894889588968897889888998900890189028903890489058906890789088909891089118912891389148915891689178918891989208921892289238924892589268927892889298930893189328933893489358936893789388939894089418942894389448945894689478948894989508951895289538954895589568957895889598960896189628963896489658966896789688969897089718972897389748975897689778978897989808981898289838984898589868987898889898990899189928993899489958996899789988999900090019002900390049005900690079008900990109011901290139014901590169017901890199020902190229023902490259026902790289029903090319032903390349035903690379038903990409041904290439044904590469047904890499050905190529053905490559056905790589059906090619062906390649065906690679068906990709071907290739074907590769077907890799080908190829083908490859086908790889089909090919092909390949095909690979098909991009101910291039104910591069107910891099110911191129113911491159116911791189119912091219122912391249125912691279128912991309131913291339134913591369137913891399140914191429143914491459146914791489149915091519152915391549155915691579158915991609161916291639164916591669167916891699170917191729173917491759176917791789179918091819182918391849185918691879188918991909191919291939194919591969197919891999200920192029203920492059206920792089209921092119212921392149215921692179218921992209221922292239224922592269227922892299230923192329233923492359236923792389239924092419242924392449245924692479248924992509251925292539254925592569257925892599260926192629263926492659266926792689269927092719272927392749275927692779278927992809281928292839284928592869287928892899290929192929293929492959296929792989299930093019302930393049305930693079308930993109311931293139314931593169317931893199320932193229323932493259326932793289329933093319332933393349335933693379338933993409341934293439344934593469347934893499350935193529353935493559356935793589359936093619362936393649365936693679368936993709371937293739374937593769377937893799380938193829383938493859386938793889389939093919392939393949395939693979398939994009401940294039404940594069407940894099410941194129413941494159416941794189419942094219422942394249425942694279428942994309431943294339434943594369437943894399440944194429443944494459446944794489449945094519452945394549455945694579458945994609461946294639464946594669467946894699470947194729473947494759476947794789479948094819482948394849485948694879488948994909491949294939494949594969497949894999500950195029503950495059506950795089509951095119512951395149515951695179518951995209521952295239524952595269527952895299530953195329533953495359536953795389539954095419542954395449545954695479548954995509551955295539554955595569557955895599560956195629563956495659566956795689569957095719572957395749575957695779578957995809581958295839584958595869587958895899590959195929593959495959596959795989599960096019602960396049605960696079608960996109611961296139614961596169617961896199620962196229623962496259626962796289629963096319632963396349635963696379638963996409641964296439644964596469647964896499650965196529653965496559656965796589659966096619662966396649665966696679668966996709671967296739674967596769677967896799680968196829683968496859686968796889689969096919692969396949695969696979698969997009701970297039704970597069707970897099710971197129713971497159716971797189719972097219722972397249725972697279728972997309731973297339734973597369737973897399740974197429743974497459746974797489749975097519752975397549755975697579758975997609761976297639764976597669767976897699770977197729773977497759776977797789779978097819782978397849785978697879788978997909791979297939794979597969797979897999800980198029803980498059806980798089809981098119812981398149815981698179818981998209821982298239824982598269827982898299830983198329833983498359836983798389839984098419842984398449845984698479848984998509851985298539854985598569857985898599860986198629863986498659866986798689869987098719872987398749875987698779878987998809881988298839884988598869887988898899890989198929893989498959896989798989899990099019902990399049905990699079908990999109911991299139914991599169917991899199920992199229923992499259926992799289929993099319932993399349935993699379938993999409941994299439944994599469947994899499950995199529953995499559956995799589959996099619962996399649965996699679968996999709971997299739974997599769977997899799980998199829983998499859986998799889989999099919992999399949995999699979998999910000100011000210003100041000510006100071000810009100101001110012100131001410015100161001710018100191002010021100221002310024100251002610027100281002910030100311003210033100341003510036100371003810039100401004110042100431004410045100461004710048100491005010051100521005310054100551005610057100581005910060100611006210063100641006510066100671006810069100701007110072100731007410075100761007710078100791008010081100821008310084100851008610087100881008910090100911009210093100941009510096100971009810099101001010110102101031010410105101061010710108101091011010111101121011310114101151011610117101181011910120101211012210123101241012510126101271012810129101301013110132101331013410135101361013710138101391014010141101421014310144101451014610147101481014910150101511015210153101541015510156101571015810159101601016110162101631016410165101661016710168101691017010171101721017310174101751017610177101781017910180101811018210183101841018510186101871018810189101901019110192101931019410195101961019710198101991020010201102021020310204102051020610207102081020910210102111021210213102141021510216102171021810219102201022110222102231022410225102261022710228102291023010231102321023310234102351023610237102381023910240102411024210243102441024510246102471024810249102501025110252102531025410255102561025710258102591026010261102621026310264102651026610267102681026910270102711027210273102741027510276102771027810279102801028110282102831028410285102861028710288102891029010291102921029310294102951029610297102981029910300103011030210303103041030510306103071030810309103101031110312103131031410315103161031710318103191032010321103221032310324103251032610327103281032910330103311033210333103341033510336103371033810339103401034110342103431034410345103461034710348103491035010351103521035310354103551035610357103581035910360103611036210363103641036510366103671036810369103701037110372103731037410375103761037710378103791038010381103821038310384103851038610387103881038910390103911039210393103941039510396103971039810399104001040110402104031040410405104061040710408104091041010411104121041310414104151041610417104181041910420104211042210423104241042510426104271042810429104301043110432104331043410435104361043710438104391044010441104421044310444104451044610447104481044910450104511045210453104541045510456104571045810459104601046110462104631046410465104661046710468104691047010471104721047310474104751047610477104781047910480104811048210483104841048510486104871048810489104901049110492104931049410495104961049710498104991050010501105021050310504105051050610507105081050910510105111051210513105141051510516105171051810519105201052110522105231052410525105261052710528105291053010531105321053310534105351053610537105381053910540105411054210543105441054510546105471054810549105501055110552105531055410555105561055710558105591056010561105621056310564105651056610567105681056910570105711057210573105741057510576105771057810579105801058110582105831058410585105861058710588105891059010591105921059310594105951059610597105981059910600106011060210603106041060510606106071060810609106101061110612106131061410615106161061710618106191062010621106221062310624106251062610627106281062910630106311063210633106341063510636106371063810639106401064110642106431064410645106461064710648106491065010651106521065310654106551065610657106581065910660106611066210663106641066510666106671066810669106701067110672106731067410675106761067710678106791068010681106821068310684106851068610687106881068910690106911069210693106941069510696106971069810699107001070110702107031070410705107061070710708107091071010711107121071310714107151071610717107181071910720107211072210723107241072510726107271072810729107301073110732107331073410735107361073710738107391074010741107421074310744107451074610747107481074910750107511075210753107541075510756107571075810759107601076110762107631076410765107661076710768107691077010771107721077310774107751077610777107781077910780107811078210783107841078510786107871078810789107901079110792107931079410795107961079710798107991080010801108021080310804108051080610807108081080910810108111081210813108141081510816108171081810819108201082110822108231082410825108261082710828108291083010831108321083310834108351083610837108381083910840108411084210843108441084510846108471084810849108501085110852108531085410855108561085710858108591086010861108621086310864108651086610867108681086910870108711087210873108741087510876108771087810879108801088110882108831088410885108861088710888108891089010891108921089310894108951089610897108981089910900109011090210903109041090510906109071090810909109101091110912109131091410915109161091710918109191092010921109221092310924109251092610927109281092910930109311093210933109341093510936109371093810939109401094110942109431094410945109461094710948109491095010951109521095310954109551095610957109581095910960109611096210963109641096510966109671096810969109701097110972109731097410975109761097710978109791098010981109821098310984109851098610987109881098910990109911099210993109941099510996109971099810999110001100111002110031100411005110061100711008110091101011011110121101311014110151101611017110181101911020110211102211023110241102511026110271102811029110301103111032110331103411035110361103711038110391104011041110421104311044110451104611047110481104911050110511105211053110541105511056110571105811059110601106111062110631106411065110661106711068110691107011071110721107311074110751107611077110781107911080110811108211083110841108511086110871108811089110901109111092110931109411095110961109711098110991110011101111021110311104111051110611107111081110911110111111111211113111141111511116111171111811119111201112111122111231112411125111261112711128111291113011131111321113311134111351113611137111381113911140111411114211143111441114511146111471114811149111501115111152111531115411155111561115711158111591116011161111621116311164111651116611167111681116911170111711117211173111741117511176111771117811179111801118111182111831118411185111861118711188111891119011191111921119311194111951119611197111981119911200112011120211203112041120511206112071120811209112101121111212112131121411215112161121711218112191122011221112221122311224112251122611227112281122911230112311123211233112341123511236112371123811239112401124111242112431124411245112461124711248112491125011251112521125311254112551125611257112581125911260112611126211263112641126511266112671126811269112701127111272112731127411275112761127711278112791128011281112821128311284112851128611287112881128911290112911129211293112941129511296112971129811299113001130111302113031130411305113061130711308113091131011311113121131311314113151131611317113181131911320113211132211323113241132511326113271132811329113301133111332113331133411335113361133711338113391134011341113421134311344113451134611347113481134911350113511135211353113541135511356113571135811359113601136111362113631136411365113661136711368113691137011371113721137311374113751137611377113781137911380113811138211383113841138511386113871138811389113901139111392113931139411395113961139711398113991140011401114021140311404114051140611407114081140911410114111141211413114141141511416114171141811419114201142111422114231142411425114261142711428114291143011431114321143311434114351143611437114381143911440114411144211443114441144511446114471144811449114501145111452114531145411455114561145711458114591146011461114621146311464114651146611467114681146911470114711147211473114741147511476114771147811479114801148111482114831148411485114861148711488114891149011491114921149311494114951149611497114981149911500115011150211503115041150511506115071150811509115101151111512115131151411515115161151711518115191152011521115221152311524115251152611527115281152911530115311153211533115341153511536115371153811539115401154111542115431154411545115461154711548115491155011551115521155311554115551155611557115581155911560115611156211563115641156511566115671156811569115701157111572115731157411575115761157711578115791158011581115821158311584115851158611587115881158911590115911159211593115941159511596115971159811599116001160111602116031160411605116061160711608116091161011611116121161311614116151161611617116181161911620116211162211623116241162511626116271162811629116301163111632116331163411635116361163711638116391164011641116421164311644116451164611647116481164911650116511165211653116541165511656116571165811659116601166111662116631166411665116661166711668116691167011671116721167311674116751167611677116781167911680116811168211683116841168511686116871168811689116901169111692116931169411695116961169711698116991170011701117021170311704117051170611707117081170911710117111171211713117141171511716117171171811719117201172111722117231172411725117261172711728117291173011731117321173311734117351173611737117381173911740117411174211743117441174511746117471174811749117501175111752117531175411755117561175711758117591176011761117621176311764117651176611767117681176911770117711177211773117741177511776117771177811779117801178111782117831178411785117861178711788117891179011791117921179311794117951179611797117981179911800118011180211803118041180511806118071180811809118101181111812118131181411815118161181711818118191182011821118221182311824118251182611827118281182911830118311183211833118341183511836118371183811839118401184111842118431184411845118461184711848118491185011851118521185311854118551185611857118581185911860118611186211863118641186511866118671186811869118701187111872118731187411875118761187711878118791188011881118821188311884118851188611887118881188911890118911189211893118941189511896118971189811899119001190111902119031190411905119061190711908119091191011911119121191311914119151191611917119181191911920119211192211923119241192511926119271192811929119301193111932119331193411935119361193711938119391194011941119421194311944119451194611947119481194911950119511195211953119541195511956119571195811959119601196111962119631196411965119661196711968119691197011971119721197311974119751197611977119781197911980119811198211983119841198511986119871198811989119901199111992119931199411995119961199711998119991200012001120021200312004120051200612007120081200912010120111201212013120141201512016120171201812019120201202112022120231202412025120261202712028120291203012031120321203312034120351203612037120381203912040120411204212043120441204512046120471204812049120501205112052120531205412055120561205712058120591206012061120621206312064120651206612067120681206912070120711207212073120741207512076120771207812079120801208112082120831208412085120861208712088120891209012091120921209312094120951209612097120981209912100121011210212103121041210512106121071210812109121101211112112121131211412115121161211712118121191212012121121221212312124121251212612127121281212912130121311213212133121341213512136121371213812139121401214112142121431214412145121461214712148121491215012151121521215312154121551215612157121581215912160121611216212163121641216512166121671216812169121701217112172121731217412175121761217712178121791218012181121821218312184121851218612187121881218912190121911219212193121941219512196121971219812199122001220112202122031220412205122061220712208122091221012211122121221312214122151221612217122181221912220122211222212223122241222512226122271222812229122301223112232122331223412235122361223712238122391224012241122421224312244122451224612247122481224912250122511225212253122541225512256122571225812259122601226112262122631226412265122661226712268122691227012271122721227312274122751227612277122781227912280122811228212283122841228512286122871228812289122901229112292122931229412295122961229712298122991230012301123021230312304123051230612307123081230912310123111231212313123141231512316123171231812319123201232112322123231232412325123261232712328123291233012331123321233312334123351233612337123381233912340123411234212343123441234512346123471234812349123501235112352123531235412355123561235712358123591236012361123621236312364123651236612367123681236912370123711237212373123741237512376123771237812379123801238112382123831238412385123861238712388123891239012391123921239312394123951239612397123981239912400124011240212403124041240512406124071240812409124101241112412124131241412415124161241712418124191242012421124221242312424124251242612427124281242912430124311243212433124341243512436124371243812439124401244112442124431244412445124461244712448124491245012451124521245312454124551245612457124581245912460124611246212463124641246512466124671246812469124701247112472124731247412475124761247712478124791248012481124821248312484124851248612487124881248912490124911249212493124941249512496124971249812499125001250112502125031250412505125061250712508125091251012511125121251312514125151251612517125181251912520125211252212523125241252512526125271252812529125301253112532125331253412535125361253712538125391254012541125421254312544125451254612547125481254912550125511255212553125541255512556125571255812559125601256112562125631256412565125661256712568125691257012571125721257312574125751257612577125781257912580125811258212583125841258512586125871258812589125901259112592125931259412595125961259712598125991260012601126021260312604126051260612607126081260912610126111261212613126141261512616126171261812619126201262112622126231262412625126261262712628126291263012631126321263312634126351263612637126381263912640126411264212643126441264512646126471264812649126501265112652126531265412655126561265712658126591266012661126621266312664126651266612667126681266912670126711267212673126741267512676126771267812679126801268112682126831268412685126861268712688126891269012691126921269312694126951269612697126981269912700127011270212703127041270512706127071270812709127101271112712127131271412715127161271712718127191272012721127221272312724127251272612727127281272912730127311273212733127341273512736127371273812739127401274112742127431274412745127461274712748127491275012751127521275312754127551275612757127581275912760127611276212763127641276512766127671276812769127701277112772127731277412775127761277712778127791278012781127821278312784127851278612787127881278912790127911279212793127941279512796127971279812799128001280112802128031280412805128061280712808128091281012811128121281312814128151281612817128181281912820128211282212823128241282512826128271282812829128301283112832128331283412835128361283712838128391284012841128421284312844128451284612847128481284912850128511285212853128541285512856128571285812859128601286112862128631286412865128661286712868128691287012871128721287312874128751287612877128781287912880128811288212883128841288512886128871288812889128901289112892128931289412895128961289712898128991290012901129021290312904129051290612907129081290912910129111291212913129141291512916129171291812919129201292112922129231292412925129261292712928129291293012931129321293312934129351293612937129381293912940129411294212943129441294512946129471294812949129501295112952129531295412955129561295712958129591296012961129621296312964129651296612967129681296912970129711297212973129741297512976129771297812979129801298112982129831298412985129861298712988129891299012991129921299312994129951299612997129981299913000130011300213003130041300513006130071300813009130101301113012130131301413015130161301713018130191302013021130221302313024130251302613027130281302913030130311303213033130341303513036130371303813039130401304113042130431304413045130461304713048130491305013051130521305313054130551305613057130581305913060130611306213063130641306513066130671306813069130701307113072130731307413075130761307713078130791308013081130821308313084130851308613087130881308913090130911309213093130941309513096130971309813099131001310113102131031310413105131061310713108131091311013111131121311313114131151311613117131181311913120131211312213123131241312513126131271312813129131301313113132131331313413135131361313713138131391314013141131421314313144131451314613147131481314913150131511315213153131541315513156131571315813159131601316113162131631316413165131661316713168131691317013171131721317313174131751317613177131781317913180131811318213183131841318513186131871318813189131901319113192131931319413195131961319713198131991320013201132021320313204132051320613207132081320913210132111321213213132141321513216132171321813219132201322113222132231322413225132261322713228132291323013231132321323313234132351323613237132381323913240132411324213243132441324513246132471324813249132501325113252132531325413255132561325713258132591326013261132621326313264132651326613267132681326913270132711327213273132741327513276132771327813279132801328113282132831328413285132861328713288132891329013291132921329313294132951329613297132981329913300133011330213303133041330513306133071330813309133101331113312133131331413315133161331713318133191332013321133221332313324133251332613327133281332913330133311333213333133341333513336133371333813339133401334113342133431334413345133461334713348133491335013351133521335313354133551335613357133581335913360133611336213363133641336513366133671336813369133701337113372133731337413375133761337713378133791338013381133821338313384133851338613387133881338913390133911339213393133941339513396133971339813399134001340113402134031340413405134061340713408134091341013411134121341313414134151341613417134181341913420134211342213423134241342513426134271342813429134301343113432134331343413435134361343713438134391344013441134421344313444134451344613447134481344913450134511345213453134541345513456134571345813459134601346113462134631346413465134661346713468134691347013471134721347313474134751347613477134781347913480134811348213483134841348513486134871348813489134901349113492134931349413495134961349713498134991350013501135021350313504135051350613507135081350913510135111351213513135141351513516135171351813519135201352113522135231352413525135261352713528135291353013531135321353313534135351353613537135381353913540135411354213543135441354513546135471354813549135501355113552135531355413555135561355713558135591356013561135621356313564135651356613567135681356913570135711357213573135741357513576135771357813579135801358113582135831358413585135861358713588135891359013591135921359313594135951359613597135981359913600136011360213603136041360513606136071360813609136101361113612136131361413615136161361713618136191362013621136221362313624136251362613627136281362913630136311363213633136341363513636136371363813639136401364113642136431364413645136461364713648136491365013651136521365313654136551365613657136581365913660136611366213663136641366513666136671366813669136701367113672136731367413675136761367713678136791368013681136821368313684136851368613687136881368913690136911369213693136941369513696136971369813699137001370113702137031370413705137061370713708137091371013711137121371313714137151371613717137181371913720
  1. #define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
  2. #define _USE_MATH_DEFINES // For M_PI on MSVC
  3. #include "ggml-aarch64.h"
  4. #include "ggml-backend-impl.h"
  5. #include "ggml-backend.h"
  6. #include "ggml-cpu-impl.h"
  7. #include "ggml-cpu.h"
  8. #include "ggml-impl.h"
  9. #include "ggml-quants.h"
  10. #include "ggml.h"
  11. #if defined(_MSC_VER) || defined(__MINGW32__)
  12. #include <malloc.h> // using malloc.h with MSC/MINGW
  13. #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
  14. #include <alloca.h>
  15. #endif
  16. #include <assert.h>
  17. #include <errno.h>
  18. #include <time.h>
  19. #include <math.h>
  20. #include <stdlib.h>
  21. #include <string.h>
  22. #include <stdint.h>
  23. #include <inttypes.h>
  24. #include <stdio.h>
  25. #include <float.h>
  26. #include <limits.h>
  27. #include <stdarg.h>
  28. #include <signal.h>
  29. #if defined(__gnu_linux__)
  30. #include <syscall.h>
  31. #endif
  32. #ifdef GGML_USE_OPENMP
  33. #include <omp.h>
  34. #endif
  35. #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
  36. #undef GGML_USE_LLAMAFILE
  37. #endif
  38. #ifdef GGML_USE_LLAMAFILE
  39. #include <llamafile/sgemm.h>
  40. #endif
  41. #if defined(_MSC_VER)
  42. // disable "possible loss of data" to avoid hundreds of casts
  43. // we should just be careful :)
  44. #pragma warning(disable: 4244 4267)
  45. // disable POSIX deprecation warnings
  46. // these functions are never going away, anyway
  47. #pragma warning(disable: 4996)
  48. // unreachable code because of multiple instances of code after GGML_ABORT
  49. #pragma warning(disable: 4702)
  50. #endif
  51. // Note: once we move threading into a separate C++ file
  52. // will use std::hardware_destructive_interference_size instead of hardcoding it here
  53. // and we'll use C++ attribute syntax.
  54. #define GGML_CACHE_LINE 64
  55. #if defined(__clang__) || defined(__GNUC__)
  56. #define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
  57. #endif
  58. #if defined(__has_feature)
  59. #if __has_feature(thread_sanitizer)
  60. #define GGML_TSAN_ENABLED 1
  61. #endif
  62. #else // __has_feature
  63. #if defined(__SANITIZE_THREAD__)
  64. #define GGML_TSAN_ENABLED 1
  65. #endif
  66. #endif // __has_feature
  67. #define UNUSED GGML_UNUSED
  68. #define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
  69. #if defined(GGML_USE_ACCELERATE)
  70. #include <Accelerate/Accelerate.h>
  71. #endif
  72. // floating point type used to accumulate sums
  73. typedef double ggml_float;
  74. #define GGML_GELU_FP16
  75. #define GGML_GELU_QUICK_FP16
  76. #define GGML_SOFT_MAX_UNROLL 4
  77. #define GGML_VEC_DOT_UNROLL 2
  78. #define GGML_VEC_MAD_UNROLL 32
  79. //
  80. // global data
  81. //
  82. // precomputed gelu table for f16 (128 KB)
  83. static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
  84. // precomputed quick gelu table for f16 (128 KB)
  85. static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
  86. // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
  87. float ggml_table_f32_f16[1 << 16];
  88. #if defined(__ARM_ARCH)
  89. struct ggml_arm_arch_features_type {
  90. int has_neon;
  91. int has_i8mm;
  92. int has_sve;
  93. int sve_cnt;
  94. } ggml_arm_arch_features = {-1, -1, -1, 0};
  95. #endif
  96. #if defined(_WIN32)
  97. #define WIN32_LEAN_AND_MEAN
  98. #ifndef NOMINMAX
  99. #define NOMINMAX
  100. #endif
  101. #include <windows.h>
  102. #if !defined(__clang__)
  103. #define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
  104. typedef volatile LONG atomic_int;
  105. typedef atomic_int atomic_bool;
  106. typedef atomic_int atomic_flag;
  107. #define ATOMIC_FLAG_INIT 0
  108. typedef enum {
  109. memory_order_relaxed,
  110. memory_order_consume,
  111. memory_order_acquire,
  112. memory_order_release,
  113. memory_order_acq_rel,
  114. memory_order_seq_cst
  115. } memory_order;
  116. static void atomic_store(atomic_int * ptr, LONG val) {
  117. InterlockedExchange(ptr, val);
  118. }
  119. static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
  120. // TODO: add support for explicit memory order
  121. InterlockedExchange(ptr, val);
  122. }
  123. static LONG atomic_load(atomic_int * ptr) {
  124. return InterlockedCompareExchange(ptr, 0, 0);
  125. }
  126. static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
  127. // TODO: add support for explicit memory order
  128. return InterlockedCompareExchange(ptr, 0, 0);
  129. }
  130. static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
  131. return InterlockedExchangeAdd(ptr, inc);
  132. }
  133. static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
  134. // TODO: add support for explicit memory order
  135. return InterlockedExchangeAdd(ptr, inc);
  136. }
  137. static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
  138. return InterlockedExchange(ptr, 1);
  139. }
  140. static void atomic_flag_clear(atomic_flag * ptr) {
  141. InterlockedExchange(ptr, 0);
  142. }
  143. static void atomic_thread_fence(memory_order mo) {
  144. MemoryBarrier();
  145. }
  146. #else // clang
  147. #include <stdatomic.h>
  148. #endif
  149. typedef HANDLE pthread_t;
  150. typedef DWORD thread_ret_t;
  151. static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
  152. (void) unused;
  153. HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
  154. if (handle == NULL)
  155. {
  156. return EAGAIN;
  157. }
  158. *out = handle;
  159. return 0;
  160. }
  161. static int pthread_join(pthread_t thread, void * unused) {
  162. (void) unused;
  163. int ret = (int) WaitForSingleObject(thread, INFINITE);
  164. CloseHandle(thread);
  165. return ret;
  166. }
  167. static int sched_yield (void) {
  168. Sleep (0);
  169. return 0;
  170. }
  171. #else
  172. #include <pthread.h>
  173. #include <stdatomic.h>
  174. #include <sched.h>
  175. #if defined(__FreeBSD__)
  176. #include <pthread_np.h>
  177. #endif
  178. typedef void * thread_ret_t;
  179. #include <sys/types.h>
  180. #include <sys/stat.h>
  181. #include <unistd.h>
  182. #endif
  183. typedef pthread_t ggml_thread_t;
  184. #ifdef GGML_USE_CPU_HBM
  185. #include <hbwmalloc.h>
  186. #endif
  187. #if defined(__APPLE__)
  188. #include <unistd.h>
  189. #include <mach/mach.h>
  190. #include <TargetConditionals.h>
  191. #endif
  192. //
  193. // cache line
  194. //
  195. #if defined(__cpp_lib_hardware_interference_size)
  196. #define CACHE_LINE_SIZE hardware_destructive_interference_size
  197. #else
  198. #if defined(__POWER9_VECTOR__)
  199. #define CACHE_LINE_SIZE 128
  200. #else
  201. #define CACHE_LINE_SIZE 64
  202. #endif
  203. #endif
  204. static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
  205. static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
  206. static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
  207. static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
  208. static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
  209. [GGML_TYPE_F32] = {
  210. .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
  211. .vec_dot_type = GGML_TYPE_F32,
  212. .nrows = 1,
  213. },
  214. [GGML_TYPE_F16] = {
  215. .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
  216. .vec_dot_type = GGML_TYPE_F16,
  217. .nrows = 1,
  218. },
  219. [GGML_TYPE_Q4_0] = {
  220. .vec_dot = ggml_vec_dot_q4_0_q8_0,
  221. .vec_dot_type = GGML_TYPE_Q8_0,
  222. #if defined (__ARM_FEATURE_MATMUL_INT8)
  223. .nrows = 2,
  224. #else
  225. .nrows = 1,
  226. #endif
  227. },
  228. [GGML_TYPE_Q4_1] = {
  229. .vec_dot = ggml_vec_dot_q4_1_q8_1,
  230. .vec_dot_type = GGML_TYPE_Q8_1,
  231. #if defined (__ARM_FEATURE_MATMUL_INT8)
  232. .nrows = 2,
  233. #else
  234. .nrows = 1,
  235. #endif
  236. },
  237. [4] = { // GGML_TYPE_Q4_2
  238. .vec_dot = NULL,
  239. .vec_dot_type = GGML_TYPE_COUNT,
  240. .nrows = 1,
  241. },
  242. [5] = { // GGML_TYPE_Q4_3
  243. .vec_dot = NULL,
  244. .vec_dot_type = GGML_TYPE_COUNT,
  245. .nrows = 1,
  246. },
  247. [GGML_TYPE_Q5_0] = {
  248. .vec_dot = ggml_vec_dot_q5_0_q8_0,
  249. .vec_dot_type = GGML_TYPE_Q8_0,
  250. .nrows = 1,
  251. },
  252. [GGML_TYPE_Q5_1] = {
  253. .vec_dot = ggml_vec_dot_q5_1_q8_1,
  254. .vec_dot_type = GGML_TYPE_Q8_1,
  255. .nrows = 1,
  256. },
  257. [GGML_TYPE_Q8_0] = {
  258. .from_float_to_mat = quantize_mat_q8_0,
  259. .vec_dot = ggml_vec_dot_q8_0_q8_0,
  260. .vec_dot_type = GGML_TYPE_Q8_0,
  261. #if defined (__ARM_FEATURE_MATMUL_INT8)
  262. .nrows = 2,
  263. #else
  264. .nrows = 1,
  265. #endif
  266. },
  267. [GGML_TYPE_Q8_1] = {
  268. .vec_dot_type = GGML_TYPE_Q8_1,
  269. .nrows = 1,
  270. },
  271. [GGML_TYPE_Q2_K] = {
  272. .vec_dot = ggml_vec_dot_q2_K_q8_K,
  273. .vec_dot_type = GGML_TYPE_Q8_K,
  274. .nrows = 1,
  275. },
  276. [GGML_TYPE_Q3_K] = {
  277. .vec_dot = ggml_vec_dot_q3_K_q8_K,
  278. .vec_dot_type = GGML_TYPE_Q8_K,
  279. .nrows = 1,
  280. },
  281. [GGML_TYPE_Q4_K] = {
  282. .vec_dot = ggml_vec_dot_q4_K_q8_K,
  283. .vec_dot_type = GGML_TYPE_Q8_K,
  284. .nrows = 1,
  285. },
  286. [GGML_TYPE_Q5_K] = {
  287. .vec_dot = ggml_vec_dot_q5_K_q8_K,
  288. .vec_dot_type = GGML_TYPE_Q8_K,
  289. .nrows = 1,
  290. },
  291. [GGML_TYPE_Q6_K] = {
  292. .vec_dot = ggml_vec_dot_q6_K_q8_K,
  293. .vec_dot_type = GGML_TYPE_Q8_K,
  294. .nrows = 1,
  295. },
  296. [GGML_TYPE_IQ2_XXS] = {
  297. .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
  298. .vec_dot_type = GGML_TYPE_Q8_K,
  299. .nrows = 1,
  300. },
  301. [GGML_TYPE_IQ2_XS] = {
  302. .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
  303. .vec_dot_type = GGML_TYPE_Q8_K,
  304. .nrows = 1,
  305. },
  306. [GGML_TYPE_IQ3_XXS] = {
  307. .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
  308. .vec_dot_type = GGML_TYPE_Q8_K,
  309. .nrows = 1,
  310. },
  311. [GGML_TYPE_IQ3_S] = {
  312. .vec_dot = ggml_vec_dot_iq3_s_q8_K,
  313. .vec_dot_type = GGML_TYPE_Q8_K,
  314. .nrows = 1,
  315. },
  316. [GGML_TYPE_IQ2_S] = {
  317. .vec_dot = ggml_vec_dot_iq2_s_q8_K,
  318. .vec_dot_type = GGML_TYPE_Q8_K,
  319. .nrows = 1,
  320. },
  321. [GGML_TYPE_IQ1_S] = {
  322. .vec_dot = ggml_vec_dot_iq1_s_q8_K,
  323. .vec_dot_type = GGML_TYPE_Q8_K,
  324. .nrows = 1,
  325. },
  326. [GGML_TYPE_IQ1_M] = {
  327. .vec_dot = ggml_vec_dot_iq1_m_q8_K,
  328. .vec_dot_type = GGML_TYPE_Q8_K,
  329. .nrows = 1,
  330. },
  331. [GGML_TYPE_IQ4_NL] = {
  332. .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
  333. .vec_dot_type = GGML_TYPE_Q8_0,
  334. .nrows = 1,
  335. },
  336. [GGML_TYPE_IQ4_XS] = {
  337. .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
  338. .vec_dot_type = GGML_TYPE_Q8_K,
  339. .nrows = 1,
  340. },
  341. [GGML_TYPE_BF16] = {
  342. .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
  343. .vec_dot_type = GGML_TYPE_BF16,
  344. .nrows = 1,
  345. },
  346. [GGML_TYPE_Q4_0_4_4] = {
  347. .vec_dot = NULL,
  348. .vec_dot_type = GGML_TYPE_Q8_0,
  349. .nrows = 1,
  350. .ncols = 4,
  351. .gemv = ggml_gemv_q4_0_4x4_q8_0,
  352. .gemm = ggml_gemm_q4_0_4x4_q8_0,
  353. },
  354. [GGML_TYPE_Q4_0_4_8] = {
  355. .vec_dot = NULL,
  356. .vec_dot_type = GGML_TYPE_Q8_0,
  357. .nrows = 1,
  358. .ncols = 4,
  359. .gemv = ggml_gemv_q4_0_4x8_q8_0,
  360. .gemm = ggml_gemm_q4_0_4x8_q8_0,
  361. },
  362. [GGML_TYPE_Q4_0_8_8] = {
  363. .nrows = 1,
  364. .ncols = 8,
  365. .gemv = ggml_gemv_q4_0_8x8_q8_0,
  366. .gemm = ggml_gemm_q4_0_8x8_q8_0,
  367. },
  368. [GGML_TYPE_TQ1_0] = {
  369. .vec_dot = ggml_vec_dot_tq1_0_q8_K,
  370. .vec_dot_type = GGML_TYPE_Q8_K,
  371. .nrows = 1,
  372. },
  373. [GGML_TYPE_TQ2_0] = {
  374. .vec_dot = ggml_vec_dot_tq2_0_q8_K,
  375. .vec_dot_type = GGML_TYPE_Q8_K,
  376. .nrows = 1,
  377. },
  378. };
  379. const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
  380. return &type_traits_cpu[type];
  381. }
  382. //
  383. // simd mappings
  384. //
  385. // we define a common set of C macros which map to specific intrinsics based on the current architecture
  386. // we then implement the fundamental computation operations below using only these macros
  387. // adding support for new architectures requires to define the corresponding SIMD macros
  388. //
  389. // GGML_F32_STEP / GGML_F16_STEP
  390. // number of elements to process in a single step
  391. //
  392. // GGML_F32_EPR / GGML_F16_EPR
  393. // number of elements to fit in a single register
  394. //
  395. #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
  396. #define GGML_SIMD
  397. // F32 NEON
  398. #define GGML_F32_STEP 16
  399. #define GGML_F32_EPR 4
  400. #define GGML_F32x4 float32x4_t
  401. #define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
  402. #define GGML_F32x4_SET1(x) vdupq_n_f32(x)
  403. #define GGML_F32x4_LOAD vld1q_f32
  404. #define GGML_F32x4_STORE vst1q_f32
  405. #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
  406. #define GGML_F32x4_ADD vaddq_f32
  407. #define GGML_F32x4_MUL vmulq_f32
  408. #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
  409. #define GGML_F32x4_REDUCE(res, x) \
  410. { \
  411. int offset = GGML_F32_ARR >> 1; \
  412. for (int i = 0; i < offset; ++i) { \
  413. (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
  414. } \
  415. offset >>= 1; \
  416. for (int i = 0; i < offset; ++i) { \
  417. (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
  418. } \
  419. offset >>= 1; \
  420. for (int i = 0; i < offset; ++i) { \
  421. (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
  422. } \
  423. (res) = GGML_F32x4_REDUCE_ONE((x)[0]); \
  424. }
  425. #define GGML_F32_VEC GGML_F32x4
  426. #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
  427. #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
  428. #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
  429. #define GGML_F32_VEC_STORE GGML_F32x4_STORE
  430. #define GGML_F32_VEC_FMA GGML_F32x4_FMA
  431. #define GGML_F32_VEC_ADD GGML_F32x4_ADD
  432. #define GGML_F32_VEC_MUL GGML_F32x4_MUL
  433. #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
  434. // F16 NEON
  435. #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
  436. #define GGML_F16_STEP 32
  437. #define GGML_F16_EPR 8
  438. #define GGML_F16x8 float16x8_t
  439. #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
  440. #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
  441. #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x))
  442. #define GGML_F16x8_STORE vst1q_f16
  443. #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
  444. #define GGML_F16x8_ADD vaddq_f16
  445. #define GGML_F16x8_MUL vmulq_f16
  446. #define GGML_F16x8_REDUCE(res, x) \
  447. do { \
  448. int offset = GGML_F16_ARR >> 1; \
  449. for (int i = 0; i < offset; ++i) { \
  450. (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
  451. } \
  452. offset >>= 1; \
  453. for (int i = 0; i < offset; ++i) { \
  454. (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
  455. } \
  456. offset >>= 1; \
  457. for (int i = 0; i < offset; ++i) { \
  458. (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
  459. } \
  460. const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
  461. const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
  462. (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
  463. } while (0)
  464. #define GGML_F16_VEC GGML_F16x8
  465. #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
  466. #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
  467. #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
  468. #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
  469. #define GGML_F16_VEC_FMA GGML_F16x8_FMA
  470. #define GGML_F16_VEC_ADD GGML_F16x8_ADD
  471. #define GGML_F16_VEC_MUL GGML_F16x8_MUL
  472. #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
  473. #else
  474. // if FP16 vector arithmetic is not supported, we use FP32 instead
  475. // and take advantage of the vcvt_ functions to convert to/from FP16
  476. #define GGML_F16_STEP 16
  477. #define GGML_F16_EPR 4
  478. #define GGML_F32Cx4 float32x4_t
  479. #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
  480. #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
  481. #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
  482. #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
  483. #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
  484. #define GGML_F32Cx4_ADD vaddq_f32
  485. #define GGML_F32Cx4_MUL vmulq_f32
  486. #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
  487. #define GGML_F16_VEC GGML_F32Cx4
  488. #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
  489. #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
  490. #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
  491. #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
  492. #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
  493. #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
  494. #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
  495. #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
  496. #endif
  497. #elif defined(__AVX512F__)
  498. #define GGML_SIMD
  499. // F32 AVX512
  500. #define GGML_F32_STEP 64
  501. #define GGML_F32_EPR 16
  502. #define GGML_F32x16 __m512
  503. #define GGML_F32x16_ZERO _mm512_setzero_ps()
  504. #define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
  505. #define GGML_F32x16_LOAD _mm512_loadu_ps
  506. #define GGML_F32x16_STORE _mm512_storeu_ps
  507. // _mm512_fmadd_ps is defined in AVX512F so no guard is required
  508. #define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
  509. #define GGML_F32x16_ADD _mm512_add_ps
  510. #define GGML_F32x16_MUL _mm512_mul_ps
  511. #define GGML_F32x16_REDUCE(res, x) \
  512. do { \
  513. int offset = GGML_F32_ARR >> 1; \
  514. for (int i = 0; i < offset; ++i) { \
  515. x[i] = _mm512_add_ps(x[i], x[offset+i]); \
  516. } \
  517. offset >>= 1; \
  518. for (int i = 0; i < offset; ++i) { \
  519. x[i] = _mm512_add_ps(x[i], x[offset+i]); \
  520. } \
  521. offset >>= 1; \
  522. for (int i = 0; i < offset; ++i) { \
  523. x[i] = _mm512_add_ps(x[i], x[offset+i]); \
  524. } \
  525. res = _mm512_reduce_add_ps(x[0]); \
  526. } while (0)
  527. // TODO: is this optimal ?
  528. #define GGML_F32_VEC GGML_F32x16
  529. #define GGML_F32_VEC_ZERO GGML_F32x16_ZERO
  530. #define GGML_F32_VEC_SET1 GGML_F32x16_SET1
  531. #define GGML_F32_VEC_LOAD GGML_F32x16_LOAD
  532. #define GGML_F32_VEC_STORE GGML_F32x16_STORE
  533. #define GGML_F32_VEC_FMA GGML_F32x16_FMA
  534. #define GGML_F32_VEC_ADD GGML_F32x16_ADD
  535. #define GGML_F32_VEC_MUL GGML_F32x16_MUL
  536. #define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
  537. // F16 AVX512
  538. // F16 AVX
  539. #define GGML_F16_STEP 64
  540. #define GGML_F16_EPR 16
  541. // AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
  542. #define GGML_F32Cx16 __m512
  543. #define GGML_F32Cx16_ZERO _mm512_setzero_ps()
  544. #define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x)
  545. // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
  546. // so F16C guard isn't required
  547. #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
  548. #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
  549. #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
  550. #define GGML_F32Cx16_ADD _mm512_add_ps
  551. #define GGML_F32Cx16_MUL _mm512_mul_ps
  552. #define GGML_F32Cx16_REDUCE(res, x) \
  553. do { \
  554. int offset = GGML_F32_ARR >> 1; \
  555. for (int i = 0; i < offset; ++i) { \
  556. x[i] = _mm512_add_ps(x[i], x[offset+i]); \
  557. } \
  558. offset >>= 1; \
  559. for (int i = 0; i < offset; ++i) { \
  560. x[i] = _mm512_add_ps(x[i], x[offset+i]); \
  561. } \
  562. offset >>= 1; \
  563. for (int i = 0; i < offset; ++i) { \
  564. x[i] = _mm512_add_ps(x[i], x[offset+i]); \
  565. } \
  566. res = _mm512_reduce_add_ps(x[0]); \
  567. } while (0)
  568. #define GGML_F16_VEC GGML_F32Cx16
  569. #define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO
  570. #define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1
  571. #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p)
  572. #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
  573. #define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
  574. #define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
  575. #define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
  576. #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
  577. #elif defined(__AVX__)
  578. #define GGML_SIMD
  579. // F32 AVX
  580. #define GGML_F32_STEP 32
  581. #define GGML_F32_EPR 8
  582. #define GGML_F32x8 __m256
  583. #define GGML_F32x8_ZERO _mm256_setzero_ps()
  584. #define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
  585. #define GGML_F32x8_LOAD _mm256_loadu_ps
  586. #define GGML_F32x8_STORE _mm256_storeu_ps
  587. #if defined(__FMA__)
  588. #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
  589. #else
  590. #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
  591. #endif
  592. #define GGML_F32x8_ADD _mm256_add_ps
  593. #define GGML_F32x8_MUL _mm256_mul_ps
  594. #define GGML_F32x8_REDUCE(res, x) \
  595. do { \
  596. int offset = GGML_F32_ARR >> 1; \
  597. for (int i = 0; i < offset; ++i) { \
  598. x[i] = _mm256_add_ps(x[i], x[offset+i]); \
  599. } \
  600. offset >>= 1; \
  601. for (int i = 0; i < offset; ++i) { \
  602. x[i] = _mm256_add_ps(x[i], x[offset+i]); \
  603. } \
  604. offset >>= 1; \
  605. for (int i = 0; i < offset; ++i) { \
  606. x[i] = _mm256_add_ps(x[i], x[offset+i]); \
  607. } \
  608. const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
  609. _mm256_extractf128_ps(x[0], 1)); \
  610. const __m128 t1 = _mm_hadd_ps(t0, t0); \
  611. res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
  612. } while (0)
  613. // TODO: is this optimal ?
  614. #define GGML_F32_VEC GGML_F32x8
  615. #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
  616. #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
  617. #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
  618. #define GGML_F32_VEC_STORE GGML_F32x8_STORE
  619. #define GGML_F32_VEC_FMA GGML_F32x8_FMA
  620. #define GGML_F32_VEC_ADD GGML_F32x8_ADD
  621. #define GGML_F32_VEC_MUL GGML_F32x8_MUL
  622. #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
  623. // F16 AVX
  624. #define GGML_F16_STEP 32
  625. #define GGML_F16_EPR 8
  626. // F16 arithmetic is not supported by AVX, so we use F32 instead
  627. #define GGML_F32Cx8 __m256
  628. #define GGML_F32Cx8_ZERO _mm256_setzero_ps()
  629. #define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
  630. #if defined(__F16C__)
  631. // the _mm256_cvt intrinsics require F16C
  632. #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
  633. #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
  634. #else
  635. static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
  636. float tmp[8];
  637. for (int i = 0; i < 8; i++) {
  638. tmp[i] = GGML_FP16_TO_FP32(x[i]);
  639. }
  640. return _mm256_loadu_ps(tmp);
  641. }
  642. static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
  643. float arr[8];
  644. _mm256_storeu_ps(arr, y);
  645. for (int i = 0; i < 8; i++)
  646. x[i] = GGML_FP32_TO_FP16(arr[i]);
  647. }
  648. #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
  649. #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
  650. #endif
  651. #define GGML_F32Cx8_FMA GGML_F32x8_FMA
  652. #define GGML_F32Cx8_ADD _mm256_add_ps
  653. #define GGML_F32Cx8_MUL _mm256_mul_ps
  654. #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
  655. #define GGML_F16_VEC GGML_F32Cx8
  656. #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
  657. #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
  658. #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
  659. #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
  660. #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
  661. #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
  662. #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
  663. #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
  664. #elif defined(__POWER9_VECTOR__)
  665. #define GGML_SIMD
  666. // F32 POWER9
  667. #define GGML_F32_STEP 32
  668. #define GGML_F32_EPR 4
  669. #define GGML_F32x4 vector float
  670. #define GGML_F32x4_ZERO 0.0f
  671. #define GGML_F32x4_SET1 vec_splats
  672. #define GGML_F32x4_LOAD(p) vec_xl(0, p)
  673. #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
  674. #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
  675. #define GGML_F32x4_ADD vec_add
  676. #define GGML_F32x4_MUL vec_mul
  677. #define GGML_F32x4_REDUCE(res, x) \
  678. { \
  679. int offset = GGML_F32_ARR >> 1; \
  680. for (int i = 0; i < offset; ++i) { \
  681. x[i] = vec_add(x[i], x[offset+i]); \
  682. } \
  683. offset >>= 1; \
  684. for (int i = 0; i < offset; ++i) { \
  685. x[i] = vec_add(x[i], x[offset+i]); \
  686. } \
  687. offset >>= 1; \
  688. for (int i = 0; i < offset; ++i) { \
  689. x[i] = vec_add(x[i], x[offset+i]); \
  690. } \
  691. res = vec_extract(x[0], 0) + \
  692. vec_extract(x[0], 1) + \
  693. vec_extract(x[0], 2) + \
  694. vec_extract(x[0], 3); \
  695. }
  696. #define GGML_F32_VEC GGML_F32x4
  697. #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
  698. #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
  699. #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
  700. #define GGML_F32_VEC_STORE GGML_F32x4_STORE
  701. #define GGML_F32_VEC_FMA GGML_F32x4_FMA
  702. #define GGML_F32_VEC_ADD GGML_F32x4_ADD
  703. #define GGML_F32_VEC_MUL GGML_F32x4_MUL
  704. #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
  705. // F16 POWER9
  706. #define GGML_F16_STEP GGML_F32_STEP
  707. #define GGML_F16_EPR GGML_F32_EPR
  708. #define GGML_F16_VEC GGML_F32x4
  709. #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
  710. #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
  711. #define GGML_F16_VEC_FMA GGML_F32x4_FMA
  712. #define GGML_F16_VEC_ADD GGML_F32x4_ADD
  713. #define GGML_F16_VEC_MUL GGML_F32x4_MUL
  714. #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
  715. // Use vec_xl, not vec_ld, in case the load address is not aligned.
  716. #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
  717. vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
  718. vec_extract_fp32_from_shortl(vec_xl(0, p))
  719. #define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
  720. #define GGML_F16_VEC_STORE(p, r, i) \
  721. if (i & 0x1) \
  722. vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
  723. r[i - GGML_ENDIAN_BYTE(0)]), \
  724. 0, p - GGML_F16_EPR)
  725. #elif defined(__wasm_simd128__)
  726. #define GGML_SIMD
  727. // F32 WASM
  728. #define GGML_F32_STEP 16
  729. #define GGML_F32_EPR 4
  730. #define GGML_F32x4 v128_t
  731. #define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
  732. #define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
  733. #define GGML_F32x4_LOAD wasm_v128_load
  734. #define GGML_F32x4_STORE wasm_v128_store
  735. #define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
  736. #define GGML_F32x4_ADD wasm_f32x4_add
  737. #define GGML_F32x4_MUL wasm_f32x4_mul
  738. #define GGML_F32x4_REDUCE(res, x) \
  739. { \
  740. int offset = GGML_F32_ARR >> 1; \
  741. for (int i = 0; i < offset; ++i) { \
  742. x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
  743. } \
  744. offset >>= 1; \
  745. for (int i = 0; i < offset; ++i) { \
  746. x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
  747. } \
  748. offset >>= 1; \
  749. for (int i = 0; i < offset; ++i) { \
  750. x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
  751. } \
  752. res = wasm_f32x4_extract_lane(x[0], 0) + \
  753. wasm_f32x4_extract_lane(x[0], 1) + \
  754. wasm_f32x4_extract_lane(x[0], 2) + \
  755. wasm_f32x4_extract_lane(x[0], 3); \
  756. }
  757. #define GGML_F32_VEC GGML_F32x4
  758. #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
  759. #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
  760. #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
  761. #define GGML_F32_VEC_STORE GGML_F32x4_STORE
  762. #define GGML_F32_VEC_FMA GGML_F32x4_FMA
  763. #define GGML_F32_VEC_ADD GGML_F32x4_ADD
  764. #define GGML_F32_VEC_MUL GGML_F32x4_MUL
  765. #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
  766. // F16 WASM
  767. #define GGML_F16_STEP 16
  768. #define GGML_F16_EPR 4
  769. inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
  770. float tmp[4];
  771. tmp[0] = GGML_FP16_TO_FP32(p[0]);
  772. tmp[1] = GGML_FP16_TO_FP32(p[1]);
  773. tmp[2] = GGML_FP16_TO_FP32(p[2]);
  774. tmp[3] = GGML_FP16_TO_FP32(p[3]);
  775. return wasm_v128_load(tmp);
  776. }
  777. inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
  778. float tmp[4];
  779. wasm_v128_store(tmp, x);
  780. p[0] = GGML_FP32_TO_FP16(tmp[0]);
  781. p[1] = GGML_FP32_TO_FP16(tmp[1]);
  782. p[2] = GGML_FP32_TO_FP16(tmp[2]);
  783. p[3] = GGML_FP32_TO_FP16(tmp[3]);
  784. }
  785. #define GGML_F16x4 v128_t
  786. #define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
  787. #define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
  788. #define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
  789. #define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
  790. #define GGML_F16x4_FMA GGML_F32x4_FMA
  791. #define GGML_F16x4_ADD wasm_f32x4_add
  792. #define GGML_F16x4_MUL wasm_f32x4_mul
  793. #define GGML_F16x4_REDUCE(res, x) \
  794. { \
  795. int offset = GGML_F16_ARR >> 1; \
  796. for (int i = 0; i < offset; ++i) { \
  797. x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
  798. } \
  799. offset >>= 1; \
  800. for (int i = 0; i < offset; ++i) { \
  801. x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
  802. } \
  803. offset >>= 1; \
  804. for (int i = 0; i < offset; ++i) { \
  805. x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
  806. } \
  807. res = wasm_f32x4_extract_lane(x[0], 0) + \
  808. wasm_f32x4_extract_lane(x[0], 1) + \
  809. wasm_f32x4_extract_lane(x[0], 2) + \
  810. wasm_f32x4_extract_lane(x[0], 3); \
  811. }
  812. #define GGML_F16_VEC GGML_F16x4
  813. #define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
  814. #define GGML_F16_VEC_SET1 GGML_F16x4_SET1
  815. #define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
  816. #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
  817. #define GGML_F16_VEC_FMA GGML_F16x4_FMA
  818. #define GGML_F16_VEC_ADD GGML_F16x4_ADD
  819. #define GGML_F16_VEC_MUL GGML_F16x4_MUL
  820. #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
  821. #elif defined(__SSE3__)
  822. #define GGML_SIMD
  823. // F32 SSE
  824. #define GGML_F32_STEP 32
  825. #define GGML_F32_EPR 4
  826. #define GGML_F32x4 __m128
  827. #define GGML_F32x4_ZERO _mm_setzero_ps()
  828. #define GGML_F32x4_SET1(x) _mm_set1_ps(x)
  829. #define GGML_F32x4_LOAD _mm_loadu_ps
  830. #define GGML_F32x4_STORE _mm_storeu_ps
  831. #if defined(__FMA__)
  832. // TODO: Does this work?
  833. #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
  834. #else
  835. #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
  836. #endif
  837. #define GGML_F32x4_ADD _mm_add_ps
  838. #define GGML_F32x4_MUL _mm_mul_ps
  839. #define GGML_F32x4_REDUCE(res, x) \
  840. { \
  841. int offset = GGML_F32_ARR >> 1; \
  842. for (int i = 0; i < offset; ++i) { \
  843. x[i] = _mm_add_ps(x[i], x[offset+i]); \
  844. } \
  845. offset >>= 1; \
  846. for (int i = 0; i < offset; ++i) { \
  847. x[i] = _mm_add_ps(x[i], x[offset+i]); \
  848. } \
  849. offset >>= 1; \
  850. for (int i = 0; i < offset; ++i) { \
  851. x[i] = _mm_add_ps(x[i], x[offset+i]); \
  852. } \
  853. const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
  854. res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
  855. }
  856. // TODO: is this optimal ?
  857. #define GGML_F32_VEC GGML_F32x4
  858. #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
  859. #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
  860. #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
  861. #define GGML_F32_VEC_STORE GGML_F32x4_STORE
  862. #define GGML_F32_VEC_FMA GGML_F32x4_FMA
  863. #define GGML_F32_VEC_ADD GGML_F32x4_ADD
  864. #define GGML_F32_VEC_MUL GGML_F32x4_MUL
  865. #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
  866. // F16 SSE
  867. #define GGML_F16_STEP 32
  868. #define GGML_F16_EPR 4
  869. static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
  870. float tmp[4];
  871. tmp[0] = GGML_FP16_TO_FP32(x[0]);
  872. tmp[1] = GGML_FP16_TO_FP32(x[1]);
  873. tmp[2] = GGML_FP16_TO_FP32(x[2]);
  874. tmp[3] = GGML_FP16_TO_FP32(x[3]);
  875. return _mm_loadu_ps(tmp);
  876. }
  877. static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
  878. float arr[4];
  879. _mm_storeu_ps(arr, y);
  880. x[0] = GGML_FP32_TO_FP16(arr[0]);
  881. x[1] = GGML_FP32_TO_FP16(arr[1]);
  882. x[2] = GGML_FP32_TO_FP16(arr[2]);
  883. x[3] = GGML_FP32_TO_FP16(arr[3]);
  884. }
  885. #define GGML_F32Cx4 __m128
  886. #define GGML_F32Cx4_ZERO _mm_setzero_ps()
  887. #define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
  888. #define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
  889. #define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
  890. #define GGML_F32Cx4_FMA GGML_F32x4_FMA
  891. #define GGML_F32Cx4_ADD _mm_add_ps
  892. #define GGML_F32Cx4_MUL _mm_mul_ps
  893. #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
  894. #define GGML_F16_VEC GGML_F32Cx4
  895. #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
  896. #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
  897. #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
  898. #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
  899. #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
  900. #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
  901. #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
  902. #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
  903. #elif defined(__loongarch_asx)
  904. #define GGML_SIMD
  905. // F32 LASX
  906. #define GGML_F32_STEP 32
  907. #define GGML_F32_EPR 8
  908. #define GGML_F32x8 __m256
  909. #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
  910. #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
  911. #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
  912. #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
  913. #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
  914. #define GGML_F32x8_ADD __lasx_xvfadd_s
  915. #define GGML_F32x8_MUL __lasx_xvfmul_s
  916. #define GGML_F32x8_REDUCE(res, x) \
  917. do { \
  918. int offset = GGML_F32_ARR >> 1; \
  919. for (int i = 0; i < offset; ++i) { \
  920. x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
  921. } \
  922. offset >>= 1; \
  923. for (int i = 0; i < offset; ++i) { \
  924. x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
  925. } \
  926. offset >>= 1; \
  927. for (int i = 0; i < offset; ++i) { \
  928. x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
  929. } \
  930. float *tmp_p = (float *)&x[0]; \
  931. res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
  932. } while (0)
  933. // TODO: is this optimal ?
  934. #define GGML_F32_VEC GGML_F32x8
  935. #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
  936. #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
  937. #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
  938. #define GGML_F32_VEC_STORE GGML_F32x8_STORE
  939. #define GGML_F32_VEC_FMA GGML_F32x8_FMA
  940. #define GGML_F32_VEC_ADD GGML_F32x8_ADD
  941. #define GGML_F32_VEC_MUL GGML_F32x8_MUL
  942. #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
  943. // F16 LASX
  944. #define GGML_F16_STEP 32
  945. #define GGML_F16_EPR 8
  946. // F16 arithmetic is not supported by AVX, so we use F32 instead
  947. #define GGML_F32Cx8 __m256
  948. #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
  949. #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
  950. static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
  951. float tmp[8];
  952. for (int i = 0; i < 8; i++) {
  953. tmp[i] = GGML_FP16_TO_FP32(x[i]);
  954. }
  955. return (__m256)__lasx_xvld(tmp, 0);
  956. }
  957. static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
  958. float arr[8];
  959. __lasx_xvst(y, arr, 0);
  960. for (int i = 0; i < 8; i++) {
  961. x[i] = GGML_FP32_TO_FP16(arr[i]);
  962. }
  963. }
  964. #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
  965. #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
  966. #define GGML_F32Cx8_FMA GGML_F32x8_FMA
  967. #define GGML_F32Cx8_ADD __lasx_xvfadd_s
  968. #define GGML_F32Cx8_MUL __lasx_xvfmul_s
  969. #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
  970. #define GGML_F16_VEC GGML_F32Cx8
  971. #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
  972. #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
  973. #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
  974. #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
  975. #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
  976. #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
  977. #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
  978. #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
  979. #elif defined(__loongarch_sx)
  980. #define GGML_SIMD
  981. // F32 LSX
  982. #define GGML_F32_STEP 32
  983. #define GGML_F32_EPR 4
  984. #define GGML_F32x4 __m128
  985. #define GGML_F32x4_ZERO __lsx_vldi(0)
  986. #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
  987. #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
  988. #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
  989. #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
  990. #define GGML_F32x4_ADD __lsx_vfadd_s
  991. #define GGML_F32x4_MUL __lsx_vfmul_s
  992. #define GGML_F32x4_REDUCE(res, x) \
  993. { \
  994. int offset = GGML_F32_ARR >> 1; \
  995. for (int i = 0; i < offset; ++i) { \
  996. x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
  997. } \
  998. offset >>= 1; \
  999. for (int i = 0; i < offset; ++i) { \
  1000. x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
  1001. } \
  1002. offset >>= 1; \
  1003. for (int i = 0; i < offset; ++i) { \
  1004. x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
  1005. } \
  1006. __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
  1007. tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
  1008. tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
  1009. const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
  1010. tmp = __lsx_vsrli_d((__m128i)t0, 32); \
  1011. tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
  1012. tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
  1013. res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
  1014. }
  1015. #define GGML_F32_VEC GGML_F32x4
  1016. #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
  1017. #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
  1018. #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
  1019. #define GGML_F32_VEC_STORE GGML_F32x4_STORE
  1020. #define GGML_F32_VEC_FMA GGML_F32x4_FMA
  1021. #define GGML_F32_VEC_ADD GGML_F32x4_ADD
  1022. #define GGML_F32_VEC_MUL GGML_F32x4_MUL
  1023. #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
  1024. // F16 LSX
  1025. #define GGML_F16_STEP 32
  1026. #define GGML_F16_EPR 4
  1027. static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
  1028. float tmp[4];
  1029. tmp[0] = GGML_FP16_TO_FP32(x[0]);
  1030. tmp[1] = GGML_FP16_TO_FP32(x[1]);
  1031. tmp[2] = GGML_FP16_TO_FP32(x[2]);
  1032. tmp[3] = GGML_FP16_TO_FP32(x[3]);
  1033. return __lsx_vld(tmp, 0);
  1034. }
  1035. static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
  1036. float arr[4];
  1037. __lsx_vst(y, arr, 0);
  1038. x[0] = GGML_FP32_TO_FP16(arr[0]);
  1039. x[1] = GGML_FP32_TO_FP16(arr[1]);
  1040. x[2] = GGML_FP32_TO_FP16(arr[2]);
  1041. x[3] = GGML_FP32_TO_FP16(arr[3]);
  1042. }
  1043. #define GGML_F32Cx4 __m128
  1044. #define GGML_F32Cx4_ZERO __lsx_vldi(0)
  1045. #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
  1046. #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
  1047. #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
  1048. #define GGML_F32Cx4_FMA GGML_F32x4_FMA
  1049. #define GGML_F32Cx4_ADD __lsx_vfadd_s
  1050. #define GGML_F32Cx4_MUL __lsx_vfmul_s
  1051. #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
  1052. #define GGML_F16_VEC GGML_F32Cx4
  1053. #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
  1054. #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
  1055. #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
  1056. #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
  1057. #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
  1058. #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
  1059. #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
  1060. #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
  1061. #endif
  1062. // GGML_F32_ARR / GGML_F16_ARR
  1063. // number of registers to use per step
  1064. #ifdef GGML_SIMD
  1065. #define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
  1066. #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
  1067. #endif
  1068. //
  1069. // Threading defs
  1070. //
  1071. typedef pthread_t ggml_thread_t;
  1072. #if defined(_WIN32)
  1073. typedef CONDITION_VARIABLE ggml_cond_t;
  1074. typedef SRWLOCK ggml_mutex_t;
  1075. #define ggml_mutex_init(m) InitializeSRWLock(m)
  1076. #define ggml_mutex_destroy(m)
  1077. #define ggml_mutex_lock(m) AcquireSRWLockExclusive(m)
  1078. #define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
  1079. #define ggml_mutex_lock_shared(m) AcquireSRWLockShared(m)
  1080. #define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
  1081. #define ggml_cond_init(c) InitializeConditionVariable(c)
  1082. #define ggml_cond_destroy(c)
  1083. #define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
  1084. #define ggml_cond_broadcast(c) WakeAllConditionVariable(c)
  1085. #define ggml_thread_create pthread_create
  1086. #define ggml_thread_join pthread_join
  1087. #else
  1088. typedef pthread_cond_t ggml_cond_t;
  1089. typedef pthread_mutex_t ggml_mutex_t;
  1090. #define ggml_mutex_init(m) pthread_mutex_init(m, NULL)
  1091. #define ggml_mutex_destroy(m) pthread_mutex_destroy(m)
  1092. #define ggml_mutex_lock(m) pthread_mutex_lock(m)
  1093. #define ggml_mutex_unlock(m) pthread_mutex_unlock(m)
  1094. #define ggml_mutex_lock_shared(m) pthread_mutex_lock(m)
  1095. #define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
  1096. #define ggml_lock_init(x) UNUSED(x)
  1097. #define ggml_lock_destroy(x) UNUSED(x)
  1098. #if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
  1099. #define ggml_lock_lock(x) _mm_pause()
  1100. #else
  1101. #define ggml_lock_lock(x) UNUSED(x)
  1102. #endif
  1103. #define ggml_lock_unlock(x) UNUSED(x)
  1104. #define GGML_LOCK_INITIALIZER 0
  1105. #define ggml_cond_init(c) pthread_cond_init(c, NULL)
  1106. #define ggml_cond_destroy(c) pthread_cond_destroy(c)
  1107. #define ggml_cond_wait(c, m) pthread_cond_wait(c, m)
  1108. #define ggml_cond_broadcast(c) pthread_cond_broadcast(c)
  1109. #define ggml_thread_create pthread_create
  1110. #define ggml_thread_join pthread_join
  1111. #endif
  1112. // Threadpool def
  1113. struct ggml_threadpool {
  1114. ggml_mutex_t mutex; // mutex for cond.var
  1115. ggml_cond_t cond; // cond.var for waiting for new work
  1116. struct ggml_cgraph * cgraph;
  1117. struct ggml_cplan * cplan;
  1118. // synchronization primitives
  1119. atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
  1120. atomic_int GGML_CACHE_ALIGN n_barrier;
  1121. atomic_int GGML_CACHE_ALIGN n_barrier_passed;
  1122. atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
  1123. // these are atomic as an annotation for thread-sanitizer
  1124. atomic_bool stop; // Used for stopping the threadpool altogether
  1125. atomic_bool pause; // Used for pausing the threadpool or individual threads
  1126. atomic_bool abort; // Used for aborting processing of a graph
  1127. struct ggml_compute_state * workers; // per thread state
  1128. int n_threads_max; // number of threads in the pool
  1129. atomic_int n_threads_cur; // number of threads used in the current graph
  1130. int32_t prio; // Scheduling priority
  1131. uint32_t poll; // Polling level (0 - no polling)
  1132. enum ggml_status ec;
  1133. };
  1134. // Per-thread state
  1135. struct ggml_compute_state {
  1136. #ifndef GGML_USE_OPENMP
  1137. ggml_thread_t thrd;
  1138. bool cpumask[GGML_MAX_N_THREADS];
  1139. int last_graph;
  1140. bool pending;
  1141. #endif
  1142. struct ggml_threadpool * threadpool;
  1143. int ith;
  1144. };
  1145. struct ggml_compute_params {
  1146. // ith = thread index, nth = number of threads
  1147. int ith, nth;
  1148. // work buffer for all threads
  1149. size_t wsize;
  1150. void * wdata;
  1151. struct ggml_threadpool * threadpool;
  1152. };
  1153. //
  1154. // fundamental operations
  1155. //
  1156. inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  1157. inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  1158. inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  1159. inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  1160. inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
  1161. inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
  1162. inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
  1163. inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
  1164. inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
  1165. inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
  1166. inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
  1167. inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
  1168. inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
  1169. inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
  1170. inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
  1171. static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
  1172. assert(nrc == 1);
  1173. UNUSED(nrc);
  1174. UNUSED(bx);
  1175. UNUSED(by);
  1176. UNUSED(bs);
  1177. #if defined(GGML_SIMD)
  1178. float sumf = 0.0f;
  1179. const int np = (n & ~(GGML_F32_STEP - 1));
  1180. GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
  1181. GGML_F32_VEC ax[GGML_F32_ARR];
  1182. GGML_F32_VEC ay[GGML_F32_ARR];
  1183. for (int i = 0; i < np; i += GGML_F32_STEP) {
  1184. for (int j = 0; j < GGML_F32_ARR; j++) {
  1185. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  1186. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  1187. sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
  1188. }
  1189. }
  1190. // reduce sum0..sum3 to sum0
  1191. GGML_F32_VEC_REDUCE(sumf, sum);
  1192. // leftovers
  1193. for (int i = np; i < n; ++i) {
  1194. sumf += x[i]*y[i];
  1195. }
  1196. #else
  1197. // scalar
  1198. ggml_float sumf = 0.0;
  1199. for (int i = 0; i < n; ++i) {
  1200. sumf += (ggml_float)(x[i]*y[i]);
  1201. }
  1202. #endif
  1203. *s = sumf;
  1204. }
  1205. static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
  1206. assert(nrc == 1);
  1207. UNUSED(nrc);
  1208. UNUSED(bx);
  1209. UNUSED(by);
  1210. UNUSED(bs);
  1211. int i = 0;
  1212. ggml_float sumf = 0;
  1213. #if defined(__AVX512BF16__)
  1214. __m512 c1 = _mm512_setzero_ps();
  1215. __m512 c2 = _mm512_setzero_ps();
  1216. for (; i + 64 <= n; i += 64) {
  1217. c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
  1218. m512bh(_mm512_loadu_si512((y + i))));
  1219. c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
  1220. m512bh(_mm512_loadu_si512((y + i + 32))));
  1221. }
  1222. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  1223. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  1224. #elif defined(__AVX512F__)
  1225. #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
  1226. __m512 c1 = _mm512_setzero_ps();
  1227. __m512 c2 = _mm512_setzero_ps();
  1228. for (; i + 32 <= n; i += 32) {
  1229. c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  1230. c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
  1231. }
  1232. sumf += (ggml_float)_mm512_reduce_add_ps(c1);
  1233. sumf += (ggml_float)_mm512_reduce_add_ps(c2);
  1234. #undef LOAD
  1235. #elif defined(__AVX2__)
  1236. #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
  1237. __m256 c1 = _mm256_setzero_ps();
  1238. __m256 c2 = _mm256_setzero_ps();
  1239. __m256 c3 = _mm256_setzero_ps();
  1240. __m256 c4 = _mm256_setzero_ps();
  1241. for (; i + 32 <= n; i += 32) {
  1242. c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
  1243. c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
  1244. c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
  1245. c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
  1246. }
  1247. __m128 g;
  1248. c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
  1249. _mm256_add_ps(c2, c4));
  1250. g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
  1251. _mm256_castps256_ps128(c1));
  1252. g = _mm_add_ps(g, _mm_movehl_ps(g, g));
  1253. g = _mm_add_ss(g, _mm_movehdup_ps(g));
  1254. sumf += (ggml_float)_mm_cvtss_f32(g);
  1255. #undef LOAD
  1256. #endif
  1257. for (; i < n; ++i) {
  1258. sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
  1259. GGML_BF16_TO_FP32(y[i]));
  1260. }
  1261. *s = sumf;
  1262. }
  1263. static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
  1264. assert(nrc == 1);
  1265. UNUSED(nrc);
  1266. UNUSED(bx);
  1267. UNUSED(by);
  1268. UNUSED(bs);
  1269. ggml_float sumf = 0.0;
  1270. #if defined(GGML_SIMD)
  1271. const int np = (n & ~(GGML_F16_STEP - 1));
  1272. GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
  1273. GGML_F16_VEC ax[GGML_F16_ARR];
  1274. GGML_F16_VEC ay[GGML_F16_ARR];
  1275. for (int i = 0; i < np; i += GGML_F16_STEP) {
  1276. for (int j = 0; j < GGML_F16_ARR; j++) {
  1277. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  1278. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  1279. sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
  1280. }
  1281. }
  1282. // reduce sum0..sum3 to sum0
  1283. GGML_F16_VEC_REDUCE(sumf, sum);
  1284. // leftovers
  1285. for (int i = np; i < n; ++i) {
  1286. sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
  1287. }
  1288. #else
  1289. for (int i = 0; i < n; ++i) {
  1290. sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
  1291. }
  1292. #endif
  1293. *s = sumf;
  1294. }
  1295. // compute GGML_VEC_DOT_UNROLL dot products at once
  1296. // xs - x row stride in bytes
  1297. inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
  1298. ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
  1299. ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
  1300. for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
  1301. x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
  1302. }
  1303. #if defined(GGML_SIMD)
  1304. const int np = (n & ~(GGML_F16_STEP - 1));
  1305. GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
  1306. GGML_F16_VEC ax[GGML_F16_ARR];
  1307. GGML_F16_VEC ay[GGML_F16_ARR];
  1308. for (int i = 0; i < np; i += GGML_F16_STEP) {
  1309. for (int j = 0; j < GGML_F16_ARR; j++) {
  1310. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  1311. for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
  1312. ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
  1313. sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
  1314. }
  1315. }
  1316. }
  1317. // reduce sum0..sum3 to sum0
  1318. for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
  1319. GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
  1320. }
  1321. // leftovers
  1322. for (int i = np; i < n; ++i) {
  1323. for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
  1324. sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
  1325. }
  1326. }
  1327. #else
  1328. for (int i = 0; i < n; ++i) {
  1329. for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
  1330. sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
  1331. }
  1332. }
  1333. #endif
  1334. for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
  1335. s[i] = sumf[i];
  1336. }
  1337. }
  1338. inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
  1339. #if defined(GGML_SIMD)
  1340. const int np = (n & ~(GGML_F32_STEP - 1));
  1341. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  1342. GGML_F32_VEC ax[GGML_F32_ARR];
  1343. GGML_F32_VEC ay[GGML_F32_ARR];
  1344. for (int i = 0; i < np; i += GGML_F32_STEP) {
  1345. for (int j = 0; j < GGML_F32_ARR; j++) {
  1346. ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
  1347. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  1348. ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
  1349. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  1350. }
  1351. }
  1352. // leftovers
  1353. for (int i = np; i < n; ++i) {
  1354. y[i] += x[i]*v;
  1355. }
  1356. #else
  1357. // scalar
  1358. for (int i = 0; i < n; ++i) {
  1359. y[i] += x[i]*v;
  1360. }
  1361. #endif
  1362. }
  1363. inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
  1364. #if defined(GGML_SIMD)
  1365. const int np = (n & ~(GGML_F16_STEP - 1));
  1366. GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
  1367. GGML_F16_VEC ax[GGML_F16_ARR];
  1368. GGML_F16_VEC ay[GGML_F16_ARR];
  1369. for (int i = 0; i < np; i += GGML_F16_STEP) {
  1370. for (int j = 0; j < GGML_F16_ARR; j++) {
  1371. ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
  1372. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  1373. ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
  1374. GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
  1375. }
  1376. }
  1377. // leftovers
  1378. for (int i = np; i < n; ++i) {
  1379. y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
  1380. }
  1381. #else
  1382. // scalar
  1383. for (int i = 0; i < n; ++i) {
  1384. y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
  1385. }
  1386. #endif
  1387. }
  1388. // xs and vs are byte strides of x and v
  1389. inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
  1390. const float * restrict x[GGML_VEC_MAD_UNROLL];
  1391. const float * restrict v[GGML_VEC_MAD_UNROLL];
  1392. for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
  1393. x[i] = (const float *) ((const char *) xv + i*xs);
  1394. v[i] = (const float *) ((const char *) vv + i*vs);
  1395. }
  1396. #if defined(GGML_SIMD)
  1397. const int np = (n & ~(GGML_F32_STEP - 1));
  1398. GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
  1399. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  1400. vx[k] = GGML_F32_VEC_SET1(v[k][0]);
  1401. }
  1402. GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
  1403. GGML_F32_VEC ay[GGML_F32_ARR];
  1404. for (int i = 0; i < np; i += GGML_F32_STEP) {
  1405. for (int j = 0; j < GGML_F32_ARR; j++) {
  1406. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  1407. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  1408. ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
  1409. ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
  1410. }
  1411. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  1412. }
  1413. }
  1414. // leftovers
  1415. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  1416. for (int i = np; i < n; ++i) {
  1417. y[i] += x[k][i]*v[k][0];
  1418. }
  1419. }
  1420. #else
  1421. // scalar
  1422. for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
  1423. for (int i = 0; i < n; ++i) {
  1424. y[i] += x[k][i]*v[k][0];
  1425. }
  1426. }
  1427. #endif
  1428. }
  1429. //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
  1430. inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
  1431. #if defined(GGML_USE_ACCELERATE)
  1432. vDSP_vsmul(y, 1, &v, y, 1, n);
  1433. #elif defined(GGML_SIMD)
  1434. const int np = (n & ~(GGML_F32_STEP - 1));
  1435. GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
  1436. GGML_F32_VEC ay[GGML_F32_ARR];
  1437. for (int i = 0; i < np; i += GGML_F32_STEP) {
  1438. for (int j = 0; j < GGML_F32_ARR; j++) {
  1439. ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
  1440. ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
  1441. GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
  1442. }
  1443. }
  1444. // leftovers
  1445. for (int i = np; i < n; ++i) {
  1446. y[i] *= v;
  1447. }
  1448. #else
  1449. // scalar
  1450. for (int i = 0; i < n; ++i) {
  1451. y[i] *= v;
  1452. }
  1453. #endif
  1454. }
  1455. inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
  1456. #if defined(GGML_SIMD)
  1457. const int np = (n & ~(GGML_F16_STEP - 1));
  1458. GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
  1459. GGML_F16_VEC ay[GGML_F16_ARR];
  1460. for (int i = 0; i < np; i += GGML_F16_STEP) {
  1461. for (int j = 0; j < GGML_F16_ARR; j++) {
  1462. ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
  1463. ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
  1464. GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
  1465. }
  1466. }
  1467. // leftovers
  1468. for (int i = np; i < n; ++i) {
  1469. y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
  1470. }
  1471. #else
  1472. // scalar
  1473. for (int i = 0; i < n; ++i) {
  1474. y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
  1475. }
  1476. #endif
  1477. }
  1478. inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
  1479. inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
  1480. inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
  1481. inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
  1482. inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
  1483. inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
  1484. inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
  1485. inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
  1486. inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
  1487. inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
  1488. inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
  1489. inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
  1490. inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
  1491. inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
  1492. // TODO: optimize performance
  1493. inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
  1494. inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
  1495. inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
  1496. static const float GELU_COEF_A = 0.044715f;
  1497. static const float GELU_QUICK_COEF = -1.702f;
  1498. static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  1499. inline static float ggml_gelu_f32(float x) {
  1500. return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
  1501. }
  1502. inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  1503. const uint16_t * i16 = (const uint16_t *) x;
  1504. for (int i = 0; i < n; ++i) {
  1505. y[i] = ggml_table_gelu_f16[i16[i]];
  1506. }
  1507. }
  1508. #ifdef GGML_GELU_FP16
  1509. inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
  1510. uint16_t t;
  1511. for (int i = 0; i < n; ++i) {
  1512. if (x[i] <= -10.0f) {
  1513. y[i] = 0.0f;
  1514. } else if (x[i] >= 10.0f) {
  1515. y[i] = x[i];
  1516. } else {
  1517. ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
  1518. memcpy(&t, &fp16, sizeof(uint16_t));
  1519. y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
  1520. }
  1521. }
  1522. }
  1523. #else
  1524. inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
  1525. for (int i = 0; i < n; ++i) {
  1526. y[i] = ggml_gelu_f32(x[i]);
  1527. }
  1528. }
  1529. #endif
  1530. inline static float ggml_gelu_quick_f32(float x) {
  1531. return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
  1532. }
  1533. //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
  1534. // const uint16_t * i16 = (const uint16_t *) x;
  1535. // for (int i = 0; i < n; ++i) {
  1536. // y[i] = ggml_table_gelu_quick_f16[i16[i]];
  1537. // }
  1538. //}
  1539. #ifdef GGML_GELU_QUICK_FP16
  1540. inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
  1541. uint16_t t;
  1542. for (int i = 0; i < n; ++i) {
  1543. ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
  1544. memcpy(&t, &fp16, sizeof(uint16_t));
  1545. y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
  1546. }
  1547. }
  1548. #else
  1549. inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
  1550. for (int i = 0; i < n; ++i) {
  1551. y[i] = ggml_gelu_quick_f32(x[i]);
  1552. }
  1553. }
  1554. #endif
  1555. // Sigmoid Linear Unit (SiLU) function
  1556. inline static float ggml_silu_f32(float x) {
  1557. return x/(1.0f + expf(-x));
  1558. }
  1559. #if __FINITE_MATH_ONLY__
  1560. #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
  1561. #error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
  1562. #endif
  1563. #if defined(__ARM_NEON) && defined(__aarch64__)
  1564. // adapted from arm limited optimized routine
  1565. // the maximum error is 1.45358 plus 0.5 ulps
  1566. // numbers above 88.38 will flush to infinity
  1567. // numbers beneath -103.97 will flush to zero
  1568. inline static float32x4_t ggml_v_expf(float32x4_t x) {
  1569. const float32x4_t r = vdupq_n_f32(0x1.8p23f);
  1570. const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
  1571. const float32x4_t n = vsubq_f32(z, r);
  1572. const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
  1573. vdupq_n_f32(0x1.7f7d1cp-20f));
  1574. const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
  1575. const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
  1576. const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
  1577. const float32x4_t u = vmulq_f32(b, b);
  1578. const float32x4_t j = vfmaq_f32(
  1579. vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
  1580. vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
  1581. vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
  1582. if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
  1583. return vfmaq_f32(k, j, k);
  1584. const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
  1585. const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
  1586. const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
  1587. return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
  1588. vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
  1589. }
  1590. // computes silu x/(1+exp(-x)) in single precision vector
  1591. inline static float32x4_t ggml_v_silu(float32x4_t x) {
  1592. const float32x4_t one = vdupq_n_f32(1.0f);
  1593. const float32x4_t zero = vdupq_n_f32(0.0f);
  1594. const float32x4_t neg_x = vsubq_f32(zero, x);
  1595. const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
  1596. const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
  1597. return vdivq_f32(x, one_plus_exp_neg_x);
  1598. }
  1599. #elif defined(__AVX512F__) && defined(__AVX512DQ__)
  1600. // adapted from arm limited optimized routine
  1601. // the maximum error is 1.45358 plus 0.5 ulps
  1602. // numbers above 88.38 will flush to infinity
  1603. // numbers beneath -103.97 will flush to zero
  1604. inline static __m512 ggml_v_expf(__m512 x) {
  1605. const __m512 r = _mm512_set1_ps(0x1.8p23f);
  1606. const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
  1607. const __m512 n = _mm512_sub_ps(z, r);
  1608. const __m512 b =
  1609. _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
  1610. _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
  1611. const __mmask16 d =
  1612. _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
  1613. const __m512 u = _mm512_mul_ps(b, b);
  1614. const __m512 j = _mm512_fmadd_ps(
  1615. _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
  1616. _mm512_set1_ps(0x1.573e2ep-5f)),
  1617. u,
  1618. _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
  1619. _mm512_set1_ps(0x1.fffdb6p-2f))),
  1620. u,
  1621. _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
  1622. const __m512 res = _mm512_scalef_ps(j, n);
  1623. if (_mm512_kortestz(d, d))
  1624. return res;
  1625. const __m512 zero = _mm512_setzero_ps();
  1626. const __m512 alt = _mm512_mask_blend_ps(
  1627. _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
  1628. return _mm512_mask_blend_ps(d, res, alt);
  1629. }
  1630. // computes silu x/(1+exp(-x)) in single precision vector
  1631. inline static __m512 ggml_v_silu(__m512 x) {
  1632. const __m512 one = _mm512_set1_ps(1);
  1633. const __m512 zero = _mm512_setzero_ps();
  1634. const __m512 neg_x = _mm512_sub_ps(zero, x);
  1635. const __m512 exp_neg_x = ggml_v_expf(neg_x);
  1636. const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
  1637. return _mm512_div_ps(x, one_plus_exp_neg_x);
  1638. }
  1639. #elif defined(__AVX2__) && defined(__FMA__)
  1640. // adapted from arm limited optimized routine
  1641. // the maximum error is 1.45358 plus 0.5 ulps
  1642. // numbers above 88.38 will flush to infinity
  1643. // numbers beneath -103.97 will flush to zero
  1644. inline static __m256 ggml_v_expf(__m256 x) {
  1645. const __m256 r = _mm256_set1_ps(0x1.8p23f);
  1646. const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
  1647. const __m256 n = _mm256_sub_ps(z, r);
  1648. const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
  1649. _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
  1650. const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
  1651. const __m256 k = _mm256_castsi256_ps(
  1652. _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
  1653. const __m256i c = _mm256_castps_si256(
  1654. _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
  1655. _mm256_set1_ps(126), _CMP_GT_OQ));
  1656. const __m256 u = _mm256_mul_ps(b, b);
  1657. const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
  1658. _mm256_set1_ps(0x1.573e2ep-5f)), u,
  1659. _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
  1660. _mm256_set1_ps(0x1.fffdb6p-2f))),
  1661. u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
  1662. if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
  1663. return _mm256_fmadd_ps(j, k, k);
  1664. const __m256i g = _mm256_and_si256(
  1665. _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
  1666. _mm256_set1_epi32(0x82000000u));
  1667. const __m256 s1 =
  1668. _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
  1669. const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
  1670. const __m256i d = _mm256_castps_si256(
  1671. _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
  1672. _mm256_set1_ps(192), _CMP_GT_OQ));
  1673. return _mm256_or_ps(
  1674. _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
  1675. _mm256_andnot_ps(
  1676. _mm256_castsi256_ps(d),
  1677. _mm256_or_ps(
  1678. _mm256_and_ps(_mm256_castsi256_ps(c),
  1679. _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
  1680. _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
  1681. }
  1682. // computes silu x/(1+exp(-x)) in single precision vector
  1683. inline static __m256 ggml_v_silu(__m256 x) {
  1684. const __m256 one = _mm256_set1_ps(1);
  1685. const __m256 zero = _mm256_setzero_ps();
  1686. const __m256 neg_x = _mm256_sub_ps(zero, x);
  1687. const __m256 exp_neg_x = ggml_v_expf(neg_x);
  1688. const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
  1689. return _mm256_div_ps(x, one_plus_exp_neg_x);
  1690. }
  1691. #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
  1692. #if defined(__FMA__)
  1693. #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
  1694. #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
  1695. #else
  1696. #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
  1697. #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
  1698. #endif
  1699. // adapted from arm limited optimized routine
  1700. // the maximum error is 1.45358 plus 0.5 ulps
  1701. // numbers above 88.38 will flush to infinity
  1702. // numbers beneath -103.97 will flush to zero
  1703. inline static __m128 ggml_v_expf(__m128 x) {
  1704. const __m128 r = _mm_set1_ps(0x1.8p23f);
  1705. const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
  1706. const __m128 n = _mm_sub_ps(z, r);
  1707. const __m128 b =
  1708. NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
  1709. const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
  1710. const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
  1711. const __m128i c =
  1712. _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
  1713. const __m128 u = _mm_mul_ps(b, b);
  1714. const __m128 j =
  1715. MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
  1716. MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
  1717. u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
  1718. if (!_mm_movemask_epi8(c))
  1719. return MADD128(j, k, k);
  1720. const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
  1721. _mm_set1_epi32(0x82000000u));
  1722. const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
  1723. const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
  1724. const __m128i d =
  1725. _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
  1726. return _mm_or_ps(
  1727. _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
  1728. _mm_andnot_ps(_mm_castsi128_ps(d),
  1729. _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
  1730. _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
  1731. }
  1732. // computes silu x/(1+exp(-x)) in single precision vector
  1733. inline static __m128 ggml_v_silu(__m128 x) {
  1734. const __m128 one = _mm_set1_ps(1);
  1735. const __m128 zero = _mm_setzero_ps();
  1736. const __m128 neg_x = _mm_sub_ps(zero, x);
  1737. const __m128 exp_neg_x = ggml_v_expf(neg_x);
  1738. const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
  1739. return _mm_div_ps(x, one_plus_exp_neg_x);
  1740. }
  1741. #endif // __ARM_NEON / __AVX2__ / __SSE2__
  1742. static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
  1743. int i = 0;
  1744. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  1745. for (; i + 15 < n; i += 16) {
  1746. _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
  1747. }
  1748. #elif defined(__AVX2__) && defined(__FMA__)
  1749. for (; i + 7 < n; i += 8) {
  1750. _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
  1751. }
  1752. #elif defined(__SSE2__)
  1753. for (; i + 3 < n; i += 4) {
  1754. _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
  1755. }
  1756. #elif defined(__ARM_NEON) && defined(__aarch64__)
  1757. for (; i + 3 < n; i += 4) {
  1758. vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
  1759. }
  1760. #endif
  1761. for (; i < n; ++i) {
  1762. y[i] = ggml_silu_f32(x[i]);
  1763. }
  1764. }
  1765. static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
  1766. int i = 0;
  1767. ggml_float sum = 0;
  1768. #if defined(__AVX512F__) && defined(__AVX512DQ__)
  1769. for (; i + 15 < n; i += 16) {
  1770. __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
  1771. _mm512_set1_ps(max)));
  1772. _mm512_storeu_ps(y + i, val);
  1773. sum += (ggml_float)_mm512_reduce_add_ps(val);
  1774. }
  1775. #elif defined(__AVX2__) && defined(__FMA__)
  1776. for (; i + 7 < n; i += 8) {
  1777. __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
  1778. _mm256_set1_ps(max)));
  1779. _mm256_storeu_ps(y + i, val);
  1780. __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
  1781. _mm256_castps256_ps128(val));
  1782. val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
  1783. val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
  1784. sum += (ggml_float)_mm_cvtss_f32(val2);
  1785. }
  1786. #elif defined(__SSE2__)
  1787. for (; i + 3 < n; i += 4) {
  1788. __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
  1789. _mm_set1_ps(max)));
  1790. _mm_storeu_ps(y + i, val);
  1791. #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  1792. val = _mm_add_ps(val, _mm_movehl_ps(val, val));
  1793. val = _mm_add_ss(val, _mm_movehdup_ps(val));
  1794. #else
  1795. __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
  1796. val = _mm_add_ps(val, tmp);
  1797. tmp = _mm_movehl_ps(tmp, val);
  1798. val = _mm_add_ss(val, tmp);
  1799. #endif
  1800. sum += (ggml_float)_mm_cvtss_f32(val);
  1801. }
  1802. #elif defined(__ARM_NEON) && defined(__aarch64__)
  1803. for (; i + 3 < n; i += 4) {
  1804. float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
  1805. vdupq_n_f32(max)));
  1806. vst1q_f32(y + i, val);
  1807. sum += (ggml_float)vaddvq_f32(val);
  1808. }
  1809. #endif
  1810. for (; i < n; ++i) {
  1811. float val = expf(x[i] - max);
  1812. sum += (ggml_float)val;
  1813. y[i] = val;
  1814. }
  1815. return sum;
  1816. }
  1817. static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
  1818. // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
  1819. int i = 0;
  1820. ggml_float sum = 0;
  1821. for (; i < n; ++i) {
  1822. float val = x[i] - max;
  1823. y[i] = val;
  1824. sum += (ggml_float)expf(val);
  1825. }
  1826. return sum = (ggml_float)logf(sum);
  1827. }
  1828. inline static float ggml_silu_backward_f32(float x, float dy) {
  1829. const float s = 1.0f/(1.0f + expf(-x));
  1830. return dy*s*(1.0f + x*(1.0f - s));
  1831. }
  1832. inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
  1833. for (int i = 0; i < n; ++i) {
  1834. dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
  1835. }
  1836. }
  1837. inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
  1838. #ifndef GGML_USE_ACCELERATE
  1839. ggml_float sum = 0.0;
  1840. for (int i = 0; i < n; ++i) {
  1841. sum += (ggml_float)x[i];
  1842. }
  1843. *s = sum;
  1844. #else
  1845. vDSP_sve(x, 1, s, n);
  1846. #endif
  1847. }
  1848. inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
  1849. ggml_float sum = 0.0;
  1850. for (int i = 0; i < n; ++i) {
  1851. sum += (ggml_float)x[i];
  1852. }
  1853. *s = sum;
  1854. }
  1855. inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
  1856. float sum = 0.0f;
  1857. for (int i = 0; i < n; ++i) {
  1858. sum += GGML_FP16_TO_FP32(x[i]);
  1859. }
  1860. *s = sum;
  1861. }
  1862. inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
  1863. float sum = 0.0f;
  1864. for (int i = 0; i < n; ++i) {
  1865. sum += GGML_BF16_TO_FP32(x[i]);
  1866. }
  1867. *s = sum;
  1868. }
  1869. inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
  1870. #ifndef GGML_USE_ACCELERATE
  1871. float max = -INFINITY;
  1872. for (int i = 0; i < n; ++i) {
  1873. max = MAX(max, x[i]);
  1874. }
  1875. *s = max;
  1876. #else
  1877. vDSP_maxv(x, 1, s, n);
  1878. #endif
  1879. }
  1880. inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
  1881. ggml_vec_norm_f32(n, s, x);
  1882. *s = 1.f/(*s);
  1883. }
  1884. inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
  1885. float max = -INFINITY;
  1886. int idx = 0;
  1887. for (int i = 0; i < n; ++i) {
  1888. max = MAX(max, x[i]);
  1889. if (max == x[i]) { idx = i; }
  1890. }
  1891. *s = idx;
  1892. }
  1893. // Helpers for polling loops
  1894. #if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
  1895. static inline void ggml_thread_cpu_relax(void) {
  1896. __asm__ volatile("yield" ::: "memory");
  1897. }
  1898. #elif defined(__x86_64__)
  1899. static inline void ggml_thread_cpu_relax(void) {
  1900. _mm_pause();
  1901. }
  1902. #else
  1903. static inline void ggml_thread_cpu_relax(void) {;}
  1904. #endif
  1905. //
  1906. // NUMA support
  1907. //
  1908. #define GGML_NUMA_MAX_NODES 8
  1909. #define GGML_NUMA_MAX_CPUS 512
  1910. struct ggml_numa_node {
  1911. uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
  1912. uint32_t n_cpus;
  1913. };
  1914. struct ggml_numa_nodes {
  1915. enum ggml_numa_strategy numa_strategy;
  1916. struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
  1917. uint32_t n_nodes;
  1918. uint32_t total_cpus; // hardware threads on system
  1919. uint32_t current_node; // node on which main process is execting
  1920. #if defined(__gnu_linux__)
  1921. cpu_set_t cpuset; // cpuset from numactl
  1922. #else
  1923. uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
  1924. #endif
  1925. };
  1926. //
  1927. // ggml state
  1928. //
  1929. struct ggml_state {
  1930. struct ggml_numa_nodes numa;
  1931. };
  1932. // global state
  1933. static struct ggml_state g_state = {0};
  1934. static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
  1935. // TODO: move to threading file
  1936. // critical section via spin lock
  1937. void ggml_critical_section_start(void) {
  1938. while (atomic_flag_test_and_set(&g_state_critical)) {
  1939. // spin
  1940. sched_yield();
  1941. }
  1942. }
  1943. void ggml_critical_section_end(void) {
  1944. atomic_flag_clear(&g_state_critical);
  1945. }
  1946. static void ggml_barrier(struct ggml_threadpool * tp) {
  1947. int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
  1948. if (n_threads == 1) {
  1949. return;
  1950. }
  1951. #ifdef GGML_USE_OPENMP
  1952. #pragma omp barrier
  1953. #else
  1954. int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
  1955. // enter barrier (full seq-cst fence)
  1956. int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
  1957. if (n_barrier == (n_threads - 1)) {
  1958. // last thread
  1959. atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
  1960. // exit barrier (fill seq-cst fence)
  1961. atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
  1962. return;
  1963. }
  1964. // wait for other threads
  1965. while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
  1966. ggml_thread_cpu_relax();
  1967. }
  1968. // exit barrier (full seq-cst fence)
  1969. // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
  1970. #ifdef GGML_TSAN_ENABLED
  1971. atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
  1972. #else
  1973. atomic_thread_fence(memory_order_seq_cst);
  1974. #endif
  1975. #endif
  1976. }
  1977. #if defined(__gnu_linux__)
  1978. static cpu_set_t ggml_get_numa_affinity(void) {
  1979. cpu_set_t cpuset;
  1980. pthread_t thread;
  1981. thread = pthread_self();
  1982. CPU_ZERO(&cpuset);
  1983. pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
  1984. return cpuset;
  1985. }
  1986. #else
  1987. static uint32_t ggml_get_numa_affinity(void) {
  1988. return 0; // no NUMA support
  1989. }
  1990. #endif
  1991. void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
  1992. if (g_state.numa.n_nodes > 0) {
  1993. fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
  1994. return;
  1995. }
  1996. #if defined(__gnu_linux__)
  1997. struct stat st;
  1998. char path[256];
  1999. int rv;
  2000. // set numa scheme
  2001. g_state.numa.numa_strategy = numa_flag;
  2002. GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
  2003. g_state.numa.cpuset = ggml_get_numa_affinity();
  2004. // enumerate nodes
  2005. while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
  2006. rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
  2007. GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
  2008. if (stat(path, &st) != 0) { break; }
  2009. ++g_state.numa.n_nodes;
  2010. }
  2011. // enumerate CPUs
  2012. while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
  2013. rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
  2014. GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
  2015. if (stat(path, &st) != 0) { break; }
  2016. ++g_state.numa.total_cpus;
  2017. }
  2018. GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
  2019. // figure out which node we're on
  2020. uint current_cpu;
  2021. int getcpu_ret = 0;
  2022. #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
  2023. getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
  2024. #else
  2025. // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
  2026. # if !defined(SYS_getcpu) && defined(SYS_get_cpu)
  2027. # define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
  2028. # endif
  2029. getcpu_ret = syscall(SYS_getcpu, &current_cpu, &g_state.numa.current_node);
  2030. #endif
  2031. if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
  2032. g_state.numa.n_nodes = 0;
  2033. return;
  2034. }
  2035. GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
  2036. for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
  2037. struct ggml_numa_node * node = &g_state.numa.nodes[n];
  2038. GGML_PRINT_DEBUG("CPUs on node %u:", n);
  2039. node->n_cpus = 0;
  2040. for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
  2041. rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
  2042. GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
  2043. if (stat(path, &st) == 0) {
  2044. node->cpus[node->n_cpus++] = c;
  2045. GGML_PRINT_DEBUG(" %u", c);
  2046. }
  2047. }
  2048. GGML_PRINT_DEBUG("\n");
  2049. }
  2050. if (ggml_is_numa()) {
  2051. FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
  2052. if (fptr != NULL) {
  2053. char buf[42];
  2054. if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
  2055. GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
  2056. }
  2057. fclose(fptr);
  2058. }
  2059. }
  2060. #else
  2061. UNUSED(numa_flag);
  2062. // TODO
  2063. #endif
  2064. }
  2065. bool ggml_is_numa(void) {
  2066. return g_state.numa.n_nodes > 1;
  2067. }
  2068. #if defined(__ARM_ARCH)
  2069. #if defined(__linux__) && defined(__aarch64__)
  2070. #include <sys/auxv.h>
  2071. #elif defined(__APPLE__)
  2072. #include <sys/sysctl.h>
  2073. #endif
  2074. #if !defined(HWCAP2_I8MM)
  2075. #define HWCAP2_I8MM 0
  2076. #endif
  2077. static void ggml_init_arm_arch_features(void) {
  2078. #if defined(__linux__) && defined(__aarch64__)
  2079. uint32_t hwcap = getauxval(AT_HWCAP);
  2080. uint32_t hwcap2 = getauxval(AT_HWCAP2);
  2081. ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
  2082. ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
  2083. ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
  2084. #if defined(__ARM_FEATURE_SVE)
  2085. ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
  2086. #endif
  2087. #elif defined(__APPLE__)
  2088. int oldp = 0;
  2089. size_t size = sizeof(oldp);
  2090. if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
  2091. oldp = 0;
  2092. }
  2093. ggml_arm_arch_features.has_neon = oldp;
  2094. if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
  2095. oldp = 0;
  2096. }
  2097. ggml_arm_arch_features.has_i8mm = oldp;
  2098. ggml_arm_arch_features.has_sve = 0;
  2099. ggml_arm_arch_features.sve_cnt = 0;
  2100. #else
  2101. // Run-time CPU feature detection not implemented for this platform, fallback to compile time
  2102. #if defined(__ARM_NEON)
  2103. ggml_arm_arch_features.has_neon = 1;
  2104. #else
  2105. ggml_arm_arch_features.has_neon = 0;
  2106. #endif
  2107. #if defined(__ARM_FEATURE_MATMUL_INT8)
  2108. ggml_arm_arch_features.has_i8mm = 1;
  2109. #else
  2110. ggml_arm_arch_features.has_i8mm = 0;
  2111. #endif
  2112. #if defined(__ARM_FEATURE_SVE)
  2113. ggml_arm_arch_features.has_sve = 1;
  2114. ggml_arm_arch_features.sve_cnt = 16;
  2115. #else
  2116. ggml_arm_arch_features.has_sve = 0;
  2117. ggml_arm_arch_features.sve_cnt = 0;
  2118. #endif
  2119. #endif
  2120. }
  2121. #endif
  2122. struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
  2123. GGML_ASSERT(!ggml_get_no_alloc(ctx));
  2124. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
  2125. ggml_set_i32(result, value);
  2126. return result;
  2127. }
  2128. struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
  2129. GGML_ASSERT(!ggml_get_no_alloc(ctx));
  2130. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  2131. ggml_set_f32(result, value);
  2132. return result;
  2133. }
  2134. struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
  2135. const int n = ggml_nrows(tensor);
  2136. const int nc = tensor->ne[0];
  2137. const size_t n1 = tensor->nb[1];
  2138. char * const data = tensor->data;
  2139. switch (tensor->type) {
  2140. case GGML_TYPE_I8:
  2141. {
  2142. assert(tensor->nb[0] == sizeof(int8_t));
  2143. for (int i = 0; i < n; i++) {
  2144. ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
  2145. }
  2146. } break;
  2147. case GGML_TYPE_I16:
  2148. {
  2149. assert(tensor->nb[0] == sizeof(int16_t));
  2150. for (int i = 0; i < n; i++) {
  2151. ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
  2152. }
  2153. } break;
  2154. case GGML_TYPE_I32:
  2155. {
  2156. assert(tensor->nb[0] == sizeof(int32_t));
  2157. for (int i = 0; i < n; i++) {
  2158. ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
  2159. }
  2160. } break;
  2161. case GGML_TYPE_F16:
  2162. {
  2163. assert(tensor->nb[0] == sizeof(ggml_fp16_t));
  2164. for (int i = 0; i < n; i++) {
  2165. ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
  2166. }
  2167. } break;
  2168. case GGML_TYPE_BF16:
  2169. {
  2170. assert(tensor->nb[0] == sizeof(ggml_fp16_t));
  2171. for (int i = 0; i < n; i++) {
  2172. ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
  2173. }
  2174. } break;
  2175. case GGML_TYPE_F32:
  2176. {
  2177. assert(tensor->nb[0] == sizeof(float));
  2178. for (int i = 0; i < n; i++) {
  2179. ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
  2180. }
  2181. } break;
  2182. default:
  2183. {
  2184. GGML_ABORT("fatal error");
  2185. }
  2186. }
  2187. return tensor;
  2188. }
  2189. struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
  2190. const int n = ggml_nrows(tensor);
  2191. const int nc = tensor->ne[0];
  2192. const size_t n1 = tensor->nb[1];
  2193. char * const data = tensor->data;
  2194. switch (tensor->type) {
  2195. case GGML_TYPE_I8:
  2196. {
  2197. assert(tensor->nb[0] == sizeof(int8_t));
  2198. for (int i = 0; i < n; i++) {
  2199. ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
  2200. }
  2201. } break;
  2202. case GGML_TYPE_I16:
  2203. {
  2204. assert(tensor->nb[0] == sizeof(int16_t));
  2205. for (int i = 0; i < n; i++) {
  2206. ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
  2207. }
  2208. } break;
  2209. case GGML_TYPE_I32:
  2210. {
  2211. assert(tensor->nb[0] == sizeof(int32_t));
  2212. for (int i = 0; i < n; i++) {
  2213. ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
  2214. }
  2215. } break;
  2216. case GGML_TYPE_F16:
  2217. {
  2218. assert(tensor->nb[0] == sizeof(ggml_fp16_t));
  2219. for (int i = 0; i < n; i++) {
  2220. ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
  2221. }
  2222. } break;
  2223. case GGML_TYPE_BF16:
  2224. {
  2225. assert(tensor->nb[0] == sizeof(ggml_bf16_t));
  2226. for (int i = 0; i < n; i++) {
  2227. ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
  2228. }
  2229. } break;
  2230. case GGML_TYPE_F32:
  2231. {
  2232. assert(tensor->nb[0] == sizeof(float));
  2233. for (int i = 0; i < n; i++) {
  2234. ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
  2235. }
  2236. } break;
  2237. default:
  2238. {
  2239. GGML_ABORT("fatal error");
  2240. }
  2241. }
  2242. return tensor;
  2243. }
  2244. int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
  2245. if (!ggml_is_contiguous(tensor)) {
  2246. int64_t id[4] = { 0, 0, 0, 0 };
  2247. ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
  2248. return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
  2249. }
  2250. switch (tensor->type) {
  2251. case GGML_TYPE_I8:
  2252. {
  2253. GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
  2254. return ((int8_t *)(tensor->data))[i];
  2255. }
  2256. case GGML_TYPE_I16:
  2257. {
  2258. GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
  2259. return ((int16_t *)(tensor->data))[i];
  2260. }
  2261. case GGML_TYPE_I32:
  2262. {
  2263. GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
  2264. return ((int32_t *)(tensor->data))[i];
  2265. }
  2266. case GGML_TYPE_F16:
  2267. {
  2268. GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
  2269. return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
  2270. }
  2271. case GGML_TYPE_BF16:
  2272. {
  2273. GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
  2274. return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
  2275. }
  2276. case GGML_TYPE_F32:
  2277. {
  2278. GGML_ASSERT(tensor->nb[0] == sizeof(float));
  2279. return ((float *)(tensor->data))[i];
  2280. }
  2281. default:
  2282. {
  2283. GGML_ABORT("fatal error");
  2284. }
  2285. }
  2286. }
  2287. void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
  2288. if (!ggml_is_contiguous(tensor)) {
  2289. int64_t id[4] = { 0, 0, 0, 0 };
  2290. ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
  2291. ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
  2292. return;
  2293. }
  2294. switch (tensor->type) {
  2295. case GGML_TYPE_I8:
  2296. {
  2297. GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
  2298. ((int8_t *)(tensor->data))[i] = value;
  2299. } break;
  2300. case GGML_TYPE_I16:
  2301. {
  2302. GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
  2303. ((int16_t *)(tensor->data))[i] = value;
  2304. } break;
  2305. case GGML_TYPE_I32:
  2306. {
  2307. GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
  2308. ((int32_t *)(tensor->data))[i] = value;
  2309. } break;
  2310. case GGML_TYPE_F16:
  2311. {
  2312. GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
  2313. ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
  2314. } break;
  2315. case GGML_TYPE_BF16:
  2316. {
  2317. GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
  2318. ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
  2319. } break;
  2320. case GGML_TYPE_F32:
  2321. {
  2322. GGML_ASSERT(tensor->nb[0] == sizeof(float));
  2323. ((float *)(tensor->data))[i] = value;
  2324. } break;
  2325. default:
  2326. {
  2327. GGML_ABORT("fatal error");
  2328. }
  2329. }
  2330. }
  2331. int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
  2332. void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
  2333. switch (tensor->type) {
  2334. case GGML_TYPE_I8:
  2335. return ((int8_t *) data)[0];
  2336. case GGML_TYPE_I16:
  2337. return ((int16_t *) data)[0];
  2338. case GGML_TYPE_I32:
  2339. return ((int32_t *) data)[0];
  2340. case GGML_TYPE_F16:
  2341. return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
  2342. case GGML_TYPE_BF16:
  2343. return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
  2344. case GGML_TYPE_F32:
  2345. return ((float *) data)[0];
  2346. default:
  2347. GGML_ABORT("fatal error");
  2348. }
  2349. }
  2350. void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
  2351. void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
  2352. switch (tensor->type) {
  2353. case GGML_TYPE_I8:
  2354. {
  2355. ((int8_t *)(data))[0] = value;
  2356. } break;
  2357. case GGML_TYPE_I16:
  2358. {
  2359. ((int16_t *)(data))[0] = value;
  2360. } break;
  2361. case GGML_TYPE_I32:
  2362. {
  2363. ((int32_t *)(data))[0] = value;
  2364. } break;
  2365. case GGML_TYPE_F16:
  2366. {
  2367. ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
  2368. } break;
  2369. case GGML_TYPE_BF16:
  2370. {
  2371. ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
  2372. } break;
  2373. case GGML_TYPE_F32:
  2374. {
  2375. ((float *)(data))[0] = value;
  2376. } break;
  2377. default:
  2378. {
  2379. GGML_ABORT("fatal error");
  2380. }
  2381. }
  2382. }
  2383. float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
  2384. if (!ggml_is_contiguous(tensor)) {
  2385. int64_t id[4] = { 0, 0, 0, 0 };
  2386. ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
  2387. return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
  2388. }
  2389. switch (tensor->type) {
  2390. case GGML_TYPE_I8:
  2391. {
  2392. return ((int8_t *)(tensor->data))[i];
  2393. }
  2394. case GGML_TYPE_I16:
  2395. {
  2396. return ((int16_t *)(tensor->data))[i];
  2397. }
  2398. case GGML_TYPE_I32:
  2399. {
  2400. return ((int32_t *)(tensor->data))[i];
  2401. }
  2402. case GGML_TYPE_F16:
  2403. {
  2404. return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
  2405. }
  2406. case GGML_TYPE_BF16:
  2407. {
  2408. return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
  2409. }
  2410. case GGML_TYPE_F32:
  2411. {
  2412. return ((float *)(tensor->data))[i];
  2413. }
  2414. default:
  2415. {
  2416. GGML_ABORT("fatal error");
  2417. }
  2418. }
  2419. }
  2420. void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
  2421. if (!ggml_is_contiguous(tensor)) {
  2422. int64_t id[4] = { 0, 0, 0, 0 };
  2423. ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
  2424. ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
  2425. return;
  2426. }
  2427. switch (tensor->type) {
  2428. case GGML_TYPE_I8:
  2429. {
  2430. ((int8_t *)(tensor->data))[i] = value;
  2431. } break;
  2432. case GGML_TYPE_I16:
  2433. {
  2434. ((int16_t *)(tensor->data))[i] = value;
  2435. } break;
  2436. case GGML_TYPE_I32:
  2437. {
  2438. ((int32_t *)(tensor->data))[i] = value;
  2439. } break;
  2440. case GGML_TYPE_F16:
  2441. {
  2442. ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
  2443. } break;
  2444. case GGML_TYPE_BF16:
  2445. {
  2446. ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
  2447. } break;
  2448. case GGML_TYPE_F32:
  2449. {
  2450. ((float *)(tensor->data))[i] = value;
  2451. } break;
  2452. default:
  2453. {
  2454. GGML_ABORT("fatal error");
  2455. }
  2456. }
  2457. }
  2458. float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
  2459. void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
  2460. switch (tensor->type) {
  2461. case GGML_TYPE_I8:
  2462. return ((int8_t *) data)[0];
  2463. case GGML_TYPE_I16:
  2464. return ((int16_t *) data)[0];
  2465. case GGML_TYPE_I32:
  2466. return ((int32_t *) data)[0];
  2467. case GGML_TYPE_F16:
  2468. return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
  2469. case GGML_TYPE_BF16:
  2470. return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
  2471. case GGML_TYPE_F32:
  2472. return ((float *) data)[0];
  2473. default:
  2474. GGML_ABORT("fatal error");
  2475. }
  2476. }
  2477. void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
  2478. void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
  2479. switch (tensor->type) {
  2480. case GGML_TYPE_I8:
  2481. {
  2482. ((int8_t *)(data))[0] = value;
  2483. } break;
  2484. case GGML_TYPE_I16:
  2485. {
  2486. ((int16_t *)(data))[0] = value;
  2487. } break;
  2488. case GGML_TYPE_I32:
  2489. {
  2490. ((int32_t *)(data))[0] = value;
  2491. } break;
  2492. case GGML_TYPE_F16:
  2493. {
  2494. ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
  2495. } break;
  2496. case GGML_TYPE_BF16:
  2497. {
  2498. ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
  2499. } break;
  2500. case GGML_TYPE_F32:
  2501. {
  2502. ((float *)(data))[0] = value;
  2503. } break;
  2504. default:
  2505. {
  2506. GGML_ABORT("fatal error");
  2507. }
  2508. }
  2509. }
  2510. ////////////////////////////////////////////////////////////////////////////////
  2511. // ggml_compute_forward_dup
  2512. static void ggml_compute_forward_dup_same_cont(
  2513. const struct ggml_compute_params * params,
  2514. struct ggml_tensor * dst) {
  2515. const struct ggml_tensor * src0 = dst->src[0];
  2516. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  2517. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  2518. GGML_ASSERT(src0->type == dst->type);
  2519. const size_t nb0 = ggml_type_size(src0->type);
  2520. const int ith = params->ith; // thread index
  2521. const int nth = params->nth; // number of threads
  2522. // parallelize by elements
  2523. const int ne = ggml_nelements(dst);
  2524. const int dr = (ne + nth - 1) / nth;
  2525. const int ie0 = dr * ith;
  2526. const int ie1 = MIN(ie0 + dr, ne);
  2527. if (ie0 < ie1) {
  2528. memcpy(
  2529. ((char *) dst->data + ie0*nb0),
  2530. ((char *) src0->data + ie0*nb0),
  2531. (ie1 - ie0) * nb0);
  2532. }
  2533. }
  2534. static void ggml_compute_forward_dup_f16(
  2535. const struct ggml_compute_params * params,
  2536. struct ggml_tensor * dst) {
  2537. const struct ggml_tensor * src0 = dst->src[0];
  2538. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  2539. GGML_TENSOR_UNARY_OP_LOCALS
  2540. const int ith = params->ith; // thread index
  2541. const int nth = params->nth; // number of threads
  2542. // parallelize by rows
  2543. const int nr = ne01;
  2544. // number of rows per thread
  2545. const int dr = (nr + nth - 1) / nth;
  2546. // row range for this thread
  2547. const int ir0 = dr * ith;
  2548. const int ir1 = MIN(ir0 + dr, nr);
  2549. if (src0->type == dst->type &&
  2550. ne00 == ne0 &&
  2551. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  2552. // copy by rows
  2553. const size_t rs = ne00*nb00;
  2554. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2555. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2556. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  2557. memcpy(
  2558. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  2559. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  2560. rs);
  2561. }
  2562. }
  2563. }
  2564. return;
  2565. }
  2566. // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
  2567. if (ggml_is_contiguous(dst)) {
  2568. if (nb00 == sizeof(ggml_fp16_t)) {
  2569. if (dst->type == GGML_TYPE_F16) {
  2570. size_t id = 0;
  2571. const size_t rs = ne00 * nb00;
  2572. char * dst_ptr = (char *) dst->data;
  2573. for (int i03 = 0; i03 < ne03; i03++) {
  2574. for (int i02 = 0; i02 < ne02; i02++) {
  2575. id += rs * ir0;
  2576. for (int i01 = ir0; i01 < ir1; i01++) {
  2577. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  2578. memcpy(dst_ptr + id, src0_ptr, rs);
  2579. id += rs;
  2580. }
  2581. id += rs * (ne01 - ir1);
  2582. }
  2583. }
  2584. } else if (dst->type == GGML_TYPE_F32) {
  2585. size_t id = 0;
  2586. float * dst_ptr = (float *) dst->data;
  2587. for (int i03 = 0; i03 < ne03; i03++) {
  2588. for (int i02 = 0; i02 < ne02; i02++) {
  2589. id += ne00 * ir0;
  2590. for (int i01 = ir0; i01 < ir1; i01++) {
  2591. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2592. for (int i00 = 0; i00 < ne00; i00++) {
  2593. dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
  2594. id++;
  2595. }
  2596. }
  2597. id += ne00 * (ne01 - ir1);
  2598. }
  2599. }
  2600. } else if (ggml_get_type_traits(dst->type)->from_float) {
  2601. ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
  2602. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  2603. size_t id = 0;
  2604. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  2605. char * dst_ptr = (char *) dst->data;
  2606. for (int i03 = 0; i03 < ne03; i03++) {
  2607. for (int i02 = 0; i02 < ne02; i02++) {
  2608. id += rs * ir0;
  2609. for (int i01 = ir0; i01 < ir1; i01++) {
  2610. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2611. for (int i00 = 0; i00 < ne00; i00++) {
  2612. src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
  2613. }
  2614. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  2615. id += rs;
  2616. }
  2617. id += rs * (ne01 - ir1);
  2618. }
  2619. }
  2620. } else {
  2621. GGML_ABORT("fatal error"); // TODO: implement
  2622. }
  2623. } else {
  2624. //printf("%s: this is not optimal - fix me\n", __func__);
  2625. if (dst->type == GGML_TYPE_F32) {
  2626. size_t id = 0;
  2627. float * dst_ptr = (float *) dst->data;
  2628. for (int i03 = 0; i03 < ne03; i03++) {
  2629. for (int i02 = 0; i02 < ne02; i02++) {
  2630. id += ne00 * ir0;
  2631. for (int i01 = ir0; i01 < ir1; i01++) {
  2632. for (int i00 = 0; i00 < ne00; i00++) {
  2633. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2634. dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
  2635. id++;
  2636. }
  2637. }
  2638. id += ne00 * (ne01 - ir1);
  2639. }
  2640. }
  2641. } else if (dst->type == GGML_TYPE_F16) {
  2642. size_t id = 0;
  2643. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  2644. for (int i03 = 0; i03 < ne03; i03++) {
  2645. for (int i02 = 0; i02 < ne02; i02++) {
  2646. id += ne00 * ir0;
  2647. for (int i01 = ir0; i01 < ir1; i01++) {
  2648. for (int i00 = 0; i00 < ne00; i00++) {
  2649. const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2650. dst_ptr[id] = *src0_ptr;
  2651. id++;
  2652. }
  2653. }
  2654. id += ne00 * (ne01 - ir1);
  2655. }
  2656. }
  2657. } else {
  2658. GGML_ABORT("fatal error"); // TODO: implement
  2659. }
  2660. }
  2661. return;
  2662. }
  2663. // dst counters
  2664. int64_t i10 = 0;
  2665. int64_t i11 = 0;
  2666. int64_t i12 = 0;
  2667. int64_t i13 = 0;
  2668. if (dst->type == GGML_TYPE_F16) {
  2669. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2670. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2671. i10 += ne00 * ir0;
  2672. while (i10 >= ne0) {
  2673. i10 -= ne0;
  2674. if (++i11 == ne1) {
  2675. i11 = 0;
  2676. if (++i12 == ne2) {
  2677. i12 = 0;
  2678. if (++i13 == ne3) {
  2679. i13 = 0;
  2680. }
  2681. }
  2682. }
  2683. }
  2684. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  2685. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2686. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2687. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  2688. memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
  2689. if (++i10 == ne00) {
  2690. i10 = 0;
  2691. if (++i11 == ne01) {
  2692. i11 = 0;
  2693. if (++i12 == ne02) {
  2694. i12 = 0;
  2695. if (++i13 == ne03) {
  2696. i13 = 0;
  2697. }
  2698. }
  2699. }
  2700. }
  2701. }
  2702. }
  2703. i10 += ne00 * (ne01 - ir1);
  2704. while (i10 >= ne0) {
  2705. i10 -= ne0;
  2706. if (++i11 == ne1) {
  2707. i11 = 0;
  2708. if (++i12 == ne2) {
  2709. i12 = 0;
  2710. if (++i13 == ne3) {
  2711. i13 = 0;
  2712. }
  2713. }
  2714. }
  2715. }
  2716. }
  2717. }
  2718. } else if (dst->type == GGML_TYPE_F32) {
  2719. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2720. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2721. i10 += ne00 * ir0;
  2722. while (i10 >= ne0) {
  2723. i10 -= ne0;
  2724. if (++i11 == ne1) {
  2725. i11 = 0;
  2726. if (++i12 == ne2) {
  2727. i12 = 0;
  2728. if (++i13 == ne3) {
  2729. i13 = 0;
  2730. }
  2731. }
  2732. }
  2733. }
  2734. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  2735. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2736. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2737. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  2738. *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
  2739. if (++i10 == ne0) {
  2740. i10 = 0;
  2741. if (++i11 == ne1) {
  2742. i11 = 0;
  2743. if (++i12 == ne2) {
  2744. i12 = 0;
  2745. if (++i13 == ne3) {
  2746. i13 = 0;
  2747. }
  2748. }
  2749. }
  2750. }
  2751. }
  2752. }
  2753. i10 += ne00 * (ne01 - ir1);
  2754. while (i10 >= ne0) {
  2755. i10 -= ne0;
  2756. if (++i11 == ne1) {
  2757. i11 = 0;
  2758. if (++i12 == ne2) {
  2759. i12 = 0;
  2760. if (++i13 == ne3) {
  2761. i13 = 0;
  2762. }
  2763. }
  2764. }
  2765. }
  2766. }
  2767. }
  2768. } else {
  2769. GGML_ABORT("fatal error"); // TODO: implement
  2770. }
  2771. }
  2772. static void ggml_compute_forward_dup_bf16(
  2773. const struct ggml_compute_params * params,
  2774. struct ggml_tensor * dst) {
  2775. const struct ggml_tensor * src0 = dst->src[0];
  2776. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  2777. GGML_TENSOR_UNARY_OP_LOCALS
  2778. const int ith = params->ith; // thread index
  2779. const int nth = params->nth; // number of threads
  2780. // parallelize by rows
  2781. const int nr = ne01;
  2782. // number of rows per thread
  2783. const int dr = (nr + nth - 1) / nth;
  2784. // row range for this thread
  2785. const int ir0 = dr * ith;
  2786. const int ir1 = MIN(ir0 + dr, nr);
  2787. if (src0->type == dst->type &&
  2788. ne00 == ne0 &&
  2789. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  2790. // copy by rows
  2791. const size_t rs = ne00*nb00;
  2792. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2793. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2794. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  2795. memcpy(
  2796. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  2797. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  2798. rs);
  2799. }
  2800. }
  2801. }
  2802. return;
  2803. }
  2804. // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
  2805. if (ggml_is_contiguous(dst)) {
  2806. if (nb00 == sizeof(ggml_bf16_t)) {
  2807. if (dst->type == GGML_TYPE_BF16) {
  2808. size_t id = 0;
  2809. const size_t rs = ne00 * nb00;
  2810. char * dst_ptr = (char *) dst->data;
  2811. for (int i03 = 0; i03 < ne03; i03++) {
  2812. for (int i02 = 0; i02 < ne02; i02++) {
  2813. id += rs * ir0;
  2814. for (int i01 = ir0; i01 < ir1; i01++) {
  2815. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  2816. memcpy(dst_ptr + id, src0_ptr, rs);
  2817. id += rs;
  2818. }
  2819. id += rs * (ne01 - ir1);
  2820. }
  2821. }
  2822. } else if (dst->type == GGML_TYPE_F16) {
  2823. size_t id = 0;
  2824. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  2825. for (int i03 = 0; i03 < ne03; i03++) {
  2826. for (int i02 = 0; i02 < ne02; i02++) {
  2827. id += ne00 * ir0;
  2828. for (int i01 = ir0; i01 < ir1; i01++) {
  2829. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2830. for (int i00 = 0; i00 < ne00; i00++) {
  2831. dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
  2832. id++;
  2833. }
  2834. }
  2835. id += ne00 * (ne01 - ir1);
  2836. }
  2837. }
  2838. } else if (dst->type == GGML_TYPE_F32) {
  2839. size_t id = 0;
  2840. float * dst_ptr = (float *) dst->data;
  2841. for (int i03 = 0; i03 < ne03; i03++) {
  2842. for (int i02 = 0; i02 < ne02; i02++) {
  2843. id += ne00 * ir0;
  2844. for (int i01 = ir0; i01 < ir1; i01++) {
  2845. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2846. for (int i00 = 0; i00 < ne00; i00++) {
  2847. dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
  2848. id++;
  2849. }
  2850. }
  2851. id += ne00 * (ne01 - ir1);
  2852. }
  2853. }
  2854. } else if (ggml_get_type_traits(dst->type)->from_float) {
  2855. ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
  2856. float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  2857. size_t id = 0;
  2858. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  2859. char * dst_ptr = (char *) dst->data;
  2860. for (int i03 = 0; i03 < ne03; i03++) {
  2861. for (int i02 = 0; i02 < ne02; i02++) {
  2862. id += rs * ir0;
  2863. for (int i01 = ir0; i01 < ir1; i01++) {
  2864. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  2865. for (int i00 = 0; i00 < ne00; i00++) {
  2866. src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
  2867. }
  2868. quantize_row_q(src0_f32, dst_ptr + id, ne00);
  2869. id += rs;
  2870. }
  2871. id += rs * (ne01 - ir1);
  2872. }
  2873. }
  2874. } else {
  2875. GGML_ABORT("fatal error"); // TODO: implement
  2876. }
  2877. } else {
  2878. //printf("%s: this is not optimal - fix me\n", __func__);
  2879. if (dst->type == GGML_TYPE_F32) {
  2880. size_t id = 0;
  2881. float * dst_ptr = (float *) dst->data;
  2882. for (int i03 = 0; i03 < ne03; i03++) {
  2883. for (int i02 = 0; i02 < ne02; i02++) {
  2884. id += ne00 * ir0;
  2885. for (int i01 = ir0; i01 < ir1; i01++) {
  2886. for (int i00 = 0; i00 < ne00; i00++) {
  2887. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2888. dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
  2889. id++;
  2890. }
  2891. }
  2892. id += ne00 * (ne01 - ir1);
  2893. }
  2894. }
  2895. } else if (dst->type == GGML_TYPE_BF16) {
  2896. size_t id = 0;
  2897. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
  2898. for (int i03 = 0; i03 < ne03; i03++) {
  2899. for (int i02 = 0; i02 < ne02; i02++) {
  2900. id += ne00 * ir0;
  2901. for (int i01 = ir0; i01 < ir1; i01++) {
  2902. for (int i00 = 0; i00 < ne00; i00++) {
  2903. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2904. dst_ptr[id] = *src0_ptr;
  2905. id++;
  2906. }
  2907. }
  2908. id += ne00 * (ne01 - ir1);
  2909. }
  2910. }
  2911. } else if (dst->type == GGML_TYPE_F16) {
  2912. size_t id = 0;
  2913. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  2914. for (int i03 = 0; i03 < ne03; i03++) {
  2915. for (int i02 = 0; i02 < ne02; i02++) {
  2916. id += ne00 * ir0;
  2917. for (int i01 = ir0; i01 < ir1; i01++) {
  2918. for (int i00 = 0; i00 < ne00; i00++) {
  2919. const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2920. dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
  2921. id++;
  2922. }
  2923. }
  2924. id += ne00 * (ne01 - ir1);
  2925. }
  2926. }
  2927. } else {
  2928. GGML_ABORT("fatal error"); // TODO: implement
  2929. }
  2930. }
  2931. return;
  2932. }
  2933. // dst counters
  2934. int64_t i10 = 0;
  2935. int64_t i11 = 0;
  2936. int64_t i12 = 0;
  2937. int64_t i13 = 0;
  2938. if (dst->type == GGML_TYPE_BF16) {
  2939. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2940. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2941. i10 += ne00 * ir0;
  2942. while (i10 >= ne0) {
  2943. i10 -= ne0;
  2944. if (++i11 == ne1) {
  2945. i11 = 0;
  2946. if (++i12 == ne2) {
  2947. i12 = 0;
  2948. if (++i13 == ne3) {
  2949. i13 = 0;
  2950. }
  2951. }
  2952. }
  2953. }
  2954. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  2955. for (int64_t i00 = 0; i00 < ne00; i00++) {
  2956. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  2957. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  2958. memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
  2959. if (++i10 == ne00) {
  2960. i10 = 0;
  2961. if (++i11 == ne01) {
  2962. i11 = 0;
  2963. if (++i12 == ne02) {
  2964. i12 = 0;
  2965. if (++i13 == ne03) {
  2966. i13 = 0;
  2967. }
  2968. }
  2969. }
  2970. }
  2971. }
  2972. }
  2973. i10 += ne00 * (ne01 - ir1);
  2974. while (i10 >= ne0) {
  2975. i10 -= ne0;
  2976. if (++i11 == ne1) {
  2977. i11 = 0;
  2978. if (++i12 == ne2) {
  2979. i12 = 0;
  2980. if (++i13 == ne3) {
  2981. i13 = 0;
  2982. }
  2983. }
  2984. }
  2985. }
  2986. }
  2987. }
  2988. } else if (dst->type == GGML_TYPE_F16) {
  2989. for (int64_t i03 = 0; i03 < ne03; i03++) {
  2990. for (int64_t i02 = 0; i02 < ne02; i02++) {
  2991. i10 += ne00 * ir0;
  2992. while (i10 >= ne0) {
  2993. i10 -= ne0;
  2994. if (++i11 == ne1) {
  2995. i11 = 0;
  2996. if (++i12 == ne2) {
  2997. i12 = 0;
  2998. if (++i13 == ne3) {
  2999. i13 = 0;
  3000. }
  3001. }
  3002. }
  3003. }
  3004. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3005. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3006. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3007. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  3008. *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
  3009. if (++i10 == ne0) {
  3010. i10 = 0;
  3011. if (++i11 == ne1) {
  3012. i11 = 0;
  3013. if (++i12 == ne2) {
  3014. i12 = 0;
  3015. if (++i13 == ne3) {
  3016. i13 = 0;
  3017. }
  3018. }
  3019. }
  3020. }
  3021. }
  3022. }
  3023. i10 += ne00 * (ne01 - ir1);
  3024. while (i10 >= ne0) {
  3025. i10 -= ne0;
  3026. if (++i11 == ne1) {
  3027. i11 = 0;
  3028. if (++i12 == ne2) {
  3029. i12 = 0;
  3030. if (++i13 == ne3) {
  3031. i13 = 0;
  3032. }
  3033. }
  3034. }
  3035. }
  3036. }
  3037. }
  3038. } else if (dst->type == GGML_TYPE_F32) {
  3039. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3040. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3041. i10 += ne00 * ir0;
  3042. while (i10 >= ne0) {
  3043. i10 -= ne0;
  3044. if (++i11 == ne1) {
  3045. i11 = 0;
  3046. if (++i12 == ne2) {
  3047. i12 = 0;
  3048. if (++i13 == ne3) {
  3049. i13 = 0;
  3050. }
  3051. }
  3052. }
  3053. }
  3054. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3055. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3056. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3057. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  3058. *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
  3059. if (++i10 == ne0) {
  3060. i10 = 0;
  3061. if (++i11 == ne1) {
  3062. i11 = 0;
  3063. if (++i12 == ne2) {
  3064. i12 = 0;
  3065. if (++i13 == ne3) {
  3066. i13 = 0;
  3067. }
  3068. }
  3069. }
  3070. }
  3071. }
  3072. }
  3073. i10 += ne00 * (ne01 - ir1);
  3074. while (i10 >= ne0) {
  3075. i10 -= ne0;
  3076. if (++i11 == ne1) {
  3077. i11 = 0;
  3078. if (++i12 == ne2) {
  3079. i12 = 0;
  3080. if (++i13 == ne3) {
  3081. i13 = 0;
  3082. }
  3083. }
  3084. }
  3085. }
  3086. }
  3087. }
  3088. } else {
  3089. GGML_ABORT("fatal error"); // TODO: implement
  3090. }
  3091. }
  3092. static void ggml_compute_forward_dup_f32(
  3093. const struct ggml_compute_params * params,
  3094. struct ggml_tensor * dst) {
  3095. const struct ggml_tensor * src0 = dst->src[0];
  3096. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  3097. GGML_TENSOR_UNARY_OP_LOCALS
  3098. const int ith = params->ith; // thread index
  3099. const int nth = params->nth; // number of threads
  3100. // parallelize by rows
  3101. const int nr = ne01;
  3102. // number of rows per thread
  3103. const int dr = (nr + nth - 1) / nth;
  3104. // row range for this thread
  3105. const int ir0 = dr * ith;
  3106. const int ir1 = MIN(ir0 + dr, nr);
  3107. if (src0->type == dst->type &&
  3108. ne00 == ne0 &&
  3109. nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
  3110. // copy by rows
  3111. const size_t rs = ne00*nb00;
  3112. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3113. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3114. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3115. memcpy(
  3116. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  3117. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  3118. rs);
  3119. }
  3120. }
  3121. }
  3122. return;
  3123. }
  3124. if (ggml_is_contiguous(dst)) {
  3125. // TODO: simplify
  3126. if (nb00 == sizeof(float)) {
  3127. if (dst->type == GGML_TYPE_F32) {
  3128. size_t id = 0;
  3129. const size_t rs = ne00 * nb00;
  3130. char * dst_ptr = (char *) dst->data;
  3131. for (int i03 = 0; i03 < ne03; i03++) {
  3132. for (int i02 = 0; i02 < ne02; i02++) {
  3133. id += rs * ir0;
  3134. for (int i01 = ir0; i01 < ir1; i01++) {
  3135. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  3136. memcpy(dst_ptr + id, src0_ptr, rs);
  3137. id += rs;
  3138. }
  3139. id += rs * (ne01 - ir1);
  3140. }
  3141. }
  3142. } else if (ggml_get_type_traits(dst->type)->from_float) {
  3143. ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
  3144. size_t id = 0;
  3145. size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  3146. char * dst_ptr = (char *) dst->data;
  3147. for (int i03 = 0; i03 < ne03; i03++) {
  3148. for (int i02 = 0; i02 < ne02; i02++) {
  3149. id += rs * ir0;
  3150. for (int i01 = ir0; i01 < ir1; i01++) {
  3151. const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  3152. quantize_row_q(src0_ptr, dst_ptr + id, ne00);
  3153. id += rs;
  3154. }
  3155. id += rs * (ne01 - ir1);
  3156. }
  3157. }
  3158. } else {
  3159. GGML_ABORT("fatal error"); // TODO: implement
  3160. }
  3161. } else {
  3162. //printf("%s: this is not optimal - fix me\n", __func__);
  3163. if (dst->type == GGML_TYPE_F32) {
  3164. size_t id = 0;
  3165. float * dst_ptr = (float *) dst->data;
  3166. for (int i03 = 0; i03 < ne03; i03++) {
  3167. for (int i02 = 0; i02 < ne02; i02++) {
  3168. id += ne00 * ir0;
  3169. for (int i01 = ir0; i01 < ir1; i01++) {
  3170. for (int i00 = 0; i00 < ne00; i00++) {
  3171. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3172. dst_ptr[id] = *src0_ptr;
  3173. id++;
  3174. }
  3175. }
  3176. id += ne00 * (ne01 - ir1);
  3177. }
  3178. }
  3179. } else if (dst->type == GGML_TYPE_F16) {
  3180. size_t id = 0;
  3181. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
  3182. for (int i03 = 0; i03 < ne03; i03++) {
  3183. for (int i02 = 0; i02 < ne02; i02++) {
  3184. id += ne00 * ir0;
  3185. for (int i01 = ir0; i01 < ir1; i01++) {
  3186. for (int i00 = 0; i00 < ne00; i00++) {
  3187. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3188. dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
  3189. id++;
  3190. }
  3191. }
  3192. id += ne00 * (ne01 - ir1);
  3193. }
  3194. }
  3195. } else if (dst->type == GGML_TYPE_BF16) {
  3196. size_t id = 0;
  3197. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
  3198. for (int i03 = 0; i03 < ne03; i03++) {
  3199. for (int i02 = 0; i02 < ne02; i02++) {
  3200. id += ne00 * ir0;
  3201. for (int i01 = ir0; i01 < ir1; i01++) {
  3202. for (int i00 = 0; i00 < ne00; i00++) {
  3203. const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3204. dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
  3205. id++;
  3206. }
  3207. }
  3208. id += ne00 * (ne01 - ir1);
  3209. }
  3210. }
  3211. } else {
  3212. GGML_ABORT("fatal error"); // TODO: implement
  3213. }
  3214. }
  3215. return;
  3216. }
  3217. // dst counters
  3218. int64_t i10 = 0;
  3219. int64_t i11 = 0;
  3220. int64_t i12 = 0;
  3221. int64_t i13 = 0;
  3222. if (dst->type == GGML_TYPE_F32) {
  3223. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3224. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3225. i10 += ne00 * ir0;
  3226. while (i10 >= ne0) {
  3227. i10 -= ne0;
  3228. if (++i11 == ne1) {
  3229. i11 = 0;
  3230. if (++i12 == ne2) {
  3231. i12 = 0;
  3232. if (++i13 == ne3) {
  3233. i13 = 0;
  3234. }
  3235. }
  3236. }
  3237. }
  3238. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3239. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3240. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3241. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  3242. memcpy(dst_ptr, src0_ptr, sizeof(float));
  3243. if (++i10 == ne0) {
  3244. i10 = 0;
  3245. if (++i11 == ne1) {
  3246. i11 = 0;
  3247. if (++i12 == ne2) {
  3248. i12 = 0;
  3249. if (++i13 == ne3) {
  3250. i13 = 0;
  3251. }
  3252. }
  3253. }
  3254. }
  3255. }
  3256. }
  3257. i10 += ne00 * (ne01 - ir1);
  3258. while (i10 >= ne0) {
  3259. i10 -= ne0;
  3260. if (++i11 == ne1) {
  3261. i11 = 0;
  3262. if (++i12 == ne2) {
  3263. i12 = 0;
  3264. if (++i13 == ne3) {
  3265. i13 = 0;
  3266. }
  3267. }
  3268. }
  3269. }
  3270. }
  3271. }
  3272. } else if (dst->type == GGML_TYPE_F16) {
  3273. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3274. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3275. i10 += ne00 * ir0;
  3276. while (i10 >= ne0) {
  3277. i10 -= ne0;
  3278. if (++i11 == ne1) {
  3279. i11 = 0;
  3280. if (++i12 == ne2) {
  3281. i12 = 0;
  3282. if (++i13 == ne3) {
  3283. i13 = 0;
  3284. }
  3285. }
  3286. }
  3287. }
  3288. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3289. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3290. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3291. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  3292. *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
  3293. if (++i10 == ne0) {
  3294. i10 = 0;
  3295. if (++i11 == ne1) {
  3296. i11 = 0;
  3297. if (++i12 == ne2) {
  3298. i12 = 0;
  3299. if (++i13 == ne3) {
  3300. i13 = 0;
  3301. }
  3302. }
  3303. }
  3304. }
  3305. }
  3306. }
  3307. i10 += ne00 * (ne01 - ir1);
  3308. while (i10 >= ne0) {
  3309. i10 -= ne0;
  3310. if (++i11 == ne1) {
  3311. i11 = 0;
  3312. if (++i12 == ne2) {
  3313. i12 = 0;
  3314. if (++i13 == ne3) {
  3315. i13 = 0;
  3316. }
  3317. }
  3318. }
  3319. }
  3320. }
  3321. }
  3322. } else if (dst->type == GGML_TYPE_BF16) {
  3323. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3324. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3325. i10 += ne00 * ir0;
  3326. while (i10 >= ne0) {
  3327. i10 -= ne0;
  3328. if (++i11 == ne1) {
  3329. i11 = 0;
  3330. if (++i12 == ne2) {
  3331. i12 = 0;
  3332. if (++i13 == ne3) {
  3333. i13 = 0;
  3334. }
  3335. }
  3336. }
  3337. }
  3338. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3339. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3340. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3341. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  3342. *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
  3343. if (++i10 == ne0) {
  3344. i10 = 0;
  3345. if (++i11 == ne1) {
  3346. i11 = 0;
  3347. if (++i12 == ne2) {
  3348. i12 = 0;
  3349. if (++i13 == ne3) {
  3350. i13 = 0;
  3351. }
  3352. }
  3353. }
  3354. }
  3355. }
  3356. }
  3357. i10 += ne00 * (ne01 - ir1);
  3358. while (i10 >= ne0) {
  3359. i10 -= ne0;
  3360. if (++i11 == ne1) {
  3361. i11 = 0;
  3362. if (++i12 == ne2) {
  3363. i12 = 0;
  3364. if (++i13 == ne3) {
  3365. i13 = 0;
  3366. }
  3367. }
  3368. }
  3369. }
  3370. }
  3371. }
  3372. } else {
  3373. GGML_ABORT("fatal error"); // TODO: implement
  3374. }
  3375. }
  3376. // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
  3377. static void ggml_compute_forward_dup_bytes(
  3378. const struct ggml_compute_params * params,
  3379. struct ggml_tensor * dst) {
  3380. const struct ggml_tensor * src0 = dst->src[0];
  3381. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  3382. GGML_ASSERT(src0->type == dst->type);
  3383. GGML_TENSOR_UNARY_OP_LOCALS;
  3384. if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
  3385. ggml_compute_forward_dup_same_cont(params, dst);
  3386. return;
  3387. }
  3388. const size_t type_size = ggml_type_size(src0->type);
  3389. const int ith = params->ith; // thread index
  3390. const int nth = params->nth; // number of threads
  3391. // parallelize by rows
  3392. const int nr = ne01;
  3393. // number of rows per thread
  3394. const int dr = (nr + nth - 1) / nth;
  3395. // row range for this thread
  3396. const int ir0 = dr * ith;
  3397. const int ir1 = MIN(ir0 + dr, nr);
  3398. if (src0->type == dst->type &&
  3399. ne00 == ne0 &&
  3400. nb00 == type_size && nb0 == type_size) {
  3401. // copy by rows
  3402. const size_t rs = ne00 * type_size;
  3403. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3404. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3405. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3406. memcpy(
  3407. ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  3408. ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  3409. rs);
  3410. }
  3411. }
  3412. }
  3413. return;
  3414. }
  3415. if (ggml_is_contiguous(dst)) {
  3416. size_t id = 0;
  3417. char * dst_ptr = (char *) dst->data;
  3418. const size_t rs = ne00 * type_size;
  3419. if (nb00 == type_size) {
  3420. // src0 is contigous on first dimension, copy by rows
  3421. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3422. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3423. id += rs * ir0;
  3424. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3425. const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  3426. memcpy(dst_ptr + id, src0_ptr, rs);
  3427. id += rs;
  3428. }
  3429. id += rs * (ne01 - ir1);
  3430. }
  3431. }
  3432. } else {
  3433. //printf("%s: this is not optimal - fix me\n", __func__);
  3434. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3435. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3436. id += rs * ir0;
  3437. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3438. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3439. const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
  3440. memcpy(dst_ptr + id, src0_ptr, type_size);
  3441. id += type_size;
  3442. }
  3443. }
  3444. id += rs * (ne01 - ir1);
  3445. }
  3446. }
  3447. }
  3448. return;
  3449. }
  3450. // dst counters
  3451. int64_t i10 = 0;
  3452. int64_t i11 = 0;
  3453. int64_t i12 = 0;
  3454. int64_t i13 = 0;
  3455. for (int64_t i03 = 0; i03 < ne03; i03++) {
  3456. for (int64_t i02 = 0; i02 < ne02; i02++) {
  3457. i10 += ne00 * ir0;
  3458. while (i10 >= ne0) {
  3459. i10 -= ne0;
  3460. if (++i11 == ne1) {
  3461. i11 = 0;
  3462. if (++i12 == ne2) {
  3463. i12 = 0;
  3464. if (++i13 == ne3) {
  3465. i13 = 0;
  3466. }
  3467. }
  3468. }
  3469. }
  3470. for (int64_t i01 = ir0; i01 < ir1; i01++) {
  3471. for (int64_t i00 = 0; i00 < ne00; i00++) {
  3472. const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  3473. char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
  3474. memcpy(dst_ptr, src0_ptr, type_size);
  3475. if (++i10 == ne0) {
  3476. i10 = 0;
  3477. if (++i11 == ne1) {
  3478. i11 = 0;
  3479. if (++i12 == ne2) {
  3480. i12 = 0;
  3481. if (++i13 == ne3) {
  3482. i13 = 0;
  3483. }
  3484. }
  3485. }
  3486. }
  3487. }
  3488. }
  3489. i10 += ne00 * (ne01 - ir1);
  3490. while (i10 >= ne0) {
  3491. i10 -= ne0;
  3492. if (++i11 == ne1) {
  3493. i11 = 0;
  3494. if (++i12 == ne2) {
  3495. i12 = 0;
  3496. if (++i13 == ne3) {
  3497. i13 = 0;
  3498. }
  3499. }
  3500. }
  3501. }
  3502. }
  3503. }
  3504. }
  3505. static void ggml_compute_forward_dup(
  3506. const struct ggml_compute_params * params,
  3507. struct ggml_tensor * dst) {
  3508. const struct ggml_tensor * src0 = dst->src[0];
  3509. if (src0->type == dst->type) {
  3510. ggml_compute_forward_dup_bytes(params, dst);
  3511. return;
  3512. }
  3513. switch (src0->type) {
  3514. case GGML_TYPE_F16:
  3515. {
  3516. ggml_compute_forward_dup_f16(params, dst);
  3517. } break;
  3518. case GGML_TYPE_BF16:
  3519. {
  3520. ggml_compute_forward_dup_bf16(params, dst);
  3521. } break;
  3522. case GGML_TYPE_F32:
  3523. {
  3524. ggml_compute_forward_dup_f32(params, dst);
  3525. } break;
  3526. default:
  3527. {
  3528. GGML_ABORT("fatal error");
  3529. }
  3530. }
  3531. }
  3532. // ggml_compute_forward_add
  3533. static void ggml_compute_forward_add_f32(
  3534. const struct ggml_compute_params * params,
  3535. struct ggml_tensor * dst) {
  3536. const struct ggml_tensor * src0 = dst->src[0];
  3537. const struct ggml_tensor * src1 = dst->src[1];
  3538. GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
  3539. const int ith = params->ith;
  3540. const int nth = params->nth;
  3541. const int nr = ggml_nrows(src0);
  3542. GGML_TENSOR_BINARY_OP_LOCALS
  3543. GGML_ASSERT( nb0 == sizeof(float));
  3544. GGML_ASSERT(nb00 == sizeof(float));
  3545. // rows per thread
  3546. const int dr = (nr + nth - 1)/nth;
  3547. // row range for this thread
  3548. const int ir0 = dr*ith;
  3549. const int ir1 = MIN(ir0 + dr, nr);
  3550. if (nb10 == sizeof(float)) {
  3551. for (int ir = ir0; ir < ir1; ++ir) {
  3552. // src1 is broadcastable across src0 and dst in i1, i2, i3
  3553. const int64_t i03 = ir/(ne02*ne01);
  3554. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  3555. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  3556. const int64_t i13 = i03 % ne13;
  3557. const int64_t i12 = i02 % ne12;
  3558. const int64_t i11 = i01 % ne11;
  3559. const int64_t nr0 = ne00 / ne10;
  3560. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  3561. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  3562. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
  3563. for (int64_t r = 0; r < nr0; ++r) {
  3564. #ifdef GGML_USE_ACCELERATE
  3565. vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
  3566. #else
  3567. ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
  3568. #endif
  3569. }
  3570. }
  3571. } else {
  3572. // src1 is not contiguous
  3573. for (int ir = ir0; ir < ir1; ++ir) {
  3574. // src1 is broadcastable across src0 and dst in i1, i2, i3
  3575. const int64_t i03 = ir/(ne02*ne01);
  3576. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  3577. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  3578. const int64_t i13 = i03 % ne13;
  3579. const int64_t i12 = i02 % ne12;
  3580. const int64_t i11 = i01 % ne11;
  3581. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  3582. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  3583. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  3584. const int64_t i10 = i0 % ne10;
  3585. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
  3586. dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
  3587. }
  3588. }
  3589. }
  3590. }
  3591. static void ggml_compute_forward_add_f16_f32(
  3592. const struct ggml_compute_params * params,
  3593. struct ggml_tensor * dst) {
  3594. const struct ggml_tensor * src0 = dst->src[0];
  3595. const struct ggml_tensor * src1 = dst->src[1];
  3596. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  3597. const int ith = params->ith;
  3598. const int nth = params->nth;
  3599. const int nr = ggml_nrows(src0);
  3600. GGML_TENSOR_BINARY_OP_LOCALS
  3601. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  3602. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  3603. if (dst->type == GGML_TYPE_F32) {
  3604. GGML_ASSERT( nb0 == sizeof(float));
  3605. }
  3606. else {
  3607. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  3608. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  3609. }
  3610. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  3611. // rows per thread
  3612. const int dr = (nr + nth - 1)/nth;
  3613. // row range for this thread
  3614. const int ir0 = dr*ith;
  3615. const int ir1 = MIN(ir0 + dr, nr);
  3616. if (nb10 == sizeof(float)) {
  3617. if (dst->type == GGML_TYPE_F16) {
  3618. for (int ir = ir0; ir < ir1; ++ir) {
  3619. // src0, src1 and dst are same shape => same indices
  3620. const int i3 = ir/(ne2*ne1);
  3621. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3622. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3623. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3624. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3625. float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
  3626. for (int i = 0; i < ne0; i++) {
  3627. dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
  3628. }
  3629. }
  3630. } else {
  3631. for (int ir = ir0; ir < ir1; ++ir) {
  3632. // src0, src1 and dst are same shape => same indices
  3633. const int i3 = ir/(ne2*ne1);
  3634. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3635. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3636. float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3637. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3638. float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
  3639. for (int i = 0; i < ne0; i++) {
  3640. dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
  3641. }
  3642. }
  3643. }
  3644. }
  3645. else {
  3646. // src1 is not contiguous
  3647. GGML_ABORT("fatal error");
  3648. }
  3649. }
  3650. static void ggml_compute_forward_add_bf16_f32(
  3651. const struct ggml_compute_params * params,
  3652. struct ggml_tensor * dst) {
  3653. const struct ggml_tensor * src0 = dst->src[0];
  3654. const struct ggml_tensor * src1 = dst->src[1];
  3655. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  3656. const int ith = params->ith;
  3657. const int nth = params->nth;
  3658. const int nr = ggml_nrows(src0);
  3659. GGML_TENSOR_BINARY_OP_LOCALS
  3660. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  3661. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  3662. if (dst->type == GGML_TYPE_F32) {
  3663. GGML_ASSERT( nb0 == sizeof(float));
  3664. }
  3665. else {
  3666. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  3667. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  3668. }
  3669. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  3670. // rows per thread
  3671. const int dr = (nr + nth - 1)/nth;
  3672. // row range for this thread
  3673. const int ir0 = dr*ith;
  3674. const int ir1 = MIN(ir0 + dr, nr);
  3675. if (nb10 == sizeof(float)) {
  3676. if (dst->type == GGML_TYPE_BF16) {
  3677. for (int ir = ir0; ir < ir1; ++ir) {
  3678. // src0, src1 and dst are same shape => same indices
  3679. const int i3 = ir/(ne2*ne1);
  3680. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3681. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3682. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3683. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3684. float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
  3685. for (int i = 0; i < ne0; i++) {
  3686. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
  3687. }
  3688. }
  3689. } else {
  3690. for (int ir = ir0; ir < ir1; ++ir) {
  3691. // src0, src1 and dst are same shape => same indices
  3692. const int i3 = ir/(ne2*ne1);
  3693. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3694. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3695. float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3696. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3697. float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
  3698. for (int i = 0; i < ne0; i++) {
  3699. dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
  3700. }
  3701. }
  3702. }
  3703. }
  3704. else {
  3705. // src1 is not contiguous
  3706. GGML_ABORT("fatal error");
  3707. }
  3708. }
  3709. static void ggml_compute_forward_add_f16_f16(
  3710. const struct ggml_compute_params * params,
  3711. struct ggml_tensor * dst) {
  3712. const struct ggml_tensor * src0 = dst->src[0];
  3713. const struct ggml_tensor * src1 = dst->src[1];
  3714. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  3715. const int ith = params->ith;
  3716. const int nth = params->nth;
  3717. const int nr = ggml_nrows(src0);
  3718. GGML_TENSOR_BINARY_OP_LOCALS
  3719. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  3720. GGML_ASSERT(src1->type == GGML_TYPE_F16);
  3721. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  3722. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  3723. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  3724. // rows per thread
  3725. const int dr = (nr + nth - 1)/nth;
  3726. // row range for this thread
  3727. const int ir0 = dr*ith;
  3728. const int ir1 = MIN(ir0 + dr, nr);
  3729. if (nb10 == sizeof(ggml_fp16_t)) {
  3730. for (int ir = ir0; ir < ir1; ++ir) {
  3731. // src0, src1 and dst are same shape => same indices
  3732. const int i3 = ir/(ne2*ne1);
  3733. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3734. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3735. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3736. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3737. ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
  3738. for (int i = 0; i < ne0; i++) {
  3739. dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
  3740. }
  3741. }
  3742. }
  3743. else {
  3744. // src1 is not contiguous
  3745. GGML_ABORT("fatal error");
  3746. }
  3747. }
  3748. static void ggml_compute_forward_add_bf16_bf16(
  3749. const struct ggml_compute_params * params,
  3750. struct ggml_tensor * dst) {
  3751. const struct ggml_tensor * src0 = dst->src[0];
  3752. const struct ggml_tensor * src1 = dst->src[1];
  3753. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  3754. const int ith = params->ith;
  3755. const int nth = params->nth;
  3756. const int nr = ggml_nrows(src0);
  3757. GGML_TENSOR_BINARY_OP_LOCALS
  3758. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  3759. GGML_ASSERT(src1->type == GGML_TYPE_BF16);
  3760. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  3761. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  3762. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  3763. // rows per thread
  3764. const int dr = (nr + nth - 1)/nth;
  3765. // row range for this thread
  3766. const int ir0 = dr*ith;
  3767. const int ir1 = MIN(ir0 + dr, nr);
  3768. if (nb10 == sizeof(ggml_bf16_t)) {
  3769. for (int ir = ir0; ir < ir1; ++ir) {
  3770. // src0, src1 and dst are same shape => same indices
  3771. const int i3 = ir/(ne2*ne1);
  3772. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3773. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3774. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  3775. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3776. ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
  3777. for (int i = 0; i < ne0; i++) {
  3778. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
  3779. }
  3780. }
  3781. }
  3782. else {
  3783. // src1 is not contiguous
  3784. GGML_ABORT("fatal error");
  3785. }
  3786. }
  3787. static void ggml_compute_forward_add_q_f32(
  3788. const struct ggml_compute_params * params,
  3789. struct ggml_tensor * dst) {
  3790. const struct ggml_tensor * src0 = dst->src[0];
  3791. const struct ggml_tensor * src1 = dst->src[1];
  3792. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  3793. const int nr = ggml_nrows(src0);
  3794. GGML_TENSOR_BINARY_OP_LOCALS
  3795. const int ith = params->ith;
  3796. const int nth = params->nth;
  3797. const enum ggml_type type = src0->type;
  3798. const enum ggml_type dtype = dst->type;
  3799. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  3800. ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dtype)->from_float;
  3801. // we don't support permuted src0 or src1
  3802. GGML_ASSERT(nb00 == ggml_type_size(type));
  3803. GGML_ASSERT(nb10 == sizeof(float));
  3804. // dst cannot be transposed or permuted
  3805. GGML_ASSERT(nb0 <= nb1);
  3806. GGML_ASSERT(nb1 <= nb2);
  3807. GGML_ASSERT(nb2 <= nb3);
  3808. GGML_ASSERT(ggml_is_quantized(src0->type));
  3809. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  3810. // rows per thread
  3811. const int dr = (nr + nth - 1)/nth;
  3812. // row range for this thread
  3813. const int ir0 = dr*ith;
  3814. const int ir1 = MIN(ir0 + dr, nr);
  3815. float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  3816. for (int ir = ir0; ir < ir1; ++ir) {
  3817. // src0 indices
  3818. const int i03 = ir/(ne02*ne01);
  3819. const int i02 = (ir - i03*ne02*ne01)/ne01;
  3820. const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
  3821. // src1 and dst are same shape as src0 => same indices
  3822. const int i13 = i03;
  3823. const int i12 = i02;
  3824. const int i11 = i01;
  3825. const int i3 = i03;
  3826. const int i2 = i02;
  3827. const int i1 = i01;
  3828. void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
  3829. float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
  3830. void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  3831. assert(ne00 % 32 == 0);
  3832. // unquantize row from src0 to temp buffer
  3833. dequantize_row_q(src0_row, wdata, ne00);
  3834. // add src1
  3835. ggml_vec_acc_f32(ne00, wdata, src1_row);
  3836. // quantize row to dst
  3837. if (quantize_row_q != NULL) {
  3838. quantize_row_q(wdata, dst_row, ne00);
  3839. } else {
  3840. memcpy(dst_row, wdata, ne0*nb0);
  3841. }
  3842. }
  3843. }
  3844. static void ggml_compute_forward_add(
  3845. const struct ggml_compute_params * params,
  3846. struct ggml_tensor * dst) {
  3847. const struct ggml_tensor * src0 = dst->src[0];
  3848. const struct ggml_tensor * src1 = dst->src[1];
  3849. switch (src0->type) {
  3850. case GGML_TYPE_F32:
  3851. {
  3852. if (src1->type == GGML_TYPE_F32) {
  3853. ggml_compute_forward_add_f32(params, dst);
  3854. }
  3855. else {
  3856. GGML_ABORT("fatal error");
  3857. }
  3858. } break;
  3859. case GGML_TYPE_F16:
  3860. {
  3861. if (src1->type == GGML_TYPE_F16) {
  3862. ggml_compute_forward_add_f16_f16(params, dst);
  3863. }
  3864. else if (src1->type == GGML_TYPE_F32) {
  3865. ggml_compute_forward_add_f16_f32(params, dst);
  3866. }
  3867. else {
  3868. GGML_ABORT("fatal error");
  3869. }
  3870. } break;
  3871. case GGML_TYPE_BF16:
  3872. {
  3873. if (src1->type == GGML_TYPE_BF16) {
  3874. ggml_compute_forward_add_bf16_bf16(params, dst);
  3875. }
  3876. else if (src1->type == GGML_TYPE_F32) {
  3877. ggml_compute_forward_add_bf16_f32(params, dst);
  3878. }
  3879. else {
  3880. GGML_ABORT("fatal error");
  3881. }
  3882. } break;
  3883. case GGML_TYPE_Q4_0:
  3884. case GGML_TYPE_Q4_1:
  3885. case GGML_TYPE_Q5_0:
  3886. case GGML_TYPE_Q5_1:
  3887. case GGML_TYPE_Q8_0:
  3888. case GGML_TYPE_Q2_K:
  3889. case GGML_TYPE_Q3_K:
  3890. case GGML_TYPE_Q4_K:
  3891. case GGML_TYPE_Q5_K:
  3892. case GGML_TYPE_Q6_K:
  3893. case GGML_TYPE_TQ1_0:
  3894. case GGML_TYPE_TQ2_0:
  3895. case GGML_TYPE_IQ2_XXS:
  3896. case GGML_TYPE_IQ2_XS:
  3897. case GGML_TYPE_IQ3_XXS:
  3898. case GGML_TYPE_IQ1_S:
  3899. case GGML_TYPE_IQ1_M:
  3900. case GGML_TYPE_IQ4_NL:
  3901. case GGML_TYPE_IQ4_XS:
  3902. case GGML_TYPE_IQ3_S:
  3903. case GGML_TYPE_IQ2_S:
  3904. case GGML_TYPE_Q4_0_4_4:
  3905. case GGML_TYPE_Q4_0_4_8:
  3906. case GGML_TYPE_Q4_0_8_8:
  3907. {
  3908. ggml_compute_forward_add_q_f32(params, dst);
  3909. } break;
  3910. default:
  3911. {
  3912. GGML_ABORT("fatal error");
  3913. }
  3914. }
  3915. }
  3916. // ggml_compute_forward_add1
  3917. static void ggml_compute_forward_add1_f32(
  3918. const struct ggml_compute_params * params,
  3919. struct ggml_tensor * dst) {
  3920. const struct ggml_tensor * src0 = dst->src[0];
  3921. const struct ggml_tensor * src1 = dst->src[1];
  3922. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3923. GGML_ASSERT(ggml_is_scalar(src1));
  3924. const int ith = params->ith;
  3925. const int nth = params->nth;
  3926. const int nr = ggml_nrows(src0);
  3927. GGML_TENSOR_UNARY_OP_LOCALS
  3928. GGML_ASSERT( nb0 == sizeof(float));
  3929. GGML_ASSERT(nb00 == sizeof(float));
  3930. // rows per thread
  3931. const int dr = (nr + nth - 1)/nth;
  3932. // row range for this thread
  3933. const int ir0 = dr*ith;
  3934. const int ir1 = MIN(ir0 + dr, nr);
  3935. for (int ir = ir0; ir < ir1; ++ir) {
  3936. // src0 and dst are same shape => same indices
  3937. const int i3 = ir/(ne2*ne1);
  3938. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3939. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3940. #ifdef GGML_USE_ACCELERATE
  3941. UNUSED(ggml_vec_add1_f32);
  3942. vDSP_vadd(
  3943. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
  3944. (float *) ((char *) src1->data), 0,
  3945. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
  3946. ne0);
  3947. #else
  3948. ggml_vec_add1_f32(ne0,
  3949. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
  3950. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  3951. *(float *) src1->data);
  3952. #endif
  3953. }
  3954. }
  3955. static void ggml_compute_forward_add1_f16_f32(
  3956. const struct ggml_compute_params * params,
  3957. struct ggml_tensor * dst) {
  3958. const struct ggml_tensor * src0 = dst->src[0];
  3959. const struct ggml_tensor * src1 = dst->src[1];
  3960. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3961. GGML_ASSERT(ggml_is_scalar(src1));
  3962. // scalar to add
  3963. const float v = *(float *) src1->data;
  3964. const int ith = params->ith;
  3965. const int nth = params->nth;
  3966. const int nr = ggml_nrows(src0);
  3967. GGML_TENSOR_UNARY_OP_LOCALS
  3968. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  3969. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  3970. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  3971. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  3972. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  3973. // rows per thread
  3974. const int dr = (nr + nth - 1)/nth;
  3975. // row range for this thread
  3976. const int ir0 = dr*ith;
  3977. const int ir1 = MIN(ir0 + dr, nr);
  3978. for (int ir = ir0; ir < ir1; ++ir) {
  3979. // src0 and dst are same shape => same indices
  3980. const int i3 = ir/(ne2*ne1);
  3981. const int i2 = (ir - i3*ne2*ne1)/ne1;
  3982. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  3983. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  3984. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  3985. for (int i = 0; i < ne0; i++) {
  3986. dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
  3987. }
  3988. }
  3989. }
  3990. static void ggml_compute_forward_add1_f16_f16(
  3991. const struct ggml_compute_params * params,
  3992. struct ggml_tensor * dst) {
  3993. const struct ggml_tensor * src0 = dst->src[0];
  3994. const struct ggml_tensor * src1 = dst->src[1];
  3995. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  3996. GGML_ASSERT(ggml_is_scalar(src1));
  3997. // scalar to add
  3998. const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
  3999. const int ith = params->ith;
  4000. const int nth = params->nth;
  4001. const int nr = ggml_nrows(src0);
  4002. GGML_TENSOR_UNARY_OP_LOCALS
  4003. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  4004. GGML_ASSERT(src1->type == GGML_TYPE_F16);
  4005. GGML_ASSERT(dst->type == GGML_TYPE_F16);
  4006. GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  4007. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  4008. // rows per thread
  4009. const int dr = (nr + nth - 1)/nth;
  4010. // row range for this thread
  4011. const int ir0 = dr*ith;
  4012. const int ir1 = MIN(ir0 + dr, nr);
  4013. for (int ir = ir0; ir < ir1; ++ir) {
  4014. // src0 and dst are same shape => same indices
  4015. const int i3 = ir/(ne2*ne1);
  4016. const int i2 = (ir - i3*ne2*ne1)/ne1;
  4017. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  4018. ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  4019. ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  4020. for (int i = 0; i < ne0; i++) {
  4021. dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
  4022. }
  4023. }
  4024. }
  4025. static void ggml_compute_forward_add1_q_f32(
  4026. const struct ggml_compute_params * params,
  4027. struct ggml_tensor * dst) {
  4028. const struct ggml_tensor * src0 = dst->src[0];
  4029. const struct ggml_tensor * src1 = dst->src[1];
  4030. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4031. GGML_ASSERT(ggml_is_scalar(src1));
  4032. // scalar to add
  4033. const float v = *(float *) src1->data;
  4034. const int ith = params->ith;
  4035. const int nth = params->nth;
  4036. const int nr = ggml_nrows(src0);
  4037. GGML_TENSOR_UNARY_OP_LOCALS
  4038. const enum ggml_type type = src0->type;
  4039. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  4040. ggml_from_float_t const quantize_row_q = ggml_get_type_traits(type)->from_float;
  4041. // we don't support permuted src0
  4042. GGML_ASSERT(nb00 == ggml_type_size(type));
  4043. // dst cannot be transposed or permuted
  4044. GGML_ASSERT(nb0 <= nb1);
  4045. GGML_ASSERT(nb1 <= nb2);
  4046. GGML_ASSERT(nb2 <= nb3);
  4047. GGML_ASSERT(ggml_is_quantized(src0->type));
  4048. GGML_ASSERT(dst->type == src0->type);
  4049. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4050. // rows per thread
  4051. const int dr = (nr + nth - 1)/nth;
  4052. // row range for this thread
  4053. const int ir0 = dr*ith;
  4054. const int ir1 = MIN(ir0 + dr, nr);
  4055. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  4056. for (int ir = ir0; ir < ir1; ++ir) {
  4057. // src0 and dst are same shape => same indices
  4058. const int i3 = ir/(ne2*ne1);
  4059. const int i2 = (ir - i3*ne2*ne1)/ne1;
  4060. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  4061. void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
  4062. void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
  4063. assert(ne0 % 32 == 0);
  4064. // unquantize row from src0 to temp buffer
  4065. dequantize_row_q(src0_row, wdata, ne0);
  4066. // add src1
  4067. ggml_vec_acc1_f32(ne0, wdata, v);
  4068. // quantize row to dst
  4069. quantize_row_q(wdata, dst_row, ne0);
  4070. }
  4071. }
  4072. static void ggml_compute_forward_add1_bf16_f32(
  4073. const struct ggml_compute_params * params,
  4074. struct ggml_tensor * dst) {
  4075. const struct ggml_tensor * src0 = dst->src[0];
  4076. const struct ggml_tensor * src1 = dst->src[1];
  4077. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4078. GGML_ASSERT(ggml_is_scalar(src1));
  4079. // scalar to add
  4080. const float v = *(float *) src1->data;
  4081. const int ith = params->ith;
  4082. const int nth = params->nth;
  4083. const int nr = ggml_nrows(src0);
  4084. GGML_TENSOR_UNARY_OP_LOCALS
  4085. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  4086. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  4087. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  4088. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  4089. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  4090. // rows per thread
  4091. const int dr = (nr + nth - 1)/nth;
  4092. // row range for this thread
  4093. const int ir0 = dr*ith;
  4094. const int ir1 = MIN(ir0 + dr, nr);
  4095. for (int ir = ir0; ir < ir1; ++ir) {
  4096. // src0 and dst are same shape => same indices
  4097. const int i3 = ir/(ne2*ne1);
  4098. const int i2 = (ir - i3*ne2*ne1)/ne1;
  4099. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  4100. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  4101. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  4102. for (int i = 0; i < ne0; i++) {
  4103. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  4104. }
  4105. }
  4106. }
  4107. static void ggml_compute_forward_add1_bf16_bf16(
  4108. const struct ggml_compute_params * params,
  4109. struct ggml_tensor * dst) {
  4110. const struct ggml_tensor * src0 = dst->src[0];
  4111. const struct ggml_tensor * src1 = dst->src[1];
  4112. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4113. GGML_ASSERT(ggml_is_scalar(src1));
  4114. // scalar to add
  4115. const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
  4116. const int ith = params->ith;
  4117. const int nth = params->nth;
  4118. const int nr = ggml_nrows(src0);
  4119. GGML_TENSOR_UNARY_OP_LOCALS
  4120. GGML_ASSERT(src0->type == GGML_TYPE_BF16);
  4121. GGML_ASSERT(src1->type == GGML_TYPE_BF16);
  4122. GGML_ASSERT(dst->type == GGML_TYPE_BF16);
  4123. GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
  4124. GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
  4125. // rows per thread
  4126. const int dr = (nr + nth - 1)/nth;
  4127. // row range for this thread
  4128. const int ir0 = dr*ith;
  4129. const int ir1 = MIN(ir0 + dr, nr);
  4130. for (int ir = ir0; ir < ir1; ++ir) {
  4131. // src0 and dst are same shape => same indices
  4132. const int i3 = ir/(ne2*ne1);
  4133. const int i2 = (ir - i3*ne2*ne1)/ne1;
  4134. const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  4135. ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
  4136. ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  4137. for (int i = 0; i < ne0; i++) {
  4138. dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
  4139. }
  4140. }
  4141. }
  4142. static void ggml_compute_forward_add1(
  4143. const struct ggml_compute_params * params,
  4144. struct ggml_tensor * dst) {
  4145. const struct ggml_tensor * src0 = dst->src[0];
  4146. const struct ggml_tensor * src1 = dst->src[1];
  4147. switch (src0->type) {
  4148. case GGML_TYPE_F32:
  4149. {
  4150. ggml_compute_forward_add1_f32(params, dst);
  4151. } break;
  4152. case GGML_TYPE_F16:
  4153. {
  4154. if (src1->type == GGML_TYPE_F16) {
  4155. ggml_compute_forward_add1_f16_f16(params, dst);
  4156. }
  4157. else if (src1->type == GGML_TYPE_F32) {
  4158. ggml_compute_forward_add1_f16_f32(params, dst);
  4159. }
  4160. else {
  4161. GGML_ABORT("fatal error");
  4162. }
  4163. } break;
  4164. case GGML_TYPE_BF16:
  4165. {
  4166. if (src1->type == GGML_TYPE_BF16) {
  4167. ggml_compute_forward_add1_bf16_bf16(params, dst);
  4168. }
  4169. else if (src1->type == GGML_TYPE_F32) {
  4170. ggml_compute_forward_add1_bf16_f32(params, dst);
  4171. }
  4172. else {
  4173. GGML_ABORT("fatal error");
  4174. }
  4175. } break;
  4176. case GGML_TYPE_Q4_0:
  4177. case GGML_TYPE_Q4_1:
  4178. case GGML_TYPE_Q5_0:
  4179. case GGML_TYPE_Q5_1:
  4180. case GGML_TYPE_Q8_0:
  4181. case GGML_TYPE_Q8_1:
  4182. case GGML_TYPE_Q2_K:
  4183. case GGML_TYPE_Q3_K:
  4184. case GGML_TYPE_Q4_K:
  4185. case GGML_TYPE_Q5_K:
  4186. case GGML_TYPE_Q6_K:
  4187. case GGML_TYPE_TQ1_0:
  4188. case GGML_TYPE_TQ2_0:
  4189. case GGML_TYPE_IQ2_XXS:
  4190. case GGML_TYPE_IQ2_XS:
  4191. case GGML_TYPE_IQ3_XXS:
  4192. case GGML_TYPE_IQ1_S:
  4193. case GGML_TYPE_IQ1_M:
  4194. case GGML_TYPE_IQ4_NL:
  4195. case GGML_TYPE_IQ4_XS:
  4196. case GGML_TYPE_IQ3_S:
  4197. case GGML_TYPE_IQ2_S:
  4198. case GGML_TYPE_Q4_0_4_4:
  4199. case GGML_TYPE_Q4_0_4_8:
  4200. case GGML_TYPE_Q4_0_8_8:
  4201. {
  4202. ggml_compute_forward_add1_q_f32(params, dst);
  4203. } break;
  4204. default:
  4205. {
  4206. GGML_ABORT("fatal error");
  4207. }
  4208. }
  4209. }
  4210. // ggml_compute_forward_acc
  4211. static void ggml_compute_forward_acc_f32(
  4212. const struct ggml_compute_params * params,
  4213. struct ggml_tensor * dst) {
  4214. const struct ggml_tensor * src0 = dst->src[0];
  4215. const struct ggml_tensor * src1 = dst->src[1];
  4216. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4217. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  4218. // view src0 and dst with these strides and data offset inbytes during acc
  4219. // nb0 is implicitly element_size because src0 and dst are contiguous
  4220. size_t nb1 = ((int32_t *) dst->op_params)[0];
  4221. size_t nb2 = ((int32_t *) dst->op_params)[1];
  4222. size_t nb3 = ((int32_t *) dst->op_params)[2];
  4223. size_t offset = ((int32_t *) dst->op_params)[3];
  4224. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  4225. if (!inplace) {
  4226. if (params->ith == 0) {
  4227. // memcpy needs to be synchronized across threads to avoid race conditions.
  4228. // => do it in INIT phase
  4229. memcpy(
  4230. ((char *) dst->data),
  4231. ((char *) src0->data),
  4232. ggml_nbytes(dst));
  4233. }
  4234. ggml_barrier(params->threadpool);
  4235. }
  4236. const int ith = params->ith;
  4237. const int nth = params->nth;
  4238. const int nr = ggml_nrows(src1);
  4239. const int nc = src1->ne[0];
  4240. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  4241. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  4242. // src0 and dst as viewed during acc
  4243. const size_t nb0 = ggml_element_size(src0);
  4244. const size_t nb00 = nb0;
  4245. const size_t nb01 = nb1;
  4246. const size_t nb02 = nb2;
  4247. const size_t nb03 = nb3;
  4248. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
  4249. GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
  4250. GGML_ASSERT(nb10 == sizeof(float));
  4251. // rows per thread
  4252. const int dr = (nr + nth - 1)/nth;
  4253. // row range for this thread
  4254. const int ir0 = dr*ith;
  4255. const int ir1 = MIN(ir0 + dr, nr);
  4256. for (int ir = ir0; ir < ir1; ++ir) {
  4257. // src0 and dst are viewed with shape of src1 and offset
  4258. // => same indices
  4259. const int i3 = ir/(ne12*ne11);
  4260. const int i2 = (ir - i3*ne12*ne11)/ne11;
  4261. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  4262. #ifdef GGML_USE_ACCELERATE
  4263. vDSP_vadd(
  4264. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
  4265. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
  4266. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
  4267. #else
  4268. ggml_vec_add_f32(nc,
  4269. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  4270. (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
  4271. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  4272. #endif
  4273. }
  4274. }
  4275. static void ggml_compute_forward_acc(
  4276. const struct ggml_compute_params * params,
  4277. struct ggml_tensor * dst) {
  4278. const struct ggml_tensor * src0 = dst->src[0];
  4279. switch (src0->type) {
  4280. case GGML_TYPE_F32:
  4281. {
  4282. ggml_compute_forward_acc_f32(params, dst);
  4283. } break;
  4284. case GGML_TYPE_F16:
  4285. case GGML_TYPE_BF16:
  4286. case GGML_TYPE_Q4_0:
  4287. case GGML_TYPE_Q4_1:
  4288. case GGML_TYPE_Q5_0:
  4289. case GGML_TYPE_Q5_1:
  4290. case GGML_TYPE_Q8_0:
  4291. case GGML_TYPE_Q8_1:
  4292. case GGML_TYPE_Q2_K:
  4293. case GGML_TYPE_Q3_K:
  4294. case GGML_TYPE_Q4_K:
  4295. case GGML_TYPE_Q5_K:
  4296. case GGML_TYPE_Q6_K:
  4297. case GGML_TYPE_TQ1_0:
  4298. case GGML_TYPE_TQ2_0:
  4299. case GGML_TYPE_IQ2_XXS:
  4300. case GGML_TYPE_IQ2_XS:
  4301. case GGML_TYPE_IQ3_XXS:
  4302. case GGML_TYPE_IQ1_S:
  4303. case GGML_TYPE_IQ1_M:
  4304. case GGML_TYPE_IQ4_NL:
  4305. case GGML_TYPE_IQ4_XS:
  4306. case GGML_TYPE_IQ3_S:
  4307. case GGML_TYPE_IQ2_S:
  4308. case GGML_TYPE_Q4_0_4_4:
  4309. case GGML_TYPE_Q4_0_4_8:
  4310. case GGML_TYPE_Q4_0_8_8:
  4311. default:
  4312. {
  4313. GGML_ABORT("fatal error");
  4314. }
  4315. }
  4316. }
  4317. // ggml_compute_forward_sub
  4318. static void ggml_compute_forward_sub_f32(
  4319. const struct ggml_compute_params * params,
  4320. struct ggml_tensor * dst) {
  4321. const struct ggml_tensor * src0 = dst->src[0];
  4322. const struct ggml_tensor * src1 = dst->src[1];
  4323. assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
  4324. const int ith = params->ith;
  4325. const int nth = params->nth;
  4326. const int nr = ggml_nrows(src0);
  4327. GGML_TENSOR_BINARY_OP_LOCALS
  4328. GGML_ASSERT( nb0 == sizeof(float));
  4329. GGML_ASSERT(nb00 == sizeof(float));
  4330. // rows per thread
  4331. const int dr = (nr + nth - 1)/nth;
  4332. // row range for this thread
  4333. const int ir0 = dr*ith;
  4334. const int ir1 = MIN(ir0 + dr, nr);
  4335. if (nb10 == sizeof(float)) {
  4336. for (int ir = ir0; ir < ir1; ++ir) {
  4337. // src1 is broadcastable across src0 and dst in i1, i2, i3
  4338. const int64_t i03 = ir/(ne02*ne01);
  4339. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  4340. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  4341. const int64_t i13 = i03 % ne13;
  4342. const int64_t i12 = i02 % ne12;
  4343. const int64_t i11 = i01 % ne11;
  4344. const int64_t nr0 = ne00 / ne10;
  4345. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  4346. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  4347. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
  4348. for (int64_t r = 0; r < nr0; ++r) {
  4349. #ifdef GGML_USE_ACCELERATE
  4350. vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
  4351. #else
  4352. ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
  4353. #endif
  4354. }
  4355. }
  4356. } else {
  4357. // src1 is not contiguous
  4358. for (int ir = ir0; ir < ir1; ++ir) {
  4359. // src1 is broadcastable across src0 and dst in i1, i2, i3
  4360. const int64_t i03 = ir/(ne02*ne01);
  4361. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  4362. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  4363. const int64_t i13 = i03 % ne13;
  4364. const int64_t i12 = i02 % ne12;
  4365. const int64_t i11 = i01 % ne11;
  4366. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  4367. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  4368. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  4369. const int64_t i10 = i0 % ne10;
  4370. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
  4371. dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
  4372. }
  4373. }
  4374. }
  4375. }
  4376. static void ggml_compute_forward_sub(
  4377. const struct ggml_compute_params * params,
  4378. struct ggml_tensor * dst) {
  4379. const struct ggml_tensor * src0 = dst->src[0];
  4380. switch (src0->type) {
  4381. case GGML_TYPE_F32:
  4382. {
  4383. ggml_compute_forward_sub_f32(params, dst);
  4384. } break;
  4385. default:
  4386. {
  4387. GGML_ABORT("fatal error");
  4388. }
  4389. }
  4390. }
  4391. // ggml_compute_forward_mul
  4392. static void ggml_compute_forward_mul_f32(
  4393. const struct ggml_compute_params * params,
  4394. struct ggml_tensor * dst) {
  4395. const struct ggml_tensor * src0 = dst->src[0];
  4396. const struct ggml_tensor * src1 = dst->src[1];
  4397. GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
  4398. const int ith = params->ith;
  4399. const int nth = params->nth;
  4400. const int64_t nr = ggml_nrows(src0);
  4401. GGML_TENSOR_BINARY_OP_LOCALS
  4402. GGML_ASSERT( nb0 == sizeof(float));
  4403. GGML_ASSERT(nb00 == sizeof(float));
  4404. if (nb10 == sizeof(float)) {
  4405. for (int64_t ir = ith; ir < nr; ir += nth) {
  4406. // src0 and dst are same shape => same indices
  4407. const int64_t i03 = ir/(ne02*ne01);
  4408. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  4409. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  4410. const int64_t i13 = i03 % ne13;
  4411. const int64_t i12 = i02 % ne12;
  4412. const int64_t i11 = i01 % ne11;
  4413. const int64_t nr0 = ne00 / ne10;
  4414. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  4415. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  4416. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
  4417. for (int64_t r = 0 ; r < nr0; ++r) {
  4418. #ifdef GGML_USE_ACCELERATE
  4419. UNUSED(ggml_vec_mul_f32);
  4420. vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
  4421. #else
  4422. ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
  4423. #endif
  4424. }
  4425. }
  4426. } else {
  4427. // src1 is not contiguous
  4428. for (int64_t ir = ith; ir < nr; ir += nth) {
  4429. // src0 and dst are same shape => same indices
  4430. // src1 is broadcastable across src0 and dst in i1, i2, i3
  4431. const int64_t i03 = ir/(ne02*ne01);
  4432. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  4433. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  4434. const int64_t i13 = i03 % ne13;
  4435. const int64_t i12 = i02 % ne12;
  4436. const int64_t i11 = i01 % ne11;
  4437. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  4438. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  4439. for (int64_t i0 = 0; i0 < ne00; ++i0) {
  4440. const int64_t i10 = i0 % ne10;
  4441. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
  4442. dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
  4443. }
  4444. }
  4445. }
  4446. }
  4447. static void ggml_compute_forward_mul(
  4448. const struct ggml_compute_params * params,
  4449. struct ggml_tensor * dst) {
  4450. const struct ggml_tensor * src0 = dst->src[0];
  4451. const struct ggml_tensor * src1 = dst->src[1];
  4452. GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
  4453. switch (src0->type) {
  4454. case GGML_TYPE_F32:
  4455. {
  4456. ggml_compute_forward_mul_f32(params, dst);
  4457. } break;
  4458. default:
  4459. {
  4460. GGML_ABORT("fatal error");
  4461. }
  4462. }
  4463. }
  4464. // ggml_compute_forward_div
  4465. static void ggml_compute_forward_div_f32(
  4466. const struct ggml_compute_params * params,
  4467. struct ggml_tensor * dst) {
  4468. const struct ggml_tensor * src0 = dst->src[0];
  4469. const struct ggml_tensor * src1 = dst->src[1];
  4470. GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
  4471. const int ith = params->ith;
  4472. const int nth = params->nth;
  4473. const int64_t nr = ggml_nrows(src0);
  4474. GGML_TENSOR_BINARY_OP_LOCALS
  4475. GGML_ASSERT( nb0 == sizeof(float));
  4476. GGML_ASSERT(nb00 == sizeof(float));
  4477. if (nb10 == sizeof(float)) {
  4478. for (int64_t ir = ith; ir < nr; ir += nth) {
  4479. // src0 and dst are same shape => same indices
  4480. const int64_t i03 = ir/(ne02*ne01);
  4481. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  4482. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  4483. const int64_t i13 = i03 % ne13;
  4484. const int64_t i12 = i02 % ne12;
  4485. const int64_t i11 = i01 % ne11;
  4486. const int64_t nr0 = ne00 / ne10;
  4487. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  4488. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  4489. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
  4490. for (int64_t r = 0; r < nr0; ++r) {
  4491. #ifdef GGML_USE_ACCELERATE
  4492. UNUSED(ggml_vec_div_f32);
  4493. vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
  4494. #else
  4495. ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
  4496. #endif
  4497. }
  4498. }
  4499. } else {
  4500. // src1 is not contiguous
  4501. for (int64_t ir = ith; ir < nr; ir += nth) {
  4502. // src0 and dst are same shape => same indices
  4503. // src1 is broadcastable across src0 and dst in i1, i2, i3
  4504. const int64_t i03 = ir/(ne02*ne01);
  4505. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  4506. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  4507. const int64_t i13 = i03 % ne13;
  4508. const int64_t i12 = i02 % ne12;
  4509. const int64_t i11 = i01 % ne11;
  4510. float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
  4511. float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
  4512. for (int64_t i0 = 0; i0 < ne00; ++i0) {
  4513. const int64_t i10 = i0 % ne10;
  4514. float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
  4515. dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
  4516. }
  4517. }
  4518. }
  4519. }
  4520. static void ggml_compute_forward_div(
  4521. const struct ggml_compute_params * params,
  4522. struct ggml_tensor * dst) {
  4523. const struct ggml_tensor * src0 = dst->src[0];
  4524. switch (src0->type) {
  4525. case GGML_TYPE_F32:
  4526. {
  4527. ggml_compute_forward_div_f32(params, dst);
  4528. } break;
  4529. default:
  4530. {
  4531. GGML_ABORT("fatal error");
  4532. }
  4533. }
  4534. }
  4535. // ggml_compute_forward_sqr
  4536. static void ggml_compute_forward_sqr_f32(
  4537. const struct ggml_compute_params * params,
  4538. struct ggml_tensor * dst) {
  4539. const struct ggml_tensor * src0 = dst->src[0];
  4540. if (params->ith != 0) {
  4541. return;
  4542. }
  4543. assert(ggml_are_same_shape(src0, dst));
  4544. const int n = ggml_nrows(src0);
  4545. const int nc = src0->ne[0];
  4546. assert( dst->nb[0] == sizeof(float));
  4547. assert(src0->nb[0] == sizeof(float));
  4548. for (int i = 0; i < n; i++) {
  4549. ggml_vec_sqr_f32(nc,
  4550. (float *) ((char *) dst->data + i*( dst->nb[1])),
  4551. (float *) ((char *) src0->data + i*(src0->nb[1])));
  4552. }
  4553. }
  4554. static void ggml_compute_forward_sqr(
  4555. const struct ggml_compute_params * params,
  4556. struct ggml_tensor * dst) {
  4557. const struct ggml_tensor * src0 = dst->src[0];
  4558. switch (src0->type) {
  4559. case GGML_TYPE_F32:
  4560. {
  4561. ggml_compute_forward_sqr_f32(params, dst);
  4562. } break;
  4563. default:
  4564. {
  4565. GGML_ABORT("fatal error");
  4566. }
  4567. }
  4568. }
  4569. // ggml_compute_forward_sqrt
  4570. static void ggml_compute_forward_sqrt_f32(
  4571. const struct ggml_compute_params * params,
  4572. struct ggml_tensor * dst) {
  4573. const struct ggml_tensor * src0 = dst->src[0];
  4574. if (params->ith != 0) {
  4575. return;
  4576. }
  4577. assert(ggml_are_same_shape(src0, dst));
  4578. const int n = ggml_nrows(src0);
  4579. const int nc = src0->ne[0];
  4580. assert( dst->nb[0] == sizeof(float));
  4581. assert(src0->nb[0] == sizeof(float));
  4582. for (int i = 0; i < n; i++) {
  4583. ggml_vec_sqrt_f32(nc,
  4584. (float *) ((char *) dst->data + i*( dst->nb[1])),
  4585. (float *) ((char *) src0->data + i*(src0->nb[1])));
  4586. }
  4587. }
  4588. static void ggml_compute_forward_sqrt(
  4589. const struct ggml_compute_params * params,
  4590. struct ggml_tensor * dst) {
  4591. const struct ggml_tensor * src0 = dst->src[0];
  4592. switch (src0->type) {
  4593. case GGML_TYPE_F32:
  4594. {
  4595. ggml_compute_forward_sqrt_f32(params, dst);
  4596. } break;
  4597. default:
  4598. {
  4599. GGML_ABORT("fatal error");
  4600. }
  4601. }
  4602. }
  4603. // ggml_compute_forward_log
  4604. static void ggml_compute_forward_log_f32(
  4605. const struct ggml_compute_params * params,
  4606. struct ggml_tensor * dst) {
  4607. const struct ggml_tensor * src0 = dst->src[0];
  4608. if (params->ith != 0) {
  4609. return;
  4610. }
  4611. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4612. const int n = ggml_nrows(src0);
  4613. const int nc = src0->ne[0];
  4614. GGML_ASSERT( dst->nb[0] == sizeof(float));
  4615. GGML_ASSERT(src0->nb[0] == sizeof(float));
  4616. for (int i = 0; i < n; i++) {
  4617. ggml_vec_log_f32(nc,
  4618. (float *) ((char *) dst->data + i*( dst->nb[1])),
  4619. (float *) ((char *) src0->data + i*(src0->nb[1])));
  4620. }
  4621. }
  4622. static void ggml_compute_forward_log(
  4623. const struct ggml_compute_params * params,
  4624. struct ggml_tensor * dst) {
  4625. const struct ggml_tensor * src0 = dst->src[0];
  4626. switch (src0->type) {
  4627. case GGML_TYPE_F32:
  4628. {
  4629. ggml_compute_forward_log_f32(params, dst);
  4630. } break;
  4631. default:
  4632. {
  4633. GGML_ABORT("fatal error");
  4634. }
  4635. }
  4636. }
  4637. // ggml_compute_forward_sin
  4638. static void ggml_compute_forward_sin_f32(
  4639. const struct ggml_compute_params * params,
  4640. struct ggml_tensor * dst) {
  4641. const struct ggml_tensor * src0 = dst->src[0];
  4642. if (params->ith != 0) {
  4643. return;
  4644. }
  4645. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4646. const int n = ggml_nrows(src0);
  4647. const int nc = src0->ne[0];
  4648. GGML_ASSERT( dst->nb[0] == sizeof(float));
  4649. GGML_ASSERT(src0->nb[0] == sizeof(float));
  4650. for (int i = 0; i < n; i++) {
  4651. ggml_vec_sin_f32(nc,
  4652. (float *) ((char *) dst->data + i*( dst->nb[1])),
  4653. (float *) ((char *) src0->data + i*(src0->nb[1])));
  4654. }
  4655. }
  4656. static void ggml_compute_forward_sin(
  4657. const struct ggml_compute_params * params,
  4658. struct ggml_tensor * dst) {
  4659. const struct ggml_tensor * src0 = dst->src[0];
  4660. switch (src0->type) {
  4661. case GGML_TYPE_F32:
  4662. {
  4663. ggml_compute_forward_sin_f32(params, dst);
  4664. } break;
  4665. default:
  4666. {
  4667. GGML_ABORT("fatal error");
  4668. }
  4669. }
  4670. }
  4671. // ggml_compute_forward_cos
  4672. static void ggml_compute_forward_cos_f32(
  4673. const struct ggml_compute_params * params,
  4674. struct ggml_tensor * dst) {
  4675. const struct ggml_tensor * src0 = dst->src[0];
  4676. if (params->ith != 0) {
  4677. return;
  4678. }
  4679. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  4680. const int n = ggml_nrows(src0);
  4681. const int nc = src0->ne[0];
  4682. GGML_ASSERT( dst->nb[0] == sizeof(float));
  4683. GGML_ASSERT(src0->nb[0] == sizeof(float));
  4684. for (int i = 0; i < n; i++) {
  4685. ggml_vec_cos_f32(nc,
  4686. (float *) ((char *) dst->data + i*( dst->nb[1])),
  4687. (float *) ((char *) src0->data + i*(src0->nb[1])));
  4688. }
  4689. }
  4690. static void ggml_compute_forward_cos(
  4691. const struct ggml_compute_params * params,
  4692. struct ggml_tensor * dst) {
  4693. const struct ggml_tensor * src0 = dst->src[0];
  4694. switch (src0->type) {
  4695. case GGML_TYPE_F32:
  4696. {
  4697. ggml_compute_forward_cos_f32(params, dst);
  4698. } break;
  4699. default:
  4700. {
  4701. GGML_ABORT("fatal error");
  4702. }
  4703. }
  4704. }
  4705. // ggml_compute_forward_sum
  4706. static void ggml_compute_forward_sum_f32(
  4707. const struct ggml_compute_params * params,
  4708. struct ggml_tensor * dst) {
  4709. const struct ggml_tensor * src0 = dst->src[0];
  4710. if (params->ith != 0) {
  4711. return;
  4712. }
  4713. assert(ggml_is_scalar(dst));
  4714. assert(src0->nb[0] == sizeof(float));
  4715. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  4716. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  4717. ggml_float sum = 0;
  4718. ggml_float row_sum = 0;
  4719. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4720. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4721. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4722. ggml_vec_sum_f32_ggf(ne00,
  4723. &row_sum,
  4724. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  4725. sum += row_sum;
  4726. }
  4727. }
  4728. }
  4729. ((float *) dst->data)[0] = sum;
  4730. }
  4731. static void ggml_compute_forward_sum_f16(
  4732. const struct ggml_compute_params * params,
  4733. struct ggml_tensor * dst) {
  4734. const struct ggml_tensor * src0 = dst->src[0];
  4735. if (params->ith != 0) {
  4736. return;
  4737. }
  4738. assert(ggml_is_scalar(dst));
  4739. assert(src0->nb[0] == sizeof(ggml_fp16_t));
  4740. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  4741. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  4742. float sum = 0;
  4743. float row_sum = 0;
  4744. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4745. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4746. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4747. ggml_vec_sum_f16_ggf(ne00,
  4748. &row_sum,
  4749. (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  4750. sum += row_sum;
  4751. }
  4752. }
  4753. }
  4754. ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
  4755. }
  4756. static void ggml_compute_forward_sum_bf16(
  4757. const struct ggml_compute_params * params,
  4758. struct ggml_tensor * dst) {
  4759. const struct ggml_tensor * src0 = dst->src[0];
  4760. if (params->ith != 0) {
  4761. return;
  4762. }
  4763. assert(ggml_is_scalar(dst));
  4764. assert(src0->nb[0] == sizeof(ggml_bf16_t));
  4765. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  4766. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
  4767. float sum = 0;
  4768. float row_sum = 0;
  4769. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4770. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4771. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4772. ggml_vec_sum_bf16_ggf(ne00,
  4773. &row_sum,
  4774. (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
  4775. sum += row_sum;
  4776. }
  4777. }
  4778. }
  4779. ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
  4780. }
  4781. static void ggml_compute_forward_sum(
  4782. const struct ggml_compute_params * params,
  4783. struct ggml_tensor * dst) {
  4784. const struct ggml_tensor * src0 = dst->src[0];
  4785. switch (src0->type) {
  4786. case GGML_TYPE_F32:
  4787. {
  4788. ggml_compute_forward_sum_f32(params, dst);
  4789. } break;
  4790. case GGML_TYPE_F16:
  4791. {
  4792. ggml_compute_forward_sum_f16(params, dst);
  4793. } break;
  4794. case GGML_TYPE_BF16:
  4795. {
  4796. ggml_compute_forward_sum_bf16(params, dst);
  4797. } break;
  4798. default:
  4799. {
  4800. GGML_ABORT("fatal error");
  4801. }
  4802. }
  4803. }
  4804. // ggml_compute_forward_sum_rows
  4805. static void ggml_compute_forward_sum_rows_f32(
  4806. const struct ggml_compute_params * params,
  4807. struct ggml_tensor * dst) {
  4808. const struct ggml_tensor * src0 = dst->src[0];
  4809. if (params->ith != 0) {
  4810. return;
  4811. }
  4812. GGML_ASSERT(src0->nb[0] == sizeof(float));
  4813. GGML_ASSERT(dst->nb[0] == sizeof(float));
  4814. GGML_TENSOR_UNARY_OP_LOCALS
  4815. GGML_ASSERT(ne0 == 1);
  4816. GGML_ASSERT(ne1 == ne01);
  4817. GGML_ASSERT(ne2 == ne02);
  4818. GGML_ASSERT(ne3 == ne03);
  4819. for (int64_t i3 = 0; i3 < ne03; i3++) {
  4820. for (int64_t i2 = 0; i2 < ne02; i2++) {
  4821. for (int64_t i1 = 0; i1 < ne01; i1++) {
  4822. float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
  4823. float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
  4824. float row_sum = 0;
  4825. ggml_vec_sum_f32(ne00, &row_sum, src_row);
  4826. dst_row[0] = row_sum;
  4827. }
  4828. }
  4829. }
  4830. }
  4831. static void ggml_compute_forward_sum_rows(
  4832. const struct ggml_compute_params * params,
  4833. struct ggml_tensor * dst) {
  4834. const struct ggml_tensor * src0 = dst->src[0];
  4835. switch (src0->type) {
  4836. case GGML_TYPE_F32:
  4837. {
  4838. ggml_compute_forward_sum_rows_f32(params, dst);
  4839. } break;
  4840. default:
  4841. {
  4842. GGML_ABORT("fatal error");
  4843. }
  4844. }
  4845. }
  4846. // ggml_compute_forward_mean
  4847. static void ggml_compute_forward_mean_f32(
  4848. const struct ggml_compute_params * params,
  4849. struct ggml_tensor * dst) {
  4850. const struct ggml_tensor * src0 = dst->src[0];
  4851. if (params->ith != 0) {
  4852. return;
  4853. }
  4854. assert(src0->nb[0] == sizeof(float));
  4855. GGML_TENSOR_UNARY_OP_LOCALS
  4856. assert(ne0 == 1);
  4857. assert(ne1 == ne01);
  4858. assert(ne2 == ne02);
  4859. assert(ne3 == ne03);
  4860. UNUSED(ne0);
  4861. UNUSED(ne1);
  4862. UNUSED(ne2);
  4863. UNUSED(ne3);
  4864. for (int64_t i03 = 0; i03 < ne03; i03++) {
  4865. for (int64_t i02 = 0; i02 < ne02; i02++) {
  4866. for (int64_t i01 = 0; i01 < ne01; i01++) {
  4867. ggml_vec_sum_f32(ne00,
  4868. (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
  4869. (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
  4870. *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
  4871. }
  4872. }
  4873. }
  4874. }
  4875. static void ggml_compute_forward_mean(
  4876. const struct ggml_compute_params * params,
  4877. struct ggml_tensor * dst) {
  4878. const struct ggml_tensor * src0 = dst->src[0];
  4879. switch (src0->type) {
  4880. case GGML_TYPE_F32:
  4881. {
  4882. ggml_compute_forward_mean_f32(params, dst);
  4883. } break;
  4884. default:
  4885. {
  4886. GGML_ABORT("fatal error");
  4887. }
  4888. }
  4889. }
  4890. // ggml_compute_forward_argmax
  4891. static void ggml_compute_forward_argmax_f32(
  4892. const struct ggml_compute_params * params,
  4893. struct ggml_tensor * dst) {
  4894. const struct ggml_tensor * src0 = dst->src[0];
  4895. if (params->ith != 0) {
  4896. return;
  4897. }
  4898. assert(src0->nb[0] == sizeof(float));
  4899. assert(dst->nb[0] == sizeof(float));
  4900. const int64_t ne00 = src0->ne[0];
  4901. const int64_t ne01 = src0->ne[1];
  4902. const size_t nb01 = src0->nb[1];
  4903. const size_t nb0 = dst->nb[0];
  4904. for (int64_t i1 = 0; i1 < ne01; i1++) {
  4905. float * src = (float *) ((char *) src0->data + i1*nb01);
  4906. int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
  4907. int v = 0;
  4908. ggml_vec_argmax_f32(ne00, &v, src);
  4909. dst_[0] = v;
  4910. }
  4911. }
  4912. static void ggml_compute_forward_argmax(
  4913. const struct ggml_compute_params * params,
  4914. struct ggml_tensor * dst) {
  4915. const struct ggml_tensor * src0 = dst->src[0];
  4916. switch (src0->type) {
  4917. case GGML_TYPE_F32:
  4918. {
  4919. ggml_compute_forward_argmax_f32(params, dst);
  4920. } break;
  4921. default:
  4922. {
  4923. GGML_ABORT("fatal error");
  4924. }
  4925. }
  4926. }
  4927. // ggml_compute_forward_count_equal
  4928. static void ggml_compute_forward_count_equal_i32(
  4929. const struct ggml_compute_params * params,
  4930. struct ggml_tensor * dst) {
  4931. const struct ggml_tensor * src0 = dst->src[0];
  4932. const struct ggml_tensor * src1 = dst->src[1];
  4933. GGML_TENSOR_BINARY_OP_LOCALS;
  4934. GGML_ASSERT(src0->type == GGML_TYPE_I32);
  4935. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  4936. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  4937. GGML_ASSERT(ggml_is_scalar(dst));
  4938. GGML_ASSERT(dst->type == GGML_TYPE_I64);
  4939. const int64_t nr = ggml_nrows(src0);
  4940. const int ith = params->ith;
  4941. const int nth = params->nth;
  4942. int64_t * sums = (int64_t *) params->wdata;
  4943. int64_t sum_thread = 0;
  4944. // rows per thread
  4945. const int64_t dr = (nr + nth - 1)/nth;
  4946. // row range for this thread
  4947. const int64_t ir0 = dr*ith;
  4948. const int64_t ir1 = MIN(ir0 + dr, nr);
  4949. for (int64_t ir = ir0; ir < ir1; ++ir) {
  4950. const int64_t i03 = ir / (ne02*ne01);
  4951. const int64_t i02 = (ir - i03*ne03) / ne01;
  4952. const int64_t i01 = ir - i03*ne03 - i02*ne02;
  4953. const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
  4954. const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
  4955. for (int64_t i00 = 0; i00 < ne00; ++i00) {
  4956. const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
  4957. const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
  4958. sum_thread += val0 == val1;
  4959. }
  4960. }
  4961. if (ith != 0) {
  4962. sums[ith] = sum_thread;
  4963. }
  4964. ggml_barrier(params->threadpool);
  4965. if (ith != 0) {
  4966. return;
  4967. }
  4968. for (int ith_other = 1; ith_other < nth; ++ith_other) {
  4969. sum_thread += sums[ith_other];
  4970. }
  4971. *((int64_t *) dst->data) = sum_thread;
  4972. }
  4973. static void ggml_compute_forward_count_equal(
  4974. const struct ggml_compute_params * params,
  4975. struct ggml_tensor * dst) {
  4976. const struct ggml_tensor * src0 = dst->src[0];
  4977. switch (src0->type) {
  4978. case GGML_TYPE_I32:
  4979. {
  4980. ggml_compute_forward_count_equal_i32(params, dst);
  4981. } break;
  4982. default:
  4983. {
  4984. GGML_ABORT("fatal error");
  4985. }
  4986. }
  4987. }
  4988. // ggml_compute_forward_repeat
  4989. static void ggml_compute_forward_repeat_f32(
  4990. const struct ggml_compute_params * params,
  4991. struct ggml_tensor * dst) {
  4992. const struct ggml_tensor * src0 = dst->src[0];
  4993. if (params->ith != 0) {
  4994. return;
  4995. }
  4996. GGML_ASSERT(ggml_can_repeat(src0, dst));
  4997. GGML_TENSOR_UNARY_OP_LOCALS
  4998. // guaranteed to be an integer due to the check in ggml_can_repeat
  4999. const int nr0 = (int)(ne0/ne00);
  5000. const int nr1 = (int)(ne1/ne01);
  5001. const int nr2 = (int)(ne2/ne02);
  5002. const int nr3 = (int)(ne3/ne03);
  5003. // TODO: support for transposed / permuted tensors
  5004. GGML_ASSERT(nb0 == sizeof(float));
  5005. GGML_ASSERT(nb00 == sizeof(float));
  5006. // TODO: maybe this is not optimal?
  5007. for (int i3 = 0; i3 < nr3; i3++) {
  5008. for (int k3 = 0; k3 < ne03; k3++) {
  5009. for (int i2 = 0; i2 < nr2; i2++) {
  5010. for (int k2 = 0; k2 < ne02; k2++) {
  5011. for (int i1 = 0; i1 < nr1; i1++) {
  5012. for (int k1 = 0; k1 < ne01; k1++) {
  5013. for (int i0 = 0; i0 < nr0; i0++) {
  5014. ggml_vec_cpy_f32(ne00,
  5015. (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
  5016. (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
  5017. }
  5018. }
  5019. }
  5020. }
  5021. }
  5022. }
  5023. }
  5024. }
  5025. static void ggml_compute_forward_repeat_f16(
  5026. const struct ggml_compute_params * params,
  5027. struct ggml_tensor * dst) {
  5028. const struct ggml_tensor * src0 = dst->src[0];
  5029. if (params->ith != 0) {
  5030. return;
  5031. }
  5032. GGML_ASSERT(ggml_can_repeat(src0, dst));
  5033. GGML_TENSOR_UNARY_OP_LOCALS
  5034. // guaranteed to be an integer due to the check in ggml_can_repeat
  5035. const int nr0 = (int)(ne0/ne00);
  5036. const int nr1 = (int)(ne1/ne01);
  5037. const int nr2 = (int)(ne2/ne02);
  5038. const int nr3 = (int)(ne3/ne03);
  5039. // TODO: support for transposed / permuted tensors
  5040. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  5041. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  5042. // TODO: maybe this is not optimal?
  5043. for (int i3 = 0; i3 < nr3; i3++) {
  5044. for (int k3 = 0; k3 < ne03; k3++) {
  5045. for (int i2 = 0; i2 < nr2; i2++) {
  5046. for (int k2 = 0; k2 < ne02; k2++) {
  5047. for (int i1 = 0; i1 < nr1; i1++) {
  5048. for (int k1 = 0; k1 < ne01; k1++) {
  5049. for (int i0 = 0; i0 < nr0; i0++) {
  5050. ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
  5051. ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
  5052. // ggml_vec_cpy_f16(ne00, y, x)
  5053. for (int i = 0; i < ne00; ++i) {
  5054. y[i] = x[i];
  5055. }
  5056. }
  5057. }
  5058. }
  5059. }
  5060. }
  5061. }
  5062. }
  5063. }
  5064. static void ggml_compute_forward_repeat(
  5065. const struct ggml_compute_params * params,
  5066. struct ggml_tensor * dst) {
  5067. const struct ggml_tensor * src0 = dst->src[0];
  5068. switch (src0->type) {
  5069. case GGML_TYPE_F16:
  5070. case GGML_TYPE_BF16:
  5071. case GGML_TYPE_I16:
  5072. {
  5073. ggml_compute_forward_repeat_f16(params, dst);
  5074. } break;
  5075. case GGML_TYPE_F32:
  5076. case GGML_TYPE_I32:
  5077. {
  5078. ggml_compute_forward_repeat_f32(params, dst);
  5079. } break;
  5080. default:
  5081. {
  5082. GGML_ABORT("fatal error");
  5083. }
  5084. }
  5085. }
  5086. // ggml_compute_forward_repeat_back
  5087. static void ggml_compute_forward_repeat_back_f32(
  5088. const struct ggml_compute_params * params,
  5089. struct ggml_tensor * dst) {
  5090. const struct ggml_tensor * src0 = dst->src[0];
  5091. if (params->ith != 0) {
  5092. return;
  5093. }
  5094. GGML_ASSERT(ggml_can_repeat(dst, src0));
  5095. GGML_TENSOR_UNARY_OP_LOCALS
  5096. // guaranteed to be an integer due to the check in ggml_can_repeat
  5097. const int nr0 = (int)(ne00/ne0);
  5098. const int nr1 = (int)(ne01/ne1);
  5099. const int nr2 = (int)(ne02/ne2);
  5100. const int nr3 = (int)(ne03/ne3);
  5101. // TODO: support for transposed / permuted tensors
  5102. GGML_ASSERT(nb0 == sizeof(float));
  5103. GGML_ASSERT(nb00 == sizeof(float));
  5104. if (ggml_is_contiguous(dst)) {
  5105. ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
  5106. } else {
  5107. for (int k3 = 0; k3 < ne3; k3++) {
  5108. for (int k2 = 0; k2 < ne2; k2++) {
  5109. for (int k1 = 0; k1 < ne1; k1++) {
  5110. ggml_vec_set_f32(ne0,
  5111. (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
  5112. 0);
  5113. }
  5114. }
  5115. }
  5116. }
  5117. // TODO: maybe this is not optimal?
  5118. for (int i3 = 0; i3 < nr3; i3++) {
  5119. for (int k3 = 0; k3 < ne3; k3++) {
  5120. for (int i2 = 0; i2 < nr2; i2++) {
  5121. for (int k2 = 0; k2 < ne2; k2++) {
  5122. for (int i1 = 0; i1 < nr1; i1++) {
  5123. for (int k1 = 0; k1 < ne1; k1++) {
  5124. for (int i0 = 0; i0 < nr0; i0++) {
  5125. ggml_vec_acc_f32(ne0,
  5126. (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
  5127. (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
  5128. }
  5129. }
  5130. }
  5131. }
  5132. }
  5133. }
  5134. }
  5135. }
  5136. static void ggml_compute_forward_repeat_back(
  5137. const struct ggml_compute_params * params,
  5138. struct ggml_tensor * dst) {
  5139. const struct ggml_tensor * src0 = dst->src[0];
  5140. switch (src0->type) {
  5141. case GGML_TYPE_F32:
  5142. {
  5143. ggml_compute_forward_repeat_back_f32(params, dst);
  5144. } break;
  5145. default:
  5146. {
  5147. GGML_ABORT("fatal error");
  5148. }
  5149. }
  5150. }
  5151. // ggml_compute_forward_concat
  5152. static void ggml_compute_forward_concat_f32(
  5153. const struct ggml_compute_params * params,
  5154. struct ggml_tensor * dst) {
  5155. const struct ggml_tensor * src0 = dst->src[0];
  5156. const struct ggml_tensor * src1 = dst->src[1];
  5157. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5158. const int ith = params->ith;
  5159. const int nth = params->nth;
  5160. GGML_TENSOR_BINARY_OP_LOCALS
  5161. const int32_t dim = ggml_get_op_params_i32(dst, 0);
  5162. GGML_ASSERT(dim >= 0 && dim < 4);
  5163. int64_t o[4] = {0, 0, 0, 0};
  5164. o[dim] = src0->ne[dim];
  5165. const float * x;
  5166. // TODO: smarter multi-theading
  5167. for (int i3 = 0; i3 < ne3; i3++) {
  5168. for (int i2 = ith; i2 < ne2; i2 += nth) {
  5169. for (int i1 = 0; i1 < ne1; i1++) {
  5170. for (int i0 = 0; i0 < ne0; i0++) {
  5171. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  5172. x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
  5173. } else {
  5174. x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
  5175. }
  5176. float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  5177. *y = *x;
  5178. }
  5179. }
  5180. }
  5181. }
  5182. }
  5183. static void ggml_compute_forward_concat(
  5184. const struct ggml_compute_params * params,
  5185. struct ggml_tensor * dst) {
  5186. const struct ggml_tensor * src0 = dst->src[0];
  5187. switch (src0->type) {
  5188. case GGML_TYPE_F32:
  5189. case GGML_TYPE_I32:
  5190. {
  5191. ggml_compute_forward_concat_f32(params, dst);
  5192. } break;
  5193. default:
  5194. {
  5195. GGML_ABORT("fatal error");
  5196. }
  5197. }
  5198. }
  5199. // ggml_compute_forward_abs
  5200. static void ggml_compute_forward_abs_f32(
  5201. const struct ggml_compute_params * params,
  5202. struct ggml_tensor * dst) {
  5203. const struct ggml_tensor * src0 = dst->src[0];
  5204. if (params->ith != 0) {
  5205. return;
  5206. }
  5207. assert(ggml_is_contiguous_1(src0));
  5208. assert(ggml_is_contiguous_1(dst));
  5209. assert(ggml_are_same_shape(src0, dst));
  5210. const int n = ggml_nrows(src0);
  5211. const int nc = src0->ne[0];
  5212. for (int i = 0; i < n; i++) {
  5213. ggml_vec_abs_f32(nc,
  5214. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5215. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5216. }
  5217. }
  5218. static void ggml_compute_forward_abs(
  5219. const struct ggml_compute_params * params,
  5220. struct ggml_tensor * dst) {
  5221. const struct ggml_tensor * src0 = dst->src[0];
  5222. switch (src0->type) {
  5223. case GGML_TYPE_F32:
  5224. {
  5225. ggml_compute_forward_abs_f32(params, dst);
  5226. } break;
  5227. default:
  5228. {
  5229. GGML_ABORT("fatal error");
  5230. }
  5231. }
  5232. }
  5233. // ggml_compute_forward_sgn
  5234. static void ggml_compute_forward_sgn_f32(
  5235. const struct ggml_compute_params * params,
  5236. struct ggml_tensor * dst) {
  5237. const struct ggml_tensor * src0 = dst->src[0];
  5238. if (params->ith != 0) {
  5239. return;
  5240. }
  5241. assert(ggml_is_contiguous_1(src0));
  5242. assert(ggml_is_contiguous_1(dst));
  5243. assert(ggml_are_same_shape(src0, dst));
  5244. const int n = ggml_nrows(src0);
  5245. const int nc = src0->ne[0];
  5246. for (int i = 0; i < n; i++) {
  5247. ggml_vec_sgn_f32(nc,
  5248. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5249. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5250. }
  5251. }
  5252. static void ggml_compute_forward_sgn(
  5253. const struct ggml_compute_params * params,
  5254. struct ggml_tensor * dst) {
  5255. const struct ggml_tensor * src0 = dst->src[0];
  5256. switch (src0->type) {
  5257. case GGML_TYPE_F32:
  5258. {
  5259. ggml_compute_forward_sgn_f32(params, dst);
  5260. } break;
  5261. default:
  5262. {
  5263. GGML_ABORT("fatal error");
  5264. }
  5265. }
  5266. }
  5267. // ggml_compute_forward_neg
  5268. static void ggml_compute_forward_neg_f32(
  5269. const struct ggml_compute_params * params,
  5270. struct ggml_tensor * dst) {
  5271. const struct ggml_tensor * src0 = dst->src[0];
  5272. if (params->ith != 0) {
  5273. return;
  5274. }
  5275. assert(ggml_is_contiguous_1(src0));
  5276. assert(ggml_is_contiguous_1(dst));
  5277. assert(ggml_are_same_shape(src0, dst));
  5278. const int n = ggml_nrows(src0);
  5279. const int nc = src0->ne[0];
  5280. for (int i = 0; i < n; i++) {
  5281. ggml_vec_neg_f32(nc,
  5282. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5283. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5284. }
  5285. }
  5286. static void ggml_compute_forward_neg(
  5287. const struct ggml_compute_params * params,
  5288. struct ggml_tensor * dst) {
  5289. const struct ggml_tensor * src0 = dst->src[0];
  5290. switch (src0->type) {
  5291. case GGML_TYPE_F32:
  5292. {
  5293. ggml_compute_forward_neg_f32(params, dst);
  5294. } break;
  5295. default:
  5296. {
  5297. GGML_ABORT("fatal error");
  5298. }
  5299. }
  5300. }
  5301. // ggml_compute_forward_step
  5302. static void ggml_compute_forward_step_f32(
  5303. const struct ggml_compute_params * params,
  5304. struct ggml_tensor * dst) {
  5305. const struct ggml_tensor * src0 = dst->src[0];
  5306. if (params->ith != 0) {
  5307. return;
  5308. }
  5309. assert(ggml_is_contiguous_1(src0));
  5310. assert(ggml_is_contiguous_1(dst));
  5311. assert(ggml_are_same_shape(src0, dst));
  5312. const int n = ggml_nrows(src0);
  5313. const int nc = src0->ne[0];
  5314. for (int i = 0; i < n; i++) {
  5315. ggml_vec_step_f32(nc,
  5316. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5317. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5318. }
  5319. }
  5320. static void ggml_compute_forward_step(
  5321. const struct ggml_compute_params * params,
  5322. struct ggml_tensor * dst) {
  5323. const struct ggml_tensor * src0 = dst->src[0];
  5324. switch (src0->type) {
  5325. case GGML_TYPE_F32:
  5326. {
  5327. ggml_compute_forward_step_f32(params, dst);
  5328. } break;
  5329. default:
  5330. {
  5331. GGML_ABORT("fatal error");
  5332. }
  5333. }
  5334. }
  5335. // ggml_compute_forward_tanh
  5336. static void ggml_compute_forward_tanh_f32(
  5337. const struct ggml_compute_params * params,
  5338. struct ggml_tensor * dst) {
  5339. const struct ggml_tensor * src0 = dst->src[0];
  5340. if (params->ith != 0) {
  5341. return;
  5342. }
  5343. assert(ggml_is_contiguous_1(src0));
  5344. assert(ggml_is_contiguous_1(dst));
  5345. assert(ggml_are_same_shape(src0, dst));
  5346. const int n = ggml_nrows(src0);
  5347. const int nc = src0->ne[0];
  5348. for (int i = 0; i < n; i++) {
  5349. ggml_vec_tanh_f32(nc,
  5350. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5351. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5352. }
  5353. }
  5354. static void ggml_compute_forward_tanh(
  5355. const struct ggml_compute_params * params,
  5356. struct ggml_tensor * dst) {
  5357. const struct ggml_tensor * src0 = dst->src[0];
  5358. switch (src0->type) {
  5359. case GGML_TYPE_F32:
  5360. {
  5361. ggml_compute_forward_tanh_f32(params, dst);
  5362. } break;
  5363. default:
  5364. {
  5365. GGML_ABORT("fatal error");
  5366. }
  5367. }
  5368. }
  5369. // ggml_compute_forward_elu
  5370. static void ggml_compute_forward_elu_f32(
  5371. const struct ggml_compute_params * params,
  5372. struct ggml_tensor * dst) {
  5373. const struct ggml_tensor * src0 = dst->src[0];
  5374. if (params->ith != 0) {
  5375. return;
  5376. }
  5377. assert(ggml_is_contiguous_1(src0));
  5378. assert(ggml_is_contiguous_1(dst));
  5379. assert(ggml_are_same_shape(src0, dst));
  5380. const int n = ggml_nrows(src0);
  5381. const int nc = src0->ne[0];
  5382. for (int i = 0; i < n; i++) {
  5383. ggml_vec_elu_f32(nc,
  5384. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5385. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5386. }
  5387. }
  5388. static void ggml_compute_forward_elu(
  5389. const struct ggml_compute_params * params,
  5390. struct ggml_tensor * dst) {
  5391. const struct ggml_tensor * src0 = dst->src[0];
  5392. switch (src0->type) {
  5393. case GGML_TYPE_F32:
  5394. {
  5395. ggml_compute_forward_elu_f32(params, dst);
  5396. } break;
  5397. default:
  5398. {
  5399. GGML_ABORT("fatal error");
  5400. }
  5401. }
  5402. }
  5403. // ggml_compute_forward_relu
  5404. static void ggml_compute_forward_relu_f32(
  5405. const struct ggml_compute_params * params,
  5406. struct ggml_tensor * dst) {
  5407. const struct ggml_tensor * src0 = dst->src[0];
  5408. if (params->ith != 0) {
  5409. return;
  5410. }
  5411. assert(ggml_is_contiguous_1(src0));
  5412. assert(ggml_is_contiguous_1(dst));
  5413. assert(ggml_are_same_shape(src0, dst));
  5414. const int n = ggml_nrows(src0);
  5415. const int nc = src0->ne[0];
  5416. for (int i = 0; i < n; i++) {
  5417. ggml_vec_relu_f32(nc,
  5418. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5419. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5420. }
  5421. }
  5422. static void ggml_compute_forward_relu(
  5423. const struct ggml_compute_params * params,
  5424. struct ggml_tensor * dst) {
  5425. const struct ggml_tensor * src0 = dst->src[0];
  5426. switch (src0->type) {
  5427. case GGML_TYPE_F32:
  5428. {
  5429. ggml_compute_forward_relu_f32(params, dst);
  5430. } break;
  5431. default:
  5432. {
  5433. GGML_ABORT("fatal error");
  5434. }
  5435. }
  5436. }
  5437. // ggml_compute_forward_sigmoid
  5438. static void ggml_compute_forward_sigmoid_f32(
  5439. const struct ggml_compute_params * params,
  5440. struct ggml_tensor * dst) {
  5441. const struct ggml_tensor * src0 = dst->src[0];
  5442. if (params->ith != 0) {
  5443. return;
  5444. }
  5445. assert(ggml_is_contiguous_1(src0));
  5446. assert(ggml_is_contiguous_1(dst));
  5447. assert(ggml_are_same_shape(src0, dst));
  5448. const int n = ggml_nrows(src0);
  5449. const int nc = src0->ne[0];
  5450. for (int i = 0; i < n; i++) {
  5451. ggml_vec_sigmoid_f32(nc,
  5452. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5453. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5454. }
  5455. }
  5456. static void ggml_compute_forward_sigmoid(
  5457. const struct ggml_compute_params * params,
  5458. struct ggml_tensor * dst) {
  5459. const struct ggml_tensor * src0 = dst->src[0];
  5460. switch (src0->type) {
  5461. case GGML_TYPE_F32:
  5462. {
  5463. ggml_compute_forward_sigmoid_f32(params, dst);
  5464. } break;
  5465. default:
  5466. {
  5467. GGML_ABORT("fatal error");
  5468. }
  5469. }
  5470. }
  5471. // ggml_compute_forward_gelu
  5472. static void ggml_compute_forward_gelu_f32(
  5473. const struct ggml_compute_params * params,
  5474. struct ggml_tensor * dst) {
  5475. const struct ggml_tensor * src0 = dst->src[0];
  5476. assert(ggml_is_contiguous_1(src0));
  5477. assert(ggml_is_contiguous_1(dst));
  5478. assert(ggml_are_same_shape(src0, dst));
  5479. const int ith = params->ith;
  5480. const int nth = params->nth;
  5481. const int nc = src0->ne[0];
  5482. const int nr = ggml_nrows(src0);
  5483. // rows per thread
  5484. const int dr = (nr + nth - 1)/nth;
  5485. // row range for this thread
  5486. const int ir0 = dr*ith;
  5487. const int ir1 = MIN(ir0 + dr, nr);
  5488. for (int i1 = ir0; i1 < ir1; i1++) {
  5489. ggml_vec_gelu_f32(nc,
  5490. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  5491. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  5492. #ifndef NDEBUG
  5493. for (int k = 0; k < nc; k++) {
  5494. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  5495. UNUSED(x);
  5496. assert(!isnan(x));
  5497. assert(!isinf(x));
  5498. }
  5499. #endif
  5500. }
  5501. }
  5502. static void ggml_compute_forward_gelu(
  5503. const struct ggml_compute_params * params,
  5504. struct ggml_tensor * dst) {
  5505. const struct ggml_tensor * src0 = dst->src[0];
  5506. switch (src0->type) {
  5507. case GGML_TYPE_F32:
  5508. {
  5509. ggml_compute_forward_gelu_f32(params, dst);
  5510. } break;
  5511. default:
  5512. {
  5513. GGML_ABORT("fatal error");
  5514. }
  5515. }
  5516. }
  5517. // ggml_compute_forward_gelu_quick
  5518. static void ggml_compute_forward_gelu_quick_f32(
  5519. const struct ggml_compute_params * params,
  5520. struct ggml_tensor * dst) {
  5521. const struct ggml_tensor * src0 = dst->src[0];
  5522. assert(ggml_is_contiguous_1(src0));
  5523. assert(ggml_is_contiguous_1(dst));
  5524. assert(ggml_are_same_shape(src0, dst));
  5525. const int ith = params->ith;
  5526. const int nth = params->nth;
  5527. const int nc = src0->ne[0];
  5528. const int nr = ggml_nrows(src0);
  5529. // rows per thread
  5530. const int dr = (nr + nth - 1)/nth;
  5531. // row range for this thread
  5532. const int ir0 = dr*ith;
  5533. const int ir1 = MIN(ir0 + dr, nr);
  5534. for (int i1 = ir0; i1 < ir1; i1++) {
  5535. ggml_vec_gelu_quick_f32(nc,
  5536. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  5537. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  5538. #ifndef NDEBUG
  5539. for (int k = 0; k < nc; k++) {
  5540. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  5541. UNUSED(x);
  5542. assert(!isnan(x));
  5543. assert(!isinf(x));
  5544. }
  5545. #endif
  5546. }
  5547. }
  5548. static void ggml_compute_forward_gelu_quick(
  5549. const struct ggml_compute_params * params,
  5550. struct ggml_tensor * dst) {
  5551. const struct ggml_tensor * src0 = dst->src[0];
  5552. switch (src0->type) {
  5553. case GGML_TYPE_F32:
  5554. {
  5555. ggml_compute_forward_gelu_quick_f32(params, dst);
  5556. } break;
  5557. default:
  5558. {
  5559. GGML_ABORT("fatal error");
  5560. }
  5561. }
  5562. }
  5563. // ggml_compute_forward_silu
  5564. static void ggml_compute_forward_silu_f32(
  5565. const struct ggml_compute_params * params,
  5566. struct ggml_tensor * dst) {
  5567. const struct ggml_tensor * src0 = dst->src[0];
  5568. assert(ggml_is_contiguous_1(src0));
  5569. assert(ggml_is_contiguous_1(dst));
  5570. assert(ggml_are_same_shape(src0, dst));
  5571. const int ith = params->ith;
  5572. const int nth = params->nth;
  5573. const int nc = src0->ne[0];
  5574. const int nr = ggml_nrows(src0);
  5575. // rows per thread
  5576. const int dr = (nr + nth - 1)/nth;
  5577. // row range for this thread
  5578. const int ir0 = dr*ith;
  5579. const int ir1 = MIN(ir0 + dr, nr);
  5580. for (int i1 = ir0; i1 < ir1; i1++) {
  5581. ggml_vec_silu_f32(nc,
  5582. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  5583. (float *) ((char *) src0->data + i1*(src0->nb[1])));
  5584. #ifndef NDEBUG
  5585. for (int k = 0; k < nc; k++) {
  5586. const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
  5587. UNUSED(x);
  5588. assert(!isnan(x));
  5589. assert(!isinf(x));
  5590. }
  5591. #endif
  5592. }
  5593. }
  5594. static void ggml_compute_forward_silu(
  5595. const struct ggml_compute_params * params,
  5596. struct ggml_tensor * dst) {
  5597. const struct ggml_tensor * src0 = dst->src[0];
  5598. switch (src0->type) {
  5599. case GGML_TYPE_F32:
  5600. {
  5601. ggml_compute_forward_silu_f32(params, dst);
  5602. } break;
  5603. default:
  5604. {
  5605. GGML_ABORT("fatal error");
  5606. }
  5607. }
  5608. }
  5609. // ggml_compute_forward_leaky_relu
  5610. static void ggml_compute_forward_leaky_relu_f32(
  5611. const struct ggml_compute_params * params,
  5612. struct ggml_tensor * dst) {
  5613. const struct ggml_tensor * src0 = dst->src[0];
  5614. if (params->ith != 0) {
  5615. return;
  5616. }
  5617. assert(ggml_is_contiguous_1(src0));
  5618. assert(ggml_is_contiguous_1(dst));
  5619. assert(ggml_are_same_shape(src0, dst));
  5620. const int n = ggml_nrows(src0);
  5621. const int nc = src0->ne[0];
  5622. float negative_slope;
  5623. memcpy(&negative_slope, dst->op_params, sizeof(float));
  5624. assert(dst->nb[0] == sizeof(float));
  5625. assert(src0->nb[0] == sizeof(float));
  5626. for (int i = 0; i < n; i++) {
  5627. ggml_vec_leaky_relu_f32(nc,
  5628. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5629. (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
  5630. }
  5631. }
  5632. static void ggml_compute_forward_leaky_relu(
  5633. const struct ggml_compute_params * params,
  5634. struct ggml_tensor * dst) {
  5635. const struct ggml_tensor * src0 = dst->src[0];
  5636. switch (src0->type) {
  5637. case GGML_TYPE_F32:
  5638. {
  5639. ggml_compute_forward_leaky_relu_f32(params, dst);
  5640. } break;
  5641. default:
  5642. {
  5643. GGML_ABORT("fatal error");
  5644. }
  5645. }
  5646. }
  5647. // ggml_compute_forward_silu_back
  5648. static void ggml_compute_forward_silu_back_f32(
  5649. const struct ggml_compute_params * params,
  5650. struct ggml_tensor * dst) {
  5651. const struct ggml_tensor * src0 = dst->src[0];
  5652. const struct ggml_tensor * grad = dst->src[1];
  5653. assert(ggml_is_contiguous_1(grad));
  5654. assert(ggml_is_contiguous_1(src0));
  5655. assert(ggml_is_contiguous_1(dst));
  5656. assert(ggml_are_same_shape(src0, dst));
  5657. assert(ggml_are_same_shape(src0, grad));
  5658. const int ith = params->ith;
  5659. const int nth = params->nth;
  5660. const int nc = src0->ne[0];
  5661. const int nr = ggml_nrows(src0);
  5662. // rows per thread
  5663. const int dr = (nr + nth - 1)/nth;
  5664. // row range for this thread
  5665. const int ir0 = dr*ith;
  5666. const int ir1 = MIN(ir0 + dr, nr);
  5667. for (int i1 = ir0; i1 < ir1; i1++) {
  5668. ggml_vec_silu_backward_f32(nc,
  5669. (float *) ((char *) dst->data + i1*( dst->nb[1])),
  5670. (float *) ((char *) src0->data + i1*(src0->nb[1])),
  5671. (float *) ((char *) grad->data + i1*(grad->nb[1])));
  5672. #ifndef NDEBUG
  5673. for (int k = 0; k < nc; k++) {
  5674. const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
  5675. UNUSED(x);
  5676. assert(!isnan(x));
  5677. assert(!isinf(x));
  5678. }
  5679. #endif
  5680. }
  5681. }
  5682. static void ggml_compute_forward_silu_back(
  5683. const struct ggml_compute_params * params,
  5684. struct ggml_tensor * dst) {
  5685. const struct ggml_tensor * src0 = dst->src[0];
  5686. switch (src0->type) {
  5687. case GGML_TYPE_F32:
  5688. {
  5689. ggml_compute_forward_silu_back_f32(params, dst);
  5690. } break;
  5691. default:
  5692. {
  5693. GGML_ABORT("fatal error");
  5694. }
  5695. }
  5696. }
  5697. static void ggml_compute_forward_hardswish_f32(
  5698. const struct ggml_compute_params * params,
  5699. struct ggml_tensor * dst) {
  5700. const struct ggml_tensor * src0 = dst->src[0];
  5701. if (params->ith != 0) {
  5702. return;
  5703. }
  5704. assert(ggml_is_contiguous_1(src0));
  5705. assert(ggml_is_contiguous_1(dst));
  5706. assert(ggml_are_same_shape(src0, dst));
  5707. const int n = ggml_nrows(src0);
  5708. const int nc = src0->ne[0];
  5709. for (int i = 0; i < n; i++) {
  5710. ggml_vec_hardswish_f32(nc,
  5711. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5712. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5713. }
  5714. }
  5715. static void ggml_compute_forward_hardswish(
  5716. const struct ggml_compute_params * params,
  5717. struct ggml_tensor * dst) {
  5718. const struct ggml_tensor * src0 = dst->src[0];
  5719. switch (src0->type) {
  5720. case GGML_TYPE_F32:
  5721. {
  5722. ggml_compute_forward_hardswish_f32(params, dst);
  5723. } break;
  5724. default:
  5725. {
  5726. GGML_ABORT("fatal error");
  5727. }
  5728. }
  5729. }
  5730. static void ggml_compute_forward_hardsigmoid_f32(
  5731. const struct ggml_compute_params * params,
  5732. struct ggml_tensor * dst) {
  5733. const struct ggml_tensor * src0 = dst->src[0];
  5734. if (params->ith != 0) {
  5735. return;
  5736. }
  5737. assert(ggml_is_contiguous_1(src0));
  5738. assert(ggml_is_contiguous_1(dst));
  5739. assert(ggml_are_same_shape(src0, dst));
  5740. const int n = ggml_nrows(src0);
  5741. const int nc = src0->ne[0];
  5742. for (int i = 0; i < n; i++) {
  5743. ggml_vec_hardsigmoid_f32(nc,
  5744. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5745. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5746. }
  5747. }
  5748. static void ggml_compute_forward_hardsigmoid(
  5749. const struct ggml_compute_params * params,
  5750. struct ggml_tensor * dst) {
  5751. const struct ggml_tensor * src0 = dst->src[0];
  5752. switch (src0->type) {
  5753. case GGML_TYPE_F32:
  5754. {
  5755. ggml_compute_forward_hardsigmoid_f32(params, dst);
  5756. } break;
  5757. default:
  5758. {
  5759. GGML_ABORT("fatal error");
  5760. }
  5761. }
  5762. }
  5763. static void ggml_compute_forward_exp_f32(
  5764. const struct ggml_compute_params * params,
  5765. struct ggml_tensor * dst) {
  5766. const struct ggml_tensor * src0 = dst->src[0];
  5767. if (params->ith != 0) {
  5768. return;
  5769. }
  5770. assert(ggml_is_contiguous_1(src0));
  5771. assert(ggml_is_contiguous_1(dst));
  5772. assert(ggml_are_same_shape(src0, dst));
  5773. const int n = ggml_nrows(src0);
  5774. const int nc = src0->ne[0];
  5775. for (int i = 0; i < n; i++) {
  5776. ggml_vec_exp_f32(nc,
  5777. (float *) ((char *) dst->data + i*( dst->nb[1])),
  5778. (float *) ((char *) src0->data + i*(src0->nb[1])));
  5779. }
  5780. }
  5781. static void ggml_compute_forward_exp(
  5782. const struct ggml_compute_params * params,
  5783. struct ggml_tensor * dst) {
  5784. const struct ggml_tensor * src0 = dst->src[0];
  5785. switch (src0->type) {
  5786. case GGML_TYPE_F32:
  5787. {
  5788. ggml_compute_forward_exp_f32(params, dst);
  5789. } break;
  5790. default:
  5791. {
  5792. GGML_ABORT("fatal error");
  5793. }
  5794. }
  5795. }
  5796. // ggml_compute_forward_norm
  5797. static void ggml_compute_forward_norm_f32(
  5798. const struct ggml_compute_params * params,
  5799. struct ggml_tensor * dst) {
  5800. const struct ggml_tensor * src0 = dst->src[0];
  5801. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  5802. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5803. const int ith = params->ith;
  5804. const int nth = params->nth;
  5805. GGML_TENSOR_UNARY_OP_LOCALS
  5806. float eps;
  5807. memcpy(&eps, dst->op_params, sizeof(float));
  5808. GGML_ASSERT(eps > 0.0f);
  5809. // TODO: optimize
  5810. for (int64_t i03 = 0; i03 < ne03; i03++) {
  5811. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5812. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  5813. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  5814. ggml_float sum = 0.0;
  5815. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5816. sum += (ggml_float)x[i00];
  5817. }
  5818. float mean = sum/ne00;
  5819. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  5820. ggml_float sum2 = 0.0;
  5821. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5822. float v = x[i00] - mean;
  5823. y[i00] = v;
  5824. sum2 += (ggml_float)(v*v);
  5825. }
  5826. float variance = sum2/ne00;
  5827. const float scale = 1.0f/sqrtf(variance + eps);
  5828. ggml_vec_scale_f32(ne00, y, scale);
  5829. }
  5830. }
  5831. }
  5832. }
  5833. static void ggml_compute_forward_norm(
  5834. const struct ggml_compute_params * params,
  5835. struct ggml_tensor * dst) {
  5836. const struct ggml_tensor * src0 = dst->src[0];
  5837. switch (src0->type) {
  5838. case GGML_TYPE_F32:
  5839. {
  5840. ggml_compute_forward_norm_f32(params, dst);
  5841. } break;
  5842. default:
  5843. {
  5844. GGML_ABORT("fatal error");
  5845. }
  5846. }
  5847. }
  5848. // ggml_compute_forward_group_rms_norm
  5849. static void ggml_compute_forward_rms_norm_f32(
  5850. const struct ggml_compute_params * params,
  5851. struct ggml_tensor * dst) {
  5852. const struct ggml_tensor * src0 = dst->src[0];
  5853. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  5854. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5855. const int ith = params->ith;
  5856. const int nth = params->nth;
  5857. GGML_TENSOR_UNARY_OP_LOCALS
  5858. float eps;
  5859. memcpy(&eps, dst->op_params, sizeof(float));
  5860. GGML_ASSERT(eps > 0.0f);
  5861. // TODO: optimize
  5862. for (int64_t i03 = 0; i03 < ne03; i03++) {
  5863. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5864. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  5865. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  5866. ggml_float sum = 0.0;
  5867. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5868. sum += (ggml_float)(x[i00] * x[i00]);
  5869. }
  5870. const float mean = sum/ne00;
  5871. float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  5872. memcpy(y, x, ne00 * sizeof(float));
  5873. // for (int i00 = 0; i00 < ne00; i00++) {
  5874. // y[i00] = x[i00];
  5875. // }
  5876. const float scale = 1.0f/sqrtf(mean + eps);
  5877. ggml_vec_scale_f32(ne00, y, scale);
  5878. }
  5879. }
  5880. }
  5881. }
  5882. static void ggml_compute_forward_rms_norm(
  5883. const struct ggml_compute_params * params,
  5884. struct ggml_tensor * dst) {
  5885. const struct ggml_tensor * src0 = dst->src[0];
  5886. switch (src0->type) {
  5887. case GGML_TYPE_F32:
  5888. {
  5889. ggml_compute_forward_rms_norm_f32(params, dst);
  5890. } break;
  5891. default:
  5892. {
  5893. GGML_ABORT("fatal error");
  5894. }
  5895. }
  5896. }
  5897. static void ggml_compute_forward_rms_norm_back_f32(
  5898. const struct ggml_compute_params * params,
  5899. struct ggml_tensor * dst) {
  5900. const struct ggml_tensor * src0 = dst->src[0];
  5901. const struct ggml_tensor * src1 = dst->src[1];
  5902. GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
  5903. GGML_ASSERT(src0->nb[0] == sizeof(float));
  5904. const int ith = params->ith;
  5905. const int nth = params->nth;
  5906. GGML_TENSOR_BINARY_OP_LOCALS
  5907. float eps;
  5908. memcpy(&eps, dst->op_params, sizeof(float));
  5909. // TODO: optimize
  5910. for (int64_t i03 = 0; i03 < ne03; i03++) {
  5911. for (int64_t i02 = 0; i02 < ne02; i02++) {
  5912. for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
  5913. // src1 is same shape as src0 => same indices
  5914. const int64_t i11 = i01;
  5915. const int64_t i12 = i02;
  5916. const int64_t i13 = i03;
  5917. const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  5918. const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
  5919. ggml_float sum_xx = 0.0;
  5920. ggml_float sum_xdz = 0.0;
  5921. for (int64_t i00 = 0; i00 < ne00; i00++) {
  5922. sum_xx += (ggml_float)(x[i00] * x[i00]);
  5923. sum_xdz += (ggml_float)(x[i00] * dz[i00]);
  5924. }
  5925. //const float mean = (float)(sum_xx)/ne00;
  5926. const float mean_eps = (float)(sum_xx)/ne00 + eps;
  5927. const float sum_eps = (float)(sum_xx) + eps*ne00;
  5928. //const float mean_xdz = (float)(sum_xdz)/ne00;
  5929. // we could cache rms from forward pass to improve performance.
  5930. // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
  5931. //const float rms = sqrtf(mean_eps);
  5932. const float rrms = 1.0f / sqrtf(mean_eps);
  5933. //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
  5934. {
  5935. // z = rms_norm(x)
  5936. //
  5937. // rms_norm(src0) =
  5938. // scale(
  5939. // src0,
  5940. // div(
  5941. // 1,
  5942. // sqrt(
  5943. // add(
  5944. // scale(
  5945. // sum(
  5946. // sqr(
  5947. // src0)),
  5948. // (1.0/N)),
  5949. // eps))));
  5950. // postorder:
  5951. // ## op args grad
  5952. // 00 param src0 grad[#00]
  5953. // 01 const 1
  5954. // 02 sqr (#00) grad[#02]
  5955. // 03 sum (#02) grad[#03]
  5956. // 04 const 1/N
  5957. // 05 scale (#03, #04) grad[#05]
  5958. // 06 const eps
  5959. // 07 add (#05, #06) grad[#07]
  5960. // 08 sqrt (#07) grad[#08]
  5961. // 09 div (#01,#08) grad[#09]
  5962. // 10 scale (#00,#09) grad[#10]
  5963. //
  5964. // backward pass, given grad[#10]
  5965. // #10: scale
  5966. // grad[#00] += scale(grad[#10],#09)
  5967. // grad[#09] += sum(mul(grad[#10],#00))
  5968. // #09: div
  5969. // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
  5970. // #08: sqrt
  5971. // grad[#07] += mul(grad[#08], div(0.5, #08))
  5972. // #07: add
  5973. // grad[#05] += grad[#07]
  5974. // #05: scale
  5975. // grad[#03] += scale(grad[#05],#04)
  5976. // #03: sum
  5977. // grad[#02] += repeat(grad[#03], #02)
  5978. // #02:
  5979. // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
  5980. //
  5981. // substitute and simplify:
  5982. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  5983. // grad[#02] = repeat(grad[#03], #02)
  5984. // grad[#02] = repeat(scale(grad[#05],#04), #02)
  5985. // grad[#02] = repeat(scale(grad[#07],#04), #02)
  5986. // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
  5987. // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
  5988. // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
  5989. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
  5990. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
  5991. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
  5992. // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
  5993. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
  5994. // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
  5995. // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
  5996. // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
  5997. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  5998. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
  5999. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
  6000. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
  6001. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
  6002. // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
  6003. // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
  6004. // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
  6005. // a = b*c + d*e
  6006. // a = b*c*f/f + d*e*f/f
  6007. // a = (b*c*f + d*e*f)*(1/f)
  6008. // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
  6009. // a = (b + d*e/c)*c
  6010. // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
  6011. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
  6012. // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
  6013. // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
  6014. // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
  6015. // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
  6016. // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
  6017. // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
  6018. // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  6019. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  6020. }
  6021. // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
  6022. // post-order:
  6023. // dx := x
  6024. // dx := scale(dx,-mean_xdz/mean_eps)
  6025. // dx := add(dx, dz)
  6026. // dx := scale(dx, rrms)
  6027. float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
  6028. ggml_vec_cpy_f32 (ne00, dx, x);
  6029. // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
  6030. ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
  6031. ggml_vec_acc_f32 (ne00, dx, dz);
  6032. ggml_vec_scale_f32(ne00, dx, rrms);
  6033. }
  6034. }
  6035. }
  6036. }
  6037. static void ggml_compute_forward_rms_norm_back(
  6038. const struct ggml_compute_params * params,
  6039. struct ggml_tensor * dst) {
  6040. const struct ggml_tensor * src0 = dst->src[0];
  6041. switch (src0->type) {
  6042. case GGML_TYPE_F32:
  6043. {
  6044. ggml_compute_forward_rms_norm_back_f32(params, dst);
  6045. } break;
  6046. default:
  6047. {
  6048. GGML_ABORT("fatal error");
  6049. }
  6050. }
  6051. }
  6052. // ggml_compute_forward_group_norm
  6053. static void ggml_compute_forward_group_norm_f32(
  6054. const struct ggml_compute_params * params,
  6055. struct ggml_tensor * dst) {
  6056. const struct ggml_tensor * src0 = dst->src[0];
  6057. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  6058. GGML_ASSERT(src0->nb[0] == sizeof(float));
  6059. const int ith = params->ith;
  6060. const int nth = params->nth;
  6061. GGML_TENSOR_UNARY_OP_LOCALS
  6062. // TODO: optimize
  6063. float eps;
  6064. memcpy(&eps, dst->op_params + 1, sizeof(float));
  6065. int n_channels = src0->ne[2];
  6066. int n_groups = dst->op_params[0];
  6067. int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
  6068. for (int i = ith; i < n_groups; i += nth) {
  6069. int start = i * n_channels_per_group;
  6070. int end = start + n_channels_per_group;
  6071. if (end > n_channels) {
  6072. end = n_channels;
  6073. }
  6074. int step = end - start;
  6075. for (int64_t i03 = 0; i03 < ne03; i03++) {
  6076. ggml_float sum = 0.0;
  6077. for (int64_t i02 = start; i02 < end; i02++) {
  6078. for (int64_t i01 = 0; i01 < ne01; i01++) {
  6079. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  6080. ggml_float sumr = 0.0;
  6081. for (int64_t i00 = 0; i00 < ne00; i00++) {
  6082. sumr += (ggml_float)x[i00];
  6083. }
  6084. sum += sumr;
  6085. }
  6086. }
  6087. const float mean = sum / (ne00 * ne01 * step);
  6088. ggml_float sum2 = 0.0;
  6089. for (int64_t i02 = start; i02 < end; i02++) {
  6090. for (int64_t i01 = 0; i01 < ne01; i01++) {
  6091. const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
  6092. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  6093. ggml_float sumr = 0.0;
  6094. for (int64_t i00 = 0; i00 < ne00; i00++) {
  6095. float v = x[i00] - mean;
  6096. y[i00] = v;
  6097. sumr += (ggml_float)(v * v);
  6098. }
  6099. sum2 += sumr;
  6100. }
  6101. }
  6102. const float variance = sum2 / (ne00 * ne01 * step);
  6103. const float scale = 1.0f / sqrtf(variance + eps);
  6104. for (int64_t i02 = start; i02 < end; i02++) {
  6105. for (int64_t i01 = 0; i01 < ne01; i01++) {
  6106. float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
  6107. ggml_vec_scale_f32(ne00, y, scale);
  6108. }
  6109. }
  6110. }
  6111. }
  6112. }
  6113. static void ggml_compute_forward_group_norm(
  6114. const struct ggml_compute_params * params,
  6115. struct ggml_tensor * dst) {
  6116. const struct ggml_tensor * src0 = dst->src[0];
  6117. switch (src0->type) {
  6118. case GGML_TYPE_F32:
  6119. {
  6120. ggml_compute_forward_group_norm_f32(params, dst);
  6121. } break;
  6122. default:
  6123. {
  6124. GGML_ABORT("fatal error");
  6125. }
  6126. }
  6127. }
  6128. // ggml_compute_forward_mul_mat
  6129. static void ggml_compute_forward_mul_mat_one_chunk(
  6130. const struct ggml_compute_params * params,
  6131. struct ggml_tensor * dst,
  6132. const int64_t num_rows_per_vec_dot,
  6133. const int64_t ir0_start,
  6134. const int64_t ir0_end,
  6135. const int64_t ir1_start,
  6136. const int64_t ir1_end) {
  6137. const struct ggml_tensor * src0 = dst->src[0];
  6138. const struct ggml_tensor * src1 = dst->src[1];
  6139. GGML_TENSOR_BINARY_OP_LOCALS
  6140. const enum ggml_type type = src0->type;
  6141. const bool src1_cont = ggml_is_contiguous(src1);
  6142. ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
  6143. enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
  6144. // broadcast factors
  6145. const int64_t r2 = ne12 / ne02;
  6146. const int64_t r3 = ne13 / ne03;
  6147. //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
  6148. // threads with no work simply yield (not sure if it helps)
  6149. if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
  6150. return;
  6151. }
  6152. const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
  6153. const size_t row_size = ggml_row_size(vec_dot_type, ne10);
  6154. assert(ne12 % ne02 == 0);
  6155. assert(ne13 % ne03 == 0);
  6156. // block-tiling attempt
  6157. const int64_t blck_0 = 16;
  6158. const int64_t blck_1 = 16;
  6159. const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
  6160. // attempt to reduce false-sharing (does not seem to make a difference)
  6161. // 16 * 2, accounting for mmla kernels
  6162. float tmp[32];
  6163. for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
  6164. for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
  6165. for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
  6166. const int64_t i13 = (ir1 / (ne12 * ne1));
  6167. const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
  6168. const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
  6169. // broadcast src0 into src1
  6170. const int64_t i03 = i13 / r3;
  6171. const int64_t i02 = i12 / r2;
  6172. const int64_t i1 = i11;
  6173. const int64_t i2 = i12;
  6174. const int64_t i3 = i13;
  6175. const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
  6176. // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
  6177. // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
  6178. // the original src1 data pointer, so we should index using the indices directly
  6179. // TODO: this is a bit of a hack, we should probably have a better way to handle this
  6180. const char * src1_col = (const char*)wdata +
  6181. (src1_cont || src1->type != vec_dot_type
  6182. ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
  6183. : (i11 * nb11 + i12 * nb12 + i13 * nb13));
  6184. float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
  6185. //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
  6186. // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
  6187. //}
  6188. for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
  6189. vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
  6190. }
  6191. for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
  6192. memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
  6193. }
  6194. }
  6195. }
  6196. }
  6197. }
  6198. static void ggml_compute_forward_mul_mat(
  6199. const struct ggml_compute_params * params,
  6200. struct ggml_tensor * dst) {
  6201. const struct ggml_tensor * src0 = dst->src[0];
  6202. const struct ggml_tensor * src1 = dst->src[1];
  6203. GGML_TENSOR_BINARY_OP_LOCALS
  6204. const int ith = params->ith;
  6205. const int nth = params->nth;
  6206. const enum ggml_type type = src0->type;
  6207. enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
  6208. ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float;
  6209. ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat;
  6210. int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
  6211. int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
  6212. int64_t const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
  6213. ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
  6214. ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
  6215. GGML_ASSERT(ne0 == ne01);
  6216. GGML_ASSERT(ne1 == ne11);
  6217. GGML_ASSERT(ne2 == ne12);
  6218. GGML_ASSERT(ne3 == ne13);
  6219. // we don't support permuted src0 or src1
  6220. GGML_ASSERT(nb00 == ggml_type_size(type));
  6221. GGML_ASSERT(nb10 == ggml_type_size(src1->type));
  6222. // dst cannot be transposed or permuted
  6223. GGML_ASSERT(nb0 == sizeof(float));
  6224. GGML_ASSERT(nb0 <= nb1);
  6225. GGML_ASSERT(nb1 <= nb2);
  6226. GGML_ASSERT(nb2 <= nb3);
  6227. // nb01 >= nb00 - src0 is not transposed
  6228. // compute by src0 rows
  6229. #if GGML_USE_LLAMAFILE
  6230. // broadcast factors
  6231. const int64_t r2 = ne12 / ne02;
  6232. const int64_t r3 = ne13 / ne03;
  6233. const bool src1_cont = ggml_is_contiguous(src1);
  6234. if (src1_cont) {
  6235. for (int64_t i13 = 0; i13 < ne13; i13++)
  6236. for (int64_t i12 = 0; i12 < ne12; i12++)
  6237. if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
  6238. (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
  6239. nb01/ggml_type_size(src0->type),
  6240. (const char *)src1->data + i12*nb12 + i13*nb13,
  6241. nb11/ggml_type_size(src1->type),
  6242. (char *)dst->data + i12*nb2 + i13*nb3,
  6243. nb1/ggml_type_size(dst->type),
  6244. ith, nth,
  6245. src0->type,
  6246. src1->type,
  6247. dst->type))
  6248. goto UseGgmlGemm1;
  6249. return;
  6250. }
  6251. UseGgmlGemm1:;
  6252. #endif
  6253. if (src1->type != vec_dot_type) {
  6254. char * wdata = params->wdata;
  6255. const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
  6256. const size_t nbw2 = nbw1*ne11;
  6257. const size_t nbw3 = nbw2*ne12;
  6258. assert(params->wsize >= ne13*nbw3);
  6259. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  6260. for (int64_t i13 = 0; i13 < ne13; ++i13) {
  6261. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  6262. int64_t i11_processed = 0;
  6263. if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
  6264. for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
  6265. from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
  6266. (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
  6267. 4, ne10, blck_size_interleave);
  6268. }
  6269. i11_processed = ne11 - ne11 % 4;
  6270. }
  6271. for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
  6272. from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
  6273. (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
  6274. ne10);
  6275. }
  6276. }
  6277. }
  6278. }
  6279. if (ith == 0) {
  6280. // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
  6281. atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
  6282. }
  6283. ggml_barrier(params->threadpool);
  6284. #if GGML_USE_LLAMAFILE
  6285. if (src1->type != vec_dot_type) {
  6286. const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
  6287. const size_t row_size = ggml_row_size(vec_dot_type, ne10);
  6288. for (int64_t i13 = 0; i13 < ne13; i13++)
  6289. for (int64_t i12 = 0; i12 < ne12; i12++)
  6290. if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
  6291. (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
  6292. nb01/ggml_type_size(src0->type),
  6293. (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
  6294. row_size/ggml_type_size(vec_dot_type),
  6295. (char *)dst->data + i12*nb2 + i13*nb3,
  6296. nb1/ggml_type_size(dst->type),
  6297. ith, nth,
  6298. src0->type,
  6299. vec_dot_type,
  6300. dst->type))
  6301. goto UseGgmlGemm2;
  6302. return;
  6303. }
  6304. UseGgmlGemm2:;
  6305. #endif
  6306. // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
  6307. const int64_t nr0 = ne0;
  6308. // This is the size of the rest of the dimensions of the result
  6309. const int64_t nr1 = ne1 * ne2 * ne3;
  6310. // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
  6311. int64_t num_rows_per_vec_dot = vec_dot_num_rows;
  6312. // TODO: currently the mmla kernels support only even numbered rows/cols.
  6313. // this check can be removed once they are extended to support odd numbered rows/cols too
  6314. if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
  6315. num_rows_per_vec_dot = 1;
  6316. }
  6317. // Now select a reasonable chunk size.
  6318. int chunk_size = 16;
  6319. // We need to step up the size if it's small
  6320. if (nr0 == 1 || nr1 == 1) {
  6321. chunk_size = 64;
  6322. }
  6323. // distribute the work across the inner or outer loop based on which one is larger
  6324. // The number of chunks in the 0/1 dim.
  6325. // CEIL(nr0/chunk_size)
  6326. int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
  6327. int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
  6328. // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
  6329. // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
  6330. // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
  6331. if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
  6332. // distribute the thread work across the inner or outer loop based on which one is larger
  6333. nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
  6334. nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
  6335. }
  6336. // The number of elements in each chunk
  6337. const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
  6338. const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
  6339. if ((ggml_n_dims(src0) == 2) && gemv) {
  6340. const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
  6341. const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
  6342. int64_t src0_start = (ith * ne01) / nth;
  6343. int64_t src0_end = ((ith + 1) * ne01) / nth;
  6344. src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
  6345. src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
  6346. if (src0_start >= src0_end) return;
  6347. // If there are more than three rows in src1, use gemm; otherwise, use gemv.
  6348. if (gemm && (ne11 > 3)) {
  6349. gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
  6350. (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
  6351. }
  6352. for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
  6353. gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
  6354. (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
  6355. src0_end - src0_start);
  6356. }
  6357. return;
  6358. }
  6359. // The first chunk comes from our thread_id, the rest will get auto-assigned.
  6360. int current_chunk = ith;
  6361. while (current_chunk < nchunk0 * nchunk1) {
  6362. const int64_t ith0 = current_chunk % nchunk0;
  6363. const int64_t ith1 = current_chunk / nchunk0;
  6364. const int64_t ir0_start = dr0 * ith0;
  6365. const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
  6366. const int64_t ir1_start = dr1 * ith1;
  6367. const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
  6368. ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
  6369. if (nth >= nchunk0 * nchunk1) {
  6370. break;
  6371. }
  6372. current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
  6373. }
  6374. }
  6375. // ggml_compute_forward_mul_mat_id
  6376. static void ggml_compute_forward_mul_mat_id(
  6377. const struct ggml_compute_params * params,
  6378. struct ggml_tensor * dst) {
  6379. const struct ggml_tensor * src0 = dst->src[0];
  6380. const struct ggml_tensor * src1 = dst->src[1];
  6381. const struct ggml_tensor * ids = dst->src[2];
  6382. GGML_TENSOR_BINARY_OP_LOCALS
  6383. const int ith = params->ith;
  6384. const int nth = params->nth;
  6385. const enum ggml_type type = src0->type;
  6386. const bool src1_cont = ggml_is_contiguous(src1);
  6387. ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
  6388. enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
  6389. ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float;
  6390. int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
  6391. ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
  6392. // we don't support permuted src0 or src1
  6393. GGML_ASSERT(nb00 == ggml_type_size(type));
  6394. GGML_ASSERT(nb10 == ggml_type_size(src1->type));
  6395. // dst cannot be transposed or permuted
  6396. GGML_ASSERT(nb0 == sizeof(float));
  6397. GGML_ASSERT(nb0 <= nb1);
  6398. GGML_ASSERT(nb1 <= nb2);
  6399. GGML_ASSERT(nb2 <= nb3);
  6400. // row groups
  6401. const int n_ids = ids->ne[0]; // n_expert_used
  6402. const int n_as = ne02; // n_expert
  6403. char * wdata_src1_end = (src1->type == vec_dot_type) ?
  6404. (char *) params->wdata :
  6405. (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
  6406. struct mmid_row_mapping {
  6407. int32_t i1;
  6408. int32_t i2;
  6409. };
  6410. int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
  6411. struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
  6412. if (src1->type != vec_dot_type) {
  6413. char * wdata = params->wdata;
  6414. const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
  6415. const size_t nbw2 = nbw1*ne11;
  6416. const size_t nbw3 = nbw2*ne12;
  6417. assert(params->wsize >= ne13*nbw3);
  6418. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  6419. for (int64_t i13 = 0; i13 < ne13; ++i13) {
  6420. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  6421. for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
  6422. from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
  6423. (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
  6424. ne10);
  6425. }
  6426. }
  6427. }
  6428. }
  6429. #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
  6430. if (ith == 0) {
  6431. // initialize matrix_row_counts
  6432. memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
  6433. // group rows by src0 matrix
  6434. for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
  6435. for (int id = 0; id < n_ids; ++id) {
  6436. const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
  6437. assert(i02 >= 0 && i02 < n_as);
  6438. MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
  6439. matrix_row_counts[i02] += 1;
  6440. }
  6441. }
  6442. }
  6443. ggml_barrier(params->threadpool);
  6444. // compute each matrix multiplication in sequence
  6445. for (int cur_a = 0; cur_a < n_as; ++cur_a) {
  6446. const int64_t cne1 = matrix_row_counts[cur_a];
  6447. if (cne1 == 0) {
  6448. continue;
  6449. }
  6450. const char * src0_cur = (const char *) src0->data + cur_a*nb02;
  6451. const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
  6452. const size_t row_size = ggml_row_size(vec_dot_type, ne10);
  6453. const int64_t nr0 = ne01; // src0 rows
  6454. const int64_t nr1 = cne1; // src1 rows
  6455. if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
  6456. int64_t src0_cur_start = (ith * ne01) / nth;
  6457. int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
  6458. src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
  6459. src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
  6460. if (src0_cur_start >= src0_cur_end) return;
  6461. for (int ir1 = 0; ir1 < nr1; ir1++) {
  6462. struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
  6463. const int id = row_mapping.i1; // selected expert index
  6464. const int64_t i11 = id % ne11;
  6465. const int64_t i12 = row_mapping.i2; // row index in src1
  6466. const int64_t i1 = id; // selected expert index
  6467. const int64_t i2 = i12; // row
  6468. const char * src1_col = (const char *) wdata +
  6469. (src1_cont || src1->type != vec_dot_type
  6470. ? (i11 + i12 * ne11) * row_size
  6471. : (i11 * nb11 + i12 * nb12));
  6472. gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
  6473. (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
  6474. }
  6475. continue;
  6476. }
  6477. // distribute the thread work across the inner or outer loop based on which one is larger
  6478. const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
  6479. const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
  6480. const int64_t ith0 = ith % nth0;
  6481. const int64_t ith1 = ith / nth0;
  6482. const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
  6483. const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
  6484. const int64_t ir010 = dr0*ith0;
  6485. const int64_t ir011 = MIN(ir010 + dr0, nr0);
  6486. const int64_t ir110 = dr1*ith1;
  6487. const int64_t ir111 = MIN(ir110 + dr1, nr1);
  6488. // threads with no work simply yield (not sure if it helps)
  6489. //if (ir010 >= ir011 || ir110 >= ir111) {
  6490. // sched_yield();
  6491. // continue;
  6492. //}
  6493. // block-tiling attempt
  6494. const int64_t blck_0 = 16;
  6495. const int64_t blck_1 = 16;
  6496. // attempt to reduce false-sharing (does not seem to make a difference)
  6497. float tmp[16];
  6498. for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
  6499. for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
  6500. for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
  6501. const int64_t _i12 = ir1; // logical row index for this expert
  6502. struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
  6503. const int id = row_mapping.i1; // selected expert index
  6504. const int64_t i11 = id % ne11;
  6505. const int64_t i12 = row_mapping.i2; // row index in src1
  6506. const int64_t i1 = id; // selected expert index
  6507. const int64_t i2 = i12; // row
  6508. // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
  6509. // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
  6510. // the original src1 data pointer, so we should index using the indices directly
  6511. // TODO: this is a bit of a hack, we should probably have a better way to handle this
  6512. const char * src1_col = (const char *) wdata +
  6513. (src1_cont || src1->type != vec_dot_type
  6514. ? (i11 + i12*ne11)*row_size
  6515. : (i11*nb11 + i12*nb12));
  6516. float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
  6517. //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
  6518. // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
  6519. //}
  6520. for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
  6521. vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
  6522. }
  6523. memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
  6524. }
  6525. }
  6526. }
  6527. }
  6528. #undef MMID_MATRIX_ROW
  6529. }
  6530. // ggml_compute_forward_out_prod
  6531. static void ggml_compute_forward_out_prod_f32(
  6532. const struct ggml_compute_params * params,
  6533. struct ggml_tensor * dst) {
  6534. const struct ggml_tensor * src0 = dst->src[0];
  6535. const struct ggml_tensor * src1 = dst->src[1];
  6536. GGML_TENSOR_BINARY_OP_LOCALS
  6537. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  6538. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  6539. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  6540. const int ith = params->ith;
  6541. const int nth = params->nth;
  6542. GGML_ASSERT(ne0 == ne00);
  6543. GGML_ASSERT(ne1 == ne10);
  6544. GGML_ASSERT(ne2 == ne02);
  6545. GGML_ASSERT(ne02 == ne12);
  6546. GGML_ASSERT(ne3 == ne13);
  6547. GGML_ASSERT(ne03 == ne13);
  6548. // we don't support permuted src0 or src1
  6549. GGML_ASSERT(nb00 == sizeof(float));
  6550. // dst cannot be transposed or permuted
  6551. GGML_ASSERT(nb0 == sizeof(float));
  6552. // GGML_ASSERT(nb0 <= nb1);
  6553. // GGML_ASSERT(nb1 <= nb2);
  6554. // GGML_ASSERT(nb2 <= nb3);
  6555. // nb01 >= nb00 - src0 is not transposed
  6556. // compute by src0 rows
  6557. if (ith == 0) {
  6558. ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
  6559. }
  6560. ggml_barrier(params->threadpool);
  6561. // dst[:,:,:,:] = 0
  6562. // for i2,i3:
  6563. // for i1:
  6564. // for i01:
  6565. // for i0:
  6566. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  6567. // parallelize by last three dimensions
  6568. // total rows in dst
  6569. const int64_t nr = ne1*ne2*ne3;
  6570. // rows per thread
  6571. const int64_t dr = (nr + nth - 1)/nth;
  6572. // row range for this thread
  6573. const int64_t ir0 = dr*ith;
  6574. const int64_t ir1 = MIN(ir0 + dr, nr);
  6575. // block-tiling attempt
  6576. const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
  6577. const int64_t blck_1 = 16;
  6578. for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
  6579. const int64_t bir1 = MIN(bir + blck_1, ir1);
  6580. for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
  6581. const int64_t bne01 = MIN(bi01 + blck_0, ne01);
  6582. for (int64_t ir = bir; ir < bir1; ++ir) {
  6583. // dst indices
  6584. const int64_t i3 = ir/(ne2*ne1);
  6585. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  6586. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  6587. const int64_t i02 = i2;
  6588. const int64_t i03 = i3;
  6589. //const int64_t i10 = i1;
  6590. const int64_t i12 = i2;
  6591. const int64_t i13 = i3;
  6592. #if GGML_VEC_MAD_UNROLL > 2
  6593. const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
  6594. for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
  6595. const int64_t i11 = i01;
  6596. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  6597. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  6598. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  6599. ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
  6600. }
  6601. for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
  6602. const int64_t i11 = i01;
  6603. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  6604. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  6605. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  6606. ggml_vec_mad_f32(ne0, d, s0, *s1);
  6607. }
  6608. #else
  6609. for (int64_t i01 = bi01; i01 < bne01; ++i01) {
  6610. const int64_t i11 = i01;
  6611. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  6612. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  6613. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  6614. ggml_vec_mad_f32(ne0, d, s0, *s1);
  6615. }
  6616. #endif
  6617. }
  6618. }
  6619. }
  6620. }
  6621. static void ggml_compute_forward_out_prod_q_f32(
  6622. const struct ggml_compute_params * params,
  6623. struct ggml_tensor * dst) {
  6624. const struct ggml_tensor * src0 = dst->src[0];
  6625. const struct ggml_tensor * src1 = dst->src[1];
  6626. GGML_TENSOR_BINARY_OP_LOCALS;
  6627. const int ith = params->ith;
  6628. const int nth = params->nth;
  6629. const enum ggml_type type = src0->type;
  6630. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  6631. GGML_ASSERT(ne02 == ne12);
  6632. GGML_ASSERT(ne03 == ne13);
  6633. GGML_ASSERT(ne2 == ne12);
  6634. GGML_ASSERT(ne3 == ne13);
  6635. // we don't support permuted src0 dim0
  6636. GGML_ASSERT(nb00 == ggml_type_size(type));
  6637. // dst dim0 cannot be transposed or permuted
  6638. GGML_ASSERT(nb0 == sizeof(float));
  6639. // GGML_ASSERT(nb0 <= nb1);
  6640. // GGML_ASSERT(nb1 <= nb2);
  6641. // GGML_ASSERT(nb2 <= nb3);
  6642. GGML_ASSERT(ne0 == ne00);
  6643. GGML_ASSERT(ne1 == ne10);
  6644. GGML_ASSERT(ne2 == ne02);
  6645. GGML_ASSERT(ne3 == ne03);
  6646. // nb01 >= nb00 - src0 is not transposed
  6647. // compute by src0 rows
  6648. if (ith == 0) {
  6649. ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
  6650. }
  6651. ggml_barrier(params->threadpool);
  6652. // parallelize by last three dimensions
  6653. // total rows in dst
  6654. const int64_t nr = ne1*ne2*ne3;
  6655. // rows per thread
  6656. const int64_t dr = (nr + nth - 1)/nth;
  6657. // row range for this thread
  6658. const int64_t ir0 = dr*ith;
  6659. const int64_t ir1 = MIN(ir0 + dr, nr);
  6660. // dst[:,:,:,:] = 0
  6661. // for i2,i3:
  6662. // for i1:
  6663. // for i01:
  6664. // for i0:
  6665. // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
  6666. float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  6667. for (int64_t ir = ir0; ir < ir1; ++ir) {
  6668. // dst indices
  6669. const int64_t i3 = ir/(ne2*ne1);
  6670. const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
  6671. const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
  6672. const int64_t i02 = i2;
  6673. const int64_t i03 = i3;
  6674. //const int64_t i10 = i1;
  6675. const int64_t i12 = i2;
  6676. const int64_t i13 = i3;
  6677. for (int64_t i01 = 0; i01 < ne01; ++i01) {
  6678. const int64_t i11 = i01;
  6679. float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
  6680. float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
  6681. float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
  6682. dequantize_row_q(s0, wdata, ne0);
  6683. ggml_vec_mad_f32(ne0, d, wdata, *s1);
  6684. }
  6685. }
  6686. }
  6687. static void ggml_compute_forward_out_prod(
  6688. const struct ggml_compute_params * params,
  6689. struct ggml_tensor * dst) {
  6690. const struct ggml_tensor * src0 = dst->src[0];
  6691. switch (src0->type) {
  6692. case GGML_TYPE_Q4_0:
  6693. case GGML_TYPE_Q4_1:
  6694. case GGML_TYPE_Q5_0:
  6695. case GGML_TYPE_Q5_1:
  6696. case GGML_TYPE_Q8_0:
  6697. case GGML_TYPE_Q2_K:
  6698. case GGML_TYPE_Q3_K:
  6699. case GGML_TYPE_Q4_K:
  6700. case GGML_TYPE_Q5_K:
  6701. case GGML_TYPE_Q6_K:
  6702. case GGML_TYPE_TQ1_0:
  6703. case GGML_TYPE_TQ2_0:
  6704. case GGML_TYPE_IQ2_XXS:
  6705. case GGML_TYPE_IQ2_XS:
  6706. case GGML_TYPE_IQ3_XXS:
  6707. case GGML_TYPE_IQ1_S:
  6708. case GGML_TYPE_IQ1_M:
  6709. case GGML_TYPE_IQ4_NL:
  6710. case GGML_TYPE_IQ4_XS:
  6711. case GGML_TYPE_IQ3_S:
  6712. case GGML_TYPE_IQ2_S:
  6713. case GGML_TYPE_Q4_0_4_4:
  6714. case GGML_TYPE_Q4_0_4_8:
  6715. case GGML_TYPE_Q4_0_8_8:
  6716. {
  6717. ggml_compute_forward_out_prod_q_f32(params, dst);
  6718. } break;
  6719. case GGML_TYPE_F16:
  6720. {
  6721. GGML_ABORT("fatal error"); // todo
  6722. // ggml_compute_forward_out_prod_f16_f32(params, dst);
  6723. }
  6724. case GGML_TYPE_F32:
  6725. {
  6726. ggml_compute_forward_out_prod_f32(params, dst);
  6727. } break;
  6728. default:
  6729. {
  6730. GGML_ABORT("fatal error");
  6731. }
  6732. }
  6733. }
  6734. // ggml_compute_forward_scale
  6735. static void ggml_compute_forward_scale_f32(
  6736. const struct ggml_compute_params * params,
  6737. struct ggml_tensor * dst) {
  6738. const struct ggml_tensor * src0 = dst->src[0];
  6739. GGML_ASSERT(ggml_is_contiguous(src0));
  6740. GGML_ASSERT(ggml_is_contiguous(dst));
  6741. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  6742. // scale factor
  6743. float v;
  6744. memcpy(&v, dst->op_params, sizeof(float));
  6745. const int ith = params->ith;
  6746. const int nth = params->nth;
  6747. const int nc = src0->ne[0];
  6748. const int nr = ggml_nrows(src0);
  6749. // rows per thread
  6750. const int dr = (nr + nth - 1)/nth;
  6751. // row range for this thread
  6752. const int ir0 = dr*ith;
  6753. const int ir1 = MIN(ir0 + dr, nr);
  6754. const size_t nb01 = src0->nb[1];
  6755. const size_t nb1 = dst->nb[1];
  6756. for (int i1 = ir0; i1 < ir1; i1++) {
  6757. if (dst->data != src0->data) {
  6758. // src0 is same shape as dst => same indices
  6759. memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
  6760. }
  6761. ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
  6762. }
  6763. }
  6764. static void ggml_compute_forward_scale(
  6765. const struct ggml_compute_params * params,
  6766. struct ggml_tensor * dst) {
  6767. const struct ggml_tensor * src0 = dst->src[0];
  6768. switch (src0->type) {
  6769. case GGML_TYPE_F32:
  6770. {
  6771. ggml_compute_forward_scale_f32(params, dst);
  6772. } break;
  6773. default:
  6774. {
  6775. GGML_ABORT("fatal error");
  6776. }
  6777. }
  6778. }
  6779. // ggml_compute_forward_set
  6780. static void ggml_compute_forward_set_f32(
  6781. const struct ggml_compute_params * params,
  6782. struct ggml_tensor * dst) {
  6783. const struct ggml_tensor * src0 = dst->src[0];
  6784. const struct ggml_tensor * src1 = dst->src[1];
  6785. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  6786. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  6787. // view src0 and dst with these strides and data offset inbytes during set
  6788. // nb0 is implicitly element_size because src0 and dst are contiguous
  6789. size_t nb1 = ((int32_t *) dst->op_params)[0];
  6790. size_t nb2 = ((int32_t *) dst->op_params)[1];
  6791. size_t nb3 = ((int32_t *) dst->op_params)[2];
  6792. size_t offset = ((int32_t *) dst->op_params)[3];
  6793. bool inplace = (bool) ((int32_t *) dst->op_params)[4];
  6794. if (!inplace) {
  6795. if (params->ith == 0) {
  6796. // memcpy needs to be synchronized across threads to avoid race conditions.
  6797. // => do it in INIT phase
  6798. memcpy(
  6799. ((char *) dst->data),
  6800. ((char *) src0->data),
  6801. ggml_nbytes(dst));
  6802. }
  6803. ggml_barrier(params->threadpool);
  6804. }
  6805. const int ith = params->ith;
  6806. const int nth = params->nth;
  6807. const int nr = ggml_nrows(src1);
  6808. const int nc = src1->ne[0];
  6809. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
  6810. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  6811. // src0 and dst as viewed during set
  6812. const size_t nb0 = ggml_element_size(src0);
  6813. const int im0 = (ne10 == 0 ? 0 : ne10-1);
  6814. const int im1 = (ne11 == 0 ? 0 : ne11-1);
  6815. const int im2 = (ne12 == 0 ? 0 : ne12-1);
  6816. const int im3 = (ne13 == 0 ? 0 : ne13-1);
  6817. GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
  6818. GGML_ASSERT(nb10 == sizeof(float));
  6819. // rows per thread
  6820. const int dr = (nr + nth - 1)/nth;
  6821. // row range for this thread
  6822. const int ir0 = dr*ith;
  6823. const int ir1 = MIN(ir0 + dr, nr);
  6824. for (int ir = ir0; ir < ir1; ++ir) {
  6825. // src0 and dst are viewed with shape of src1 and offset
  6826. // => same indices
  6827. const int i3 = ir/(ne12*ne11);
  6828. const int i2 = (ir - i3*ne12*ne11)/ne11;
  6829. const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
  6830. ggml_vec_cpy_f32(nc,
  6831. (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
  6832. (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
  6833. }
  6834. }
  6835. static void ggml_compute_forward_set(
  6836. const struct ggml_compute_params * params,
  6837. struct ggml_tensor * dst) {
  6838. const struct ggml_tensor * src0 = dst->src[0];
  6839. switch (src0->type) {
  6840. case GGML_TYPE_F32:
  6841. {
  6842. ggml_compute_forward_set_f32(params, dst);
  6843. } break;
  6844. case GGML_TYPE_F16:
  6845. case GGML_TYPE_BF16:
  6846. case GGML_TYPE_Q4_0:
  6847. case GGML_TYPE_Q4_1:
  6848. case GGML_TYPE_Q5_0:
  6849. case GGML_TYPE_Q5_1:
  6850. case GGML_TYPE_Q8_0:
  6851. case GGML_TYPE_Q8_1:
  6852. case GGML_TYPE_Q2_K:
  6853. case GGML_TYPE_Q3_K:
  6854. case GGML_TYPE_Q4_K:
  6855. case GGML_TYPE_Q5_K:
  6856. case GGML_TYPE_Q6_K:
  6857. case GGML_TYPE_TQ1_0:
  6858. case GGML_TYPE_TQ2_0:
  6859. case GGML_TYPE_IQ2_XXS:
  6860. case GGML_TYPE_IQ2_XS:
  6861. case GGML_TYPE_IQ3_XXS:
  6862. case GGML_TYPE_IQ1_S:
  6863. case GGML_TYPE_IQ1_M:
  6864. case GGML_TYPE_IQ4_NL:
  6865. case GGML_TYPE_IQ4_XS:
  6866. case GGML_TYPE_IQ3_S:
  6867. case GGML_TYPE_IQ2_S:
  6868. case GGML_TYPE_Q4_0_4_4:
  6869. case GGML_TYPE_Q4_0_4_8:
  6870. case GGML_TYPE_Q4_0_8_8:
  6871. default:
  6872. {
  6873. GGML_ABORT("fatal error");
  6874. }
  6875. }
  6876. }
  6877. // ggml_compute_forward_cpy
  6878. static void ggml_compute_forward_cpy(
  6879. const struct ggml_compute_params * params,
  6880. struct ggml_tensor * dst) {
  6881. ggml_compute_forward_dup(params, dst);
  6882. }
  6883. // ggml_compute_forward_cont
  6884. static void ggml_compute_forward_cont(
  6885. const struct ggml_compute_params * params,
  6886. struct ggml_tensor * dst) {
  6887. ggml_compute_forward_dup(params, dst);
  6888. }
  6889. // ggml_compute_forward_reshape
  6890. static void ggml_compute_forward_reshape(
  6891. const struct ggml_compute_params * params,
  6892. struct ggml_tensor * dst) {
  6893. // NOP
  6894. UNUSED(params);
  6895. UNUSED(dst);
  6896. }
  6897. // ggml_compute_forward_view
  6898. static void ggml_compute_forward_view(
  6899. const struct ggml_compute_params * params,
  6900. const struct ggml_tensor * dst) {
  6901. // NOP
  6902. UNUSED(params);
  6903. UNUSED(dst);
  6904. }
  6905. // ggml_compute_forward_permute
  6906. static void ggml_compute_forward_permute(
  6907. const struct ggml_compute_params * params,
  6908. const struct ggml_tensor * dst) {
  6909. // NOP
  6910. UNUSED(params);
  6911. UNUSED(dst);
  6912. }
  6913. // ggml_compute_forward_transpose
  6914. static void ggml_compute_forward_transpose(
  6915. const struct ggml_compute_params * params,
  6916. const struct ggml_tensor * dst) {
  6917. // NOP
  6918. UNUSED(params);
  6919. UNUSED(dst);
  6920. }
  6921. // ggml_compute_forward_get_rows
  6922. static void ggml_compute_forward_get_rows_q(
  6923. const struct ggml_compute_params * params,
  6924. struct ggml_tensor * dst) {
  6925. const struct ggml_tensor * src0 = dst->src[0];
  6926. const struct ggml_tensor * src1 = dst->src[1];
  6927. GGML_TENSOR_BINARY_OP_LOCALS
  6928. const int64_t nc = ne00;
  6929. const int64_t nr = ggml_nelements(src1);
  6930. const enum ggml_type type = src0->type;
  6931. ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  6932. assert(ne0 == nc);
  6933. assert(ne02 == ne11);
  6934. assert(nb00 == ggml_type_size(type));
  6935. assert(ggml_nrows(dst) == nr);
  6936. const int ith = params->ith;
  6937. const int nth = params->nth;
  6938. // rows per thread
  6939. const int dr = (nr + nth - 1)/nth;
  6940. // row range for this thread
  6941. const int ir0 = dr*ith;
  6942. const int ir1 = MIN(ir0 + dr, nr);
  6943. for (int64_t i = ir0; i < ir1; ++i) {
  6944. const int64_t i12 = i/(ne11*ne10);
  6945. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  6946. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  6947. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  6948. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  6949. dequantize_row_q(
  6950. (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  6951. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  6952. }
  6953. }
  6954. static void ggml_compute_forward_get_rows_f16(
  6955. const struct ggml_compute_params * params,
  6956. struct ggml_tensor * dst) {
  6957. const struct ggml_tensor * src0 = dst->src[0];
  6958. const struct ggml_tensor * src1 = dst->src[1];
  6959. GGML_TENSOR_BINARY_OP_LOCALS
  6960. const int64_t nc = ne00;
  6961. const int64_t nr = ggml_nelements(src1);
  6962. assert(ne0 == nc);
  6963. assert(ne02 == ne11);
  6964. assert(nb00 == sizeof(ggml_fp16_t));
  6965. assert(ggml_nrows(dst) == nr);
  6966. const int ith = params->ith;
  6967. const int nth = params->nth;
  6968. // rows per thread
  6969. const int dr = (nr + nth - 1)/nth;
  6970. // row range for this thread
  6971. const int ir0 = dr*ith;
  6972. const int ir1 = MIN(ir0 + dr, nr);
  6973. for (int64_t i = ir0; i < ir1; ++i) {
  6974. const int64_t i12 = i/(ne11*ne10);
  6975. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  6976. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  6977. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  6978. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  6979. ggml_fp16_to_fp32_row(
  6980. (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  6981. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  6982. }
  6983. }
  6984. static void ggml_compute_forward_get_rows_bf16(
  6985. const struct ggml_compute_params * params,
  6986. struct ggml_tensor * dst) {
  6987. const struct ggml_tensor * src0 = dst->src[0];
  6988. const struct ggml_tensor * src1 = dst->src[1];
  6989. GGML_TENSOR_BINARY_OP_LOCALS
  6990. const int64_t nc = ne00;
  6991. const int64_t nr = ggml_nelements(src1);
  6992. assert(ne0 == nc);
  6993. assert(ne02 == ne11);
  6994. assert(nb00 == sizeof(ggml_bf16_t));
  6995. assert(ggml_nrows(dst) == nr);
  6996. const int ith = params->ith;
  6997. const int nth = params->nth;
  6998. // rows per thread
  6999. const int dr = (nr + nth - 1)/nth;
  7000. // row range for this thread
  7001. const int ir0 = dr*ith;
  7002. const int ir1 = MIN(ir0 + dr, nr);
  7003. for (int64_t i = ir0; i < ir1; ++i) {
  7004. const int64_t i12 = i/(ne11*ne10);
  7005. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  7006. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  7007. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  7008. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  7009. ggml_bf16_to_fp32_row(
  7010. (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
  7011. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
  7012. }
  7013. }
  7014. static void ggml_compute_forward_get_rows_f32(
  7015. const struct ggml_compute_params * params,
  7016. struct ggml_tensor * dst) {
  7017. const struct ggml_tensor * src0 = dst->src[0];
  7018. const struct ggml_tensor * src1 = dst->src[1];
  7019. GGML_TENSOR_BINARY_OP_LOCALS
  7020. const int64_t nc = ne00;
  7021. const int64_t nr = ggml_nelements(src1);
  7022. assert(ne0 == nc);
  7023. assert(ne02 == ne11);
  7024. assert(nb00 == sizeof(float));
  7025. assert(ggml_nrows(dst) == nr);
  7026. const int ith = params->ith;
  7027. const int nth = params->nth;
  7028. // rows per thread
  7029. const int dr = (nr + nth - 1)/nth;
  7030. // row range for this thread
  7031. const int ir0 = dr*ith;
  7032. const int ir1 = MIN(ir0 + dr, nr);
  7033. for (int64_t i = ir0; i < ir1; ++i) {
  7034. const int64_t i12 = i/(ne11*ne10);
  7035. const int64_t i11 = (i - i12*ne11*ne10)/ne10;
  7036. const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
  7037. const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
  7038. GGML_ASSERT(i01 >= 0 && i01 < ne01);
  7039. ggml_vec_cpy_f32(nc,
  7040. (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
  7041. (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
  7042. }
  7043. }
  7044. static void ggml_compute_forward_get_rows(
  7045. const struct ggml_compute_params * params,
  7046. struct ggml_tensor * dst) {
  7047. const struct ggml_tensor * src0 = dst->src[0];
  7048. switch (src0->type) {
  7049. case GGML_TYPE_Q4_0:
  7050. case GGML_TYPE_Q4_1:
  7051. case GGML_TYPE_Q5_0:
  7052. case GGML_TYPE_Q5_1:
  7053. case GGML_TYPE_Q8_0:
  7054. case GGML_TYPE_Q8_1:
  7055. case GGML_TYPE_Q2_K:
  7056. case GGML_TYPE_Q3_K:
  7057. case GGML_TYPE_Q4_K:
  7058. case GGML_TYPE_Q5_K:
  7059. case GGML_TYPE_Q6_K:
  7060. case GGML_TYPE_TQ1_0:
  7061. case GGML_TYPE_TQ2_0:
  7062. case GGML_TYPE_IQ2_XXS:
  7063. case GGML_TYPE_IQ2_XS:
  7064. case GGML_TYPE_IQ3_XXS:
  7065. case GGML_TYPE_IQ1_S:
  7066. case GGML_TYPE_IQ1_M:
  7067. case GGML_TYPE_IQ4_NL:
  7068. case GGML_TYPE_IQ4_XS:
  7069. case GGML_TYPE_IQ3_S:
  7070. case GGML_TYPE_IQ2_S:
  7071. case GGML_TYPE_Q4_0_4_4:
  7072. case GGML_TYPE_Q4_0_4_8:
  7073. case GGML_TYPE_Q4_0_8_8:
  7074. {
  7075. ggml_compute_forward_get_rows_q(params, dst);
  7076. } break;
  7077. case GGML_TYPE_F16:
  7078. {
  7079. ggml_compute_forward_get_rows_f16(params, dst);
  7080. } break;
  7081. case GGML_TYPE_BF16:
  7082. {
  7083. ggml_compute_forward_get_rows_bf16(params, dst);
  7084. } break;
  7085. case GGML_TYPE_F32:
  7086. case GGML_TYPE_I32:
  7087. {
  7088. ggml_compute_forward_get_rows_f32(params, dst);
  7089. } break;
  7090. default:
  7091. {
  7092. GGML_ABORT("fatal error");
  7093. }
  7094. }
  7095. //static bool first = true;
  7096. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  7097. //if (first) {
  7098. // first = false;
  7099. //} else {
  7100. // for (int k = 0; k < dst->ne[1]; ++k) {
  7101. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  7102. // for (int i = 0; i < 16; ++i) {
  7103. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  7104. // }
  7105. // printf("\n");
  7106. // }
  7107. // printf("\n");
  7108. // }
  7109. // printf("\n");
  7110. // exit(0);
  7111. //}
  7112. }
  7113. // ggml_compute_forward_get_rows_back
  7114. static void ggml_compute_forward_get_rows_back_f32_f16(
  7115. const struct ggml_compute_params * params,
  7116. struct ggml_tensor * dst) {
  7117. const struct ggml_tensor * src0 = dst->src[0];
  7118. const struct ggml_tensor * src1 = dst->src[1];
  7119. if (params->ith != 0) {
  7120. return;
  7121. }
  7122. GGML_ASSERT(ggml_is_contiguous(dst));
  7123. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  7124. memset(dst->data, 0, ggml_nbytes(dst));
  7125. const int nc = src0->ne[0];
  7126. const int nr = ggml_nelements(src1);
  7127. GGML_ASSERT( dst->ne[0] == nc);
  7128. GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
  7129. for (int i = 0; i < nr; ++i) {
  7130. const int r = ((int32_t *) src1->data)[i];
  7131. for (int j = 0; j < nc; ++j) {
  7132. ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
  7133. ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
  7134. }
  7135. }
  7136. }
  7137. static void ggml_compute_forward_get_rows_back_f32(
  7138. const struct ggml_compute_params * params,
  7139. struct ggml_tensor * dst) {
  7140. const struct ggml_tensor * src0 = dst->src[0];
  7141. const struct ggml_tensor * src1 = dst->src[1];
  7142. if (params->ith != 0) {
  7143. return;
  7144. }
  7145. GGML_ASSERT(ggml_is_contiguous(dst));
  7146. // ggml_compute_forward_dup_same_cont(params, opt0, dst);
  7147. memset(dst->data, 0, ggml_nbytes(dst));
  7148. const int nc = src0->ne[0];
  7149. const int nr = ggml_nelements(src1);
  7150. GGML_ASSERT( dst->ne[0] == nc);
  7151. GGML_ASSERT(src0->nb[0] == sizeof(float));
  7152. for (int i = 0; i < nr; ++i) {
  7153. const int r = ((int32_t *) src1->data)[i];
  7154. ggml_vec_add_f32(nc,
  7155. (float *) ((char *) dst->data + r*dst->nb[1]),
  7156. (float *) ((char *) dst->data + r*dst->nb[1]),
  7157. (float *) ((char *) src0->data + i*src0->nb[1]));
  7158. }
  7159. }
  7160. static void ggml_compute_forward_get_rows_back(
  7161. const struct ggml_compute_params * params,
  7162. struct ggml_tensor * dst) {
  7163. const struct ggml_tensor * src0 = dst->src[0];
  7164. switch (src0->type) {
  7165. case GGML_TYPE_F16:
  7166. {
  7167. ggml_compute_forward_get_rows_back_f32_f16(params, dst);
  7168. } break;
  7169. case GGML_TYPE_F32:
  7170. {
  7171. ggml_compute_forward_get_rows_back_f32(params, dst);
  7172. } break;
  7173. default:
  7174. {
  7175. GGML_ABORT("fatal error");
  7176. }
  7177. }
  7178. //static bool first = true;
  7179. //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
  7180. //if (first) {
  7181. // first = false;
  7182. //} else {
  7183. // for (int k = 0; k < dst->ne[1]; ++k) {
  7184. // for (int j = 0; j < dst->ne[0]/16; ++j) {
  7185. // for (int i = 0; i < 16; ++i) {
  7186. // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
  7187. // }
  7188. // printf("\n");
  7189. // }
  7190. // printf("\n");
  7191. // }
  7192. // printf("\n");
  7193. // exit(0);
  7194. //}
  7195. }
  7196. // ggml_compute_forward_diag
  7197. static void ggml_compute_forward_diag_f32(
  7198. const struct ggml_compute_params * params,
  7199. struct ggml_tensor * dst) {
  7200. const struct ggml_tensor * src0 = dst->src[0];
  7201. if (params->ith != 0) {
  7202. return;
  7203. }
  7204. // TODO: handle transposed/permuted matrices
  7205. GGML_TENSOR_UNARY_OP_LOCALS
  7206. GGML_ASSERT(ne00 == ne0);
  7207. GGML_ASSERT(ne00 == ne1);
  7208. GGML_ASSERT(ne01 == 1);
  7209. GGML_ASSERT(ne02 == ne2);
  7210. GGML_ASSERT(ne03 == ne3);
  7211. GGML_ASSERT(nb00 == sizeof(float));
  7212. GGML_ASSERT(nb0 == sizeof(float));
  7213. for (int i3 = 0; i3 < ne3; i3++) {
  7214. for (int i2 = 0; i2 < ne2; i2++) {
  7215. for (int i1 = 0; i1 < ne1; i1++) {
  7216. float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
  7217. float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
  7218. for (int i0 = 0; i0 < i1; i0++) {
  7219. d[i0] = 0;
  7220. }
  7221. d[i1] = s[i1];
  7222. for (int i0 = i1+1; i0 < ne0; i0++) {
  7223. d[i0] = 0;
  7224. }
  7225. }
  7226. }
  7227. }
  7228. }
  7229. static void ggml_compute_forward_diag(
  7230. const struct ggml_compute_params * params,
  7231. struct ggml_tensor * dst) {
  7232. const struct ggml_tensor * src0 = dst->src[0];
  7233. switch (src0->type) {
  7234. case GGML_TYPE_F32:
  7235. {
  7236. ggml_compute_forward_diag_f32(params, dst);
  7237. } break;
  7238. default:
  7239. {
  7240. GGML_ABORT("fatal error");
  7241. }
  7242. }
  7243. }
  7244. // ggml_compute_forward_diag_mask_inf
  7245. static void ggml_compute_forward_diag_mask_f32(
  7246. const struct ggml_compute_params * params,
  7247. struct ggml_tensor * dst,
  7248. const float value) {
  7249. const struct ggml_tensor * src0 = dst->src[0];
  7250. const int ith = params->ith;
  7251. const int nth = params->nth;
  7252. const int n_past = ((int32_t *) dst->op_params)[0];
  7253. const bool inplace = src0->data == dst->data;
  7254. GGML_ASSERT(n_past >= 0);
  7255. if (!inplace) {
  7256. if (ith == 0) {
  7257. // memcpy needs to be synchronized across threads to avoid race conditions.
  7258. // => do it in INIT phase
  7259. GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  7260. GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
  7261. memcpy(
  7262. ((char *) dst->data),
  7263. ((char *) src0->data),
  7264. ggml_nbytes(dst));
  7265. }
  7266. ggml_barrier(params->threadpool);
  7267. }
  7268. // TODO: handle transposed/permuted matrices
  7269. const int n = ggml_nrows(src0);
  7270. const int nc = src0->ne[0];
  7271. const int nr = src0->ne[1];
  7272. const int nz = n/nr;
  7273. GGML_ASSERT( dst->nb[0] == sizeof(float));
  7274. GGML_ASSERT(src0->nb[0] == sizeof(float));
  7275. for (int k = 0; k < nz; k++) {
  7276. for (int j = ith; j < nr; j += nth) {
  7277. for (int i = n_past; i < nc; i++) {
  7278. if (i > n_past + j) {
  7279. *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
  7280. }
  7281. }
  7282. }
  7283. }
  7284. }
  7285. static void ggml_compute_forward_diag_mask_inf(
  7286. const struct ggml_compute_params * params,
  7287. struct ggml_tensor * dst) {
  7288. const struct ggml_tensor * src0 = dst->src[0];
  7289. switch (src0->type) {
  7290. case GGML_TYPE_F32:
  7291. {
  7292. ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
  7293. } break;
  7294. default:
  7295. {
  7296. GGML_ABORT("fatal error");
  7297. }
  7298. }
  7299. }
  7300. static void ggml_compute_forward_diag_mask_zero(
  7301. const struct ggml_compute_params * params,
  7302. struct ggml_tensor * dst) {
  7303. const struct ggml_tensor * src0 = dst->src[0];
  7304. switch (src0->type) {
  7305. case GGML_TYPE_F32:
  7306. {
  7307. ggml_compute_forward_diag_mask_f32(params, dst, 0);
  7308. } break;
  7309. default:
  7310. {
  7311. GGML_ABORT("fatal error");
  7312. }
  7313. }
  7314. }
  7315. // ggml_compute_forward_soft_max
  7316. static void ggml_compute_forward_soft_max_f32(
  7317. const struct ggml_compute_params * params,
  7318. struct ggml_tensor * dst) {
  7319. const struct ggml_tensor * src0 = dst->src[0];
  7320. const struct ggml_tensor * src1 = dst->src[1];
  7321. assert(ggml_is_contiguous(dst));
  7322. assert(ggml_are_same_shape(src0, dst));
  7323. float scale = 1.0f;
  7324. float max_bias = 0.0f;
  7325. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  7326. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  7327. // TODO: handle transposed/permuted matrices
  7328. const int ith = params->ith;
  7329. const int nth = params->nth;
  7330. GGML_TENSOR_UNARY_OP_LOCALS
  7331. //const int64_t ne11 = src1 ? src1->ne[1] : 1;
  7332. // TODO: is this supposed to be ceil instead of floor?
  7333. // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
  7334. const uint32_t n_head = ne02;
  7335. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  7336. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  7337. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  7338. const int nc = src0->ne[0];
  7339. const int nr = ggml_nrows(src0);
  7340. // rows per thread
  7341. const int dr = (nr + nth - 1)/nth;
  7342. // row range for this thread
  7343. const int ir0 = dr*ith;
  7344. const int ir1 = MIN(ir0 + dr, nr);
  7345. float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
  7346. const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
  7347. for (int i1 = ir0; i1 < ir1; i1++) {
  7348. // ALiBi
  7349. const uint32_t h = (i1/ne01)%ne02; // head
  7350. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  7351. float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
  7352. float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
  7353. // broadcast the mask across rows
  7354. ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
  7355. float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
  7356. ggml_vec_cpy_f32 (nc, wp, sp);
  7357. ggml_vec_scale_f32(nc, wp, scale);
  7358. if (mp_f32) {
  7359. if (use_f16) {
  7360. for (int i = 0; i < nc; ++i) {
  7361. wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
  7362. }
  7363. } else {
  7364. for (int i = 0; i < nc; ++i) {
  7365. wp[i] += slope*mp_f32[i];
  7366. }
  7367. }
  7368. }
  7369. #ifndef NDEBUG
  7370. for (int i = 0; i < nc; ++i) {
  7371. //printf("p[%d] = %f\n", i, p[i]);
  7372. assert(!isnan(wp[i]));
  7373. }
  7374. #endif
  7375. float max = -INFINITY;
  7376. ggml_vec_max_f32(nc, &max, wp);
  7377. ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
  7378. assert(sum > 0.0);
  7379. sum = 1.0/sum;
  7380. ggml_vec_scale_f32(nc, dp, sum);
  7381. #ifndef NDEBUG
  7382. for (int i = 0; i < nc; ++i) {
  7383. assert(!isnan(dp[i]));
  7384. assert(!isinf(dp[i]));
  7385. }
  7386. #endif
  7387. }
  7388. }
  7389. static void ggml_compute_forward_soft_max(
  7390. const struct ggml_compute_params * params,
  7391. struct ggml_tensor * dst) {
  7392. const struct ggml_tensor * src0 = dst->src[0];
  7393. switch (src0->type) {
  7394. case GGML_TYPE_F32:
  7395. {
  7396. ggml_compute_forward_soft_max_f32(params, dst);
  7397. } break;
  7398. default:
  7399. {
  7400. GGML_ABORT("fatal error");
  7401. }
  7402. }
  7403. }
  7404. // ggml_compute_forward_soft_max_back
  7405. static void ggml_compute_forward_soft_max_back_f32(
  7406. const struct ggml_compute_params * params,
  7407. struct ggml_tensor * dst) {
  7408. const struct ggml_tensor * src0 = dst->src[0];
  7409. const struct ggml_tensor * src1 = dst->src[1];
  7410. GGML_ASSERT(ggml_is_contiguous(src0));
  7411. GGML_ASSERT(ggml_is_contiguous(src1));
  7412. GGML_ASSERT(ggml_is_contiguous(dst));
  7413. GGML_ASSERT(ggml_are_same_shape(src0, dst));
  7414. GGML_ASSERT(ggml_are_same_shape(src1, dst));
  7415. // TODO: handle transposed/permuted matrices
  7416. const int ith = params->ith;
  7417. const int nth = params->nth;
  7418. const int nc = src0->ne[0];
  7419. const int nr = ggml_nrows(src0);
  7420. // rows per thread
  7421. const int dr = (nr + nth - 1)/nth;
  7422. // row range for this thread
  7423. const int ir0 = dr*ith;
  7424. const int ir1 = MIN(ir0 + dr, nr);
  7425. for (int i1 = ir0; i1 < ir1; i1++) {
  7426. float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
  7427. float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
  7428. float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
  7429. #ifndef NDEBUG
  7430. for (int i = 0; i < nc; ++i) {
  7431. //printf("p[%d] = %f\n", i, p[i]);
  7432. assert(!isnan(dy[i]));
  7433. assert(!isnan(y[i]));
  7434. }
  7435. #endif
  7436. // Jii = yi - yi*yi
  7437. // Jij = -yi*yj
  7438. // J = diag(y)-y.T*y
  7439. // dx = J * dy
  7440. // dxk = sum_i(Jki * dyi)
  7441. // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
  7442. // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
  7443. // dxk = sum_i(-yk*yi * dyi) + yk*dyk
  7444. // dxk = -yk * sum_i(yi * dyi) + yk*dyk
  7445. // dxk = -yk * dot(y, dy) + yk*dyk
  7446. // dxk = yk * (- dot(y, dy) + dyk)
  7447. // dxk = yk * (dyk - dot(y, dy))
  7448. //
  7449. // post-order:
  7450. // dot_y_dy := dot(y, dy)
  7451. // dx := dy
  7452. // dx := dx - dot_y_dy
  7453. // dx := dx * y
  7454. // linear runtime, no additional memory
  7455. float dot_y_dy = 0;
  7456. ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
  7457. ggml_vec_cpy_f32 (nc, dx, dy);
  7458. ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
  7459. ggml_vec_mul_f32 (nc, dx, dx, y);
  7460. #ifndef NDEBUG
  7461. for (int i = 0; i < nc; ++i) {
  7462. assert(!isnan(dx[i]));
  7463. assert(!isinf(dx[i]));
  7464. }
  7465. #endif
  7466. }
  7467. }
  7468. static void ggml_compute_forward_soft_max_back(
  7469. const struct ggml_compute_params * params,
  7470. struct ggml_tensor * dst) {
  7471. const struct ggml_tensor * src0 = dst->src[0];
  7472. switch (src0->type) {
  7473. case GGML_TYPE_F32:
  7474. {
  7475. ggml_compute_forward_soft_max_back_f32(params, dst);
  7476. } break;
  7477. default:
  7478. {
  7479. GGML_ABORT("fatal error");
  7480. }
  7481. }
  7482. }
  7483. // ggml_compute_forward_clamp
  7484. static void ggml_compute_forward_clamp_f32(
  7485. const struct ggml_compute_params * params,
  7486. struct ggml_tensor * dst) {
  7487. const struct ggml_tensor * src0 = dst->src[0];
  7488. if (params->ith != 0) {
  7489. return;
  7490. }
  7491. float min;
  7492. float max;
  7493. memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
  7494. memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
  7495. const int ith = params->ith;
  7496. const int nth = params->nth;
  7497. const int n = ggml_nrows(src0);
  7498. const int nc = src0->ne[0];
  7499. const size_t nb00 = src0->nb[0];
  7500. const size_t nb01 = src0->nb[1];
  7501. const size_t nb0 = dst->nb[0];
  7502. const size_t nb1 = dst->nb[1];
  7503. GGML_ASSERT( nb0 == sizeof(float));
  7504. GGML_ASSERT(nb00 == sizeof(float));
  7505. for (int j = ith; j < n; j += nth) {
  7506. float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
  7507. float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
  7508. for (int i = 0; i < nc; i++) {
  7509. dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
  7510. }
  7511. }
  7512. }
  7513. static void ggml_compute_forward_clamp(
  7514. const struct ggml_compute_params * params,
  7515. struct ggml_tensor * dst) {
  7516. const struct ggml_tensor * src0 = dst->src[0];
  7517. switch (src0->type) {
  7518. case GGML_TYPE_F32:
  7519. {
  7520. ggml_compute_forward_clamp_f32(params, dst);
  7521. } break;
  7522. case GGML_TYPE_F16:
  7523. case GGML_TYPE_BF16:
  7524. case GGML_TYPE_Q4_0:
  7525. case GGML_TYPE_Q4_1:
  7526. case GGML_TYPE_Q5_0:
  7527. case GGML_TYPE_Q5_1:
  7528. case GGML_TYPE_Q8_0:
  7529. case GGML_TYPE_Q8_1:
  7530. case GGML_TYPE_Q2_K:
  7531. case GGML_TYPE_Q3_K:
  7532. case GGML_TYPE_Q4_K:
  7533. case GGML_TYPE_Q5_K:
  7534. case GGML_TYPE_Q6_K:
  7535. case GGML_TYPE_TQ1_0:
  7536. case GGML_TYPE_TQ2_0:
  7537. case GGML_TYPE_IQ2_XXS:
  7538. case GGML_TYPE_IQ2_XS:
  7539. case GGML_TYPE_IQ3_XXS:
  7540. case GGML_TYPE_IQ1_S:
  7541. case GGML_TYPE_IQ1_M:
  7542. case GGML_TYPE_IQ4_NL:
  7543. case GGML_TYPE_IQ4_XS:
  7544. case GGML_TYPE_IQ3_S:
  7545. case GGML_TYPE_IQ2_S:
  7546. case GGML_TYPE_Q8_K:
  7547. case GGML_TYPE_Q4_0_4_4:
  7548. case GGML_TYPE_Q4_0_4_8:
  7549. case GGML_TYPE_Q4_0_8_8:
  7550. case GGML_TYPE_I8:
  7551. case GGML_TYPE_I16:
  7552. case GGML_TYPE_I32:
  7553. case GGML_TYPE_I64:
  7554. case GGML_TYPE_F64:
  7555. case GGML_TYPE_COUNT:
  7556. {
  7557. GGML_ABORT("fatal error");
  7558. }
  7559. }
  7560. }
  7561. // ggml_compute_forward_rope
  7562. static float rope_yarn_ramp(const float low, const float high, const int i0) {
  7563. const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
  7564. return 1 - MIN(1, MAX(0, y));
  7565. }
  7566. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  7567. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  7568. static void rope_yarn(
  7569. float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
  7570. float * cos_theta, float * sin_theta) {
  7571. // Get n-d rotational scaling corrected for extrapolation
  7572. float theta_interp = freq_scale * theta_extrap;
  7573. float theta = theta_interp;
  7574. if (ext_factor != 0.0f) {
  7575. float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
  7576. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  7577. // Get n-d magnitude scaling corrected for interpolation
  7578. mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  7579. }
  7580. *cos_theta = cosf(theta) * mscale;
  7581. *sin_theta = sinf(theta) * mscale;
  7582. }
  7583. // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
  7584. // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
  7585. static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
  7586. return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
  7587. }
  7588. static void ggml_rope_cache_init(
  7589. float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
  7590. float * cache, float sin_sign, float theta_scale) {
  7591. // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
  7592. float theta = theta_base;
  7593. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  7594. const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
  7595. rope_yarn(
  7596. theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
  7597. );
  7598. cache[i0 + 1] *= sin_sign;
  7599. theta *= theta_scale;
  7600. }
  7601. }
  7602. void ggml_rope_yarn_corr_dims(
  7603. int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
  7604. ) {
  7605. // start and end correction dims
  7606. float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
  7607. float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
  7608. dims[0] = MAX(0, start);
  7609. dims[1] = MIN(n_dims - 1, end);
  7610. }
  7611. static void ggml_compute_forward_rope_f32(
  7612. const struct ggml_compute_params * params,
  7613. struct ggml_tensor * dst,
  7614. const bool forward) {
  7615. const struct ggml_tensor * src0 = dst->src[0];
  7616. const struct ggml_tensor * src1 = dst->src[1];
  7617. const struct ggml_tensor * src2 = dst->src[2];
  7618. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  7619. //const int n_past = ((int32_t *) dst->op_params)[0];
  7620. const int n_dims = ((int32_t *) dst->op_params)[1];
  7621. const int mode = ((int32_t *) dst->op_params)[2];
  7622. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  7623. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  7624. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  7625. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  7626. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  7627. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  7628. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  7629. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  7630. GGML_TENSOR_UNARY_OP_LOCALS
  7631. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  7632. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  7633. GGML_ASSERT(nb00 == sizeof(float));
  7634. const int ith = params->ith;
  7635. const int nth = params->nth;
  7636. const int nr = ggml_nrows(dst);
  7637. GGML_ASSERT(n_dims <= ne0);
  7638. GGML_ASSERT(n_dims % 2 == 0);
  7639. // rows per thread
  7640. const int dr = (nr + nth - 1)/nth;
  7641. // row range for this thread
  7642. const int ir0 = dr*ith;
  7643. const int ir1 = MIN(ir0 + dr, nr);
  7644. // row index used to determine which thread to use
  7645. int ir = 0;
  7646. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  7647. float corr_dims[2];
  7648. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  7649. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  7650. const float * freq_factors = NULL;
  7651. if (src2 != NULL) {
  7652. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  7653. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  7654. freq_factors = (const float *) src2->data;
  7655. }
  7656. // backward process uses inverse rotation by cos and sin.
  7657. // cos and sin build a rotation matrix, where the inverse is the transpose.
  7658. // this essentially just switches the sign of sin.
  7659. const float sin_sign = forward ? 1.0f : -1.0f;
  7660. const int32_t * pos = (const int32_t *) src1->data;
  7661. for (int64_t i3 = 0; i3 < ne3; i3++) {
  7662. for (int64_t i2 = 0; i2 < ne2; i2++) {
  7663. const int64_t p = pos[i2];
  7664. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  7665. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  7666. for (int64_t i1 = 0; i1 < ne1; i1++) {
  7667. if (ir++ < ir0) continue;
  7668. if (ir > ir1) break;
  7669. if (!is_neox) {
  7670. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  7671. const float cos_theta = cache[i0 + 0];
  7672. const float sin_theta = cache[i0 + 1];
  7673. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  7674. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  7675. const float x0 = src[0];
  7676. const float x1 = src[1];
  7677. dst_data[0] = x0*cos_theta - x1*sin_theta;
  7678. dst_data[1] = x0*sin_theta + x1*cos_theta;
  7679. }
  7680. } else {
  7681. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  7682. const int64_t ic = i0/2;
  7683. const float cos_theta = cache[i0 + 0];
  7684. const float sin_theta = cache[i0 + 1];
  7685. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  7686. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  7687. const float x0 = src[0];
  7688. const float x1 = src[n_dims/2];
  7689. dst_data[0] = x0*cos_theta - x1*sin_theta;
  7690. dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
  7691. }
  7692. }
  7693. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  7694. const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  7695. float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  7696. dst_data[0] = src[0];
  7697. dst_data[1] = src[1];
  7698. }
  7699. }
  7700. }
  7701. }
  7702. }
  7703. // TODO: deduplicate f16/f32 code
  7704. static void ggml_compute_forward_rope_f16(
  7705. const struct ggml_compute_params * params,
  7706. struct ggml_tensor * dst,
  7707. const bool forward) {
  7708. const struct ggml_tensor * src0 = dst->src[0];
  7709. const struct ggml_tensor * src1 = dst->src[1];
  7710. const struct ggml_tensor * src2 = dst->src[2];
  7711. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  7712. //const int n_past = ((int32_t *) dst->op_params)[0];
  7713. const int n_dims = ((int32_t *) dst->op_params)[1];
  7714. const int mode = ((int32_t *) dst->op_params)[2];
  7715. //const int n_ctx = ((int32_t *) dst->op_params)[3];
  7716. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  7717. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  7718. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  7719. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  7720. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  7721. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  7722. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  7723. GGML_TENSOR_UNARY_OP_LOCALS
  7724. //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
  7725. //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
  7726. GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
  7727. const int ith = params->ith;
  7728. const int nth = params->nth;
  7729. const int nr = ggml_nrows(dst);
  7730. GGML_ASSERT(n_dims <= ne0);
  7731. GGML_ASSERT(n_dims % 2 == 0);
  7732. // rows per thread
  7733. const int dr = (nr + nth - 1)/nth;
  7734. // row range for this thread
  7735. const int ir0 = dr*ith;
  7736. const int ir1 = MIN(ir0 + dr, nr);
  7737. // row index used to determine which thread to use
  7738. int ir = 0;
  7739. const float theta_scale = powf(freq_base, -2.0f/n_dims);
  7740. float corr_dims[2];
  7741. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  7742. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  7743. const float * freq_factors = NULL;
  7744. if (src2 != NULL) {
  7745. GGML_ASSERT(src2->type == GGML_TYPE_F32);
  7746. GGML_ASSERT(src2->ne[0] >= n_dims / 2);
  7747. freq_factors = (const float *) src2->data;
  7748. }
  7749. // backward process uses inverse rotation by cos and sin.
  7750. // cos and sin build a rotation matrix, where the inverse is the transpose.
  7751. // this essentially just switches the sign of sin.
  7752. const float sin_sign = forward ? 1.0f : -1.0f;
  7753. const int32_t * pos = (const int32_t *) src1->data;
  7754. for (int64_t i3 = 0; i3 < ne3; i3++) {
  7755. for (int64_t i2 = 0; i2 < ne2; i2++) {
  7756. const int64_t p = pos[i2];
  7757. float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
  7758. ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
  7759. for (int64_t i1 = 0; i1 < ne1; i1++) {
  7760. if (ir++ < ir0) continue;
  7761. if (ir > ir1) break;
  7762. if (!is_neox) {
  7763. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  7764. const float cos_theta = cache[i0 + 0];
  7765. const float sin_theta = cache[i0 + 1];
  7766. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  7767. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  7768. const float x0 = GGML_FP16_TO_FP32(src[0]);
  7769. const float x1 = GGML_FP16_TO_FP32(src[1]);
  7770. dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  7771. dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  7772. }
  7773. } else {
  7774. for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
  7775. const int64_t ic = i0/2;
  7776. const float cos_theta = cache[i0 + 0];
  7777. const float sin_theta = cache[i0 + 1];
  7778. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
  7779. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
  7780. const float x0 = GGML_FP16_TO_FP32(src[0]);
  7781. const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
  7782. dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
  7783. dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
  7784. }
  7785. }
  7786. for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
  7787. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  7788. ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  7789. dst_data[0] = src[0];
  7790. dst_data[1] = src[1];
  7791. }
  7792. }
  7793. }
  7794. }
  7795. }
  7796. static void ggml_compute_forward_rope(
  7797. const struct ggml_compute_params * params,
  7798. struct ggml_tensor * dst) {
  7799. const struct ggml_tensor * src0 = dst->src[0];
  7800. switch (src0->type) {
  7801. case GGML_TYPE_F16:
  7802. {
  7803. ggml_compute_forward_rope_f16(params, dst, true);
  7804. } break;
  7805. case GGML_TYPE_F32:
  7806. {
  7807. ggml_compute_forward_rope_f32(params, dst, true);
  7808. } break;
  7809. default:
  7810. {
  7811. GGML_ABORT("fatal error");
  7812. }
  7813. }
  7814. }
  7815. // ggml_compute_forward_rope_back
  7816. static void ggml_compute_forward_rope_back(
  7817. const struct ggml_compute_params * params,
  7818. struct ggml_tensor * dst) {
  7819. const struct ggml_tensor * src0 = dst->src[0];
  7820. switch (src0->type) {
  7821. case GGML_TYPE_F16:
  7822. {
  7823. ggml_compute_forward_rope_f16(params, dst, false);
  7824. } break;
  7825. case GGML_TYPE_F32:
  7826. {
  7827. ggml_compute_forward_rope_f32(params, dst, false);
  7828. } break;
  7829. default:
  7830. {
  7831. GGML_ABORT("fatal error");
  7832. }
  7833. }
  7834. }
  7835. // ggml_compute_forward_conv_transpose_1d
  7836. static void ggml_compute_forward_conv_transpose_1d_f16_f32(
  7837. const struct ggml_compute_params * params,
  7838. struct ggml_tensor * dst) {
  7839. const struct ggml_tensor * src0 = dst->src[0];
  7840. const struct ggml_tensor * src1 = dst->src[1];
  7841. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  7842. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  7843. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  7844. GGML_TENSOR_BINARY_OP_LOCALS
  7845. const int ith = params->ith;
  7846. const int nth = params->nth;
  7847. const int nk = ne00*ne01*ne02;
  7848. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  7849. GGML_ASSERT(nb10 == sizeof(float));
  7850. if (ith == 0) {
  7851. memset(params->wdata, 0, params->wsize);
  7852. // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  7853. {
  7854. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  7855. for (int64_t i02 = 0; i02 < ne02; i02++) {
  7856. for (int64_t i01 = 0; i01 < ne01; i01++) {
  7857. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
  7858. ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
  7859. for (int64_t i00 = 0; i00 < ne00; i00++) {
  7860. dst_data[i00*ne02 + i02] = src[i00];
  7861. }
  7862. }
  7863. }
  7864. }
  7865. // permute source data (src1) from (L x Cin) to (Cin x L)
  7866. {
  7867. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  7868. ggml_fp16_t * dst_data = wdata;
  7869. for (int64_t i11 = 0; i11 < ne11; i11++) {
  7870. const float * const src = (float *)((char *) src1->data + i11*nb11);
  7871. for (int64_t i10 = 0; i10 < ne10; i10++) {
  7872. dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
  7873. }
  7874. }
  7875. }
  7876. // need to zero dst since we are accumulating into it
  7877. memset(dst->data, 0, ggml_nbytes(dst));
  7878. }
  7879. ggml_barrier(params->threadpool);
  7880. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  7881. // total rows in dst
  7882. const int nr = ne1;
  7883. // rows per thread
  7884. const int dr = (nr + nth - 1)/nth;
  7885. // row range for this thread
  7886. const int ir0 = dr*ith;
  7887. const int ir1 = MIN(ir0 + dr, nr);
  7888. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  7889. ggml_fp16_t * const wdata_src = wdata + nk;
  7890. for (int i1 = ir0; i1 < ir1; i1++) {
  7891. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  7892. ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
  7893. for (int i10 = 0; i10 < ne10; i10++) {
  7894. const int i1n = i10*ne11;
  7895. for (int i00 = 0; i00 < ne00; i00++) {
  7896. float v = 0;
  7897. ggml_vec_dot_f16(ne02, &v, 0,
  7898. (ggml_fp16_t *) wdata_src + i1n, 0,
  7899. (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
  7900. dst_data[i10*s0 + i00] += v;
  7901. }
  7902. }
  7903. }
  7904. }
  7905. static void ggml_compute_forward_conv_transpose_1d_f32(
  7906. const struct ggml_compute_params * params,
  7907. struct ggml_tensor * dst) {
  7908. const struct ggml_tensor * src0 = dst->src[0];
  7909. const struct ggml_tensor * src1 = dst->src[1];
  7910. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  7911. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  7912. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  7913. GGML_TENSOR_BINARY_OP_LOCALS
  7914. const int ith = params->ith;
  7915. const int nth = params->nth;
  7916. const int nk = ne00*ne01*ne02;
  7917. GGML_ASSERT(nb00 == sizeof(float));
  7918. GGML_ASSERT(nb10 == sizeof(float));
  7919. if (ith == 0) {
  7920. memset(params->wdata, 0, params->wsize);
  7921. // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
  7922. {
  7923. float * const wdata = (float *) params->wdata + 0;
  7924. for (int64_t i02 = 0; i02 < ne02; i02++) {
  7925. for (int64_t i01 = 0; i01 < ne01; i01++) {
  7926. const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
  7927. float * dst_data = wdata + i01*ne00*ne02;
  7928. for (int64_t i00 = 0; i00 < ne00; i00++) {
  7929. dst_data[i00*ne02 + i02] = src[i00];
  7930. }
  7931. }
  7932. }
  7933. }
  7934. // prepare source data (src1)
  7935. {
  7936. float * const wdata = (float *) params->wdata + nk;
  7937. float * dst_data = wdata;
  7938. for (int64_t i11 = 0; i11 < ne11; i11++) {
  7939. const float * const src = (float *)((char *) src1->data + i11*nb11);
  7940. for (int64_t i10 = 0; i10 < ne10; i10++) {
  7941. dst_data[i10*ne11 + i11] = src[i10];
  7942. }
  7943. }
  7944. }
  7945. // need to zero dst since we are accumulating into it
  7946. memset(dst->data, 0, ggml_nbytes(dst));
  7947. }
  7948. ggml_barrier(params->threadpool);
  7949. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  7950. // total rows in dst
  7951. const int nr = ne1;
  7952. // rows per thread
  7953. const int dr = (nr + nth - 1)/nth;
  7954. // row range for this thread
  7955. const int ir0 = dr*ith;
  7956. const int ir1 = MIN(ir0 + dr, nr);
  7957. float * const wdata = (float *) params->wdata + 0;
  7958. float * const wdata_src = wdata + nk;
  7959. for (int i1 = ir0; i1 < ir1; i1++) {
  7960. float * dst_data = (float *)((char *) dst->data + i1*nb1);
  7961. float * wdata_kernel = wdata + i1*ne02*ne00;
  7962. for (int i10 = 0; i10 < ne10; i10++) {
  7963. const int i1n = i10*ne11;
  7964. for (int i00 = 0; i00 < ne00; i00++) {
  7965. float v = 0;
  7966. ggml_vec_dot_f32(ne02, &v, 0,
  7967. wdata_src + i1n, 0,
  7968. wdata_kernel + i00*ne02, 0, 1);
  7969. dst_data[i10*s0 + i00] += v;
  7970. }
  7971. }
  7972. }
  7973. }
  7974. static void ggml_compute_forward_conv_transpose_1d(
  7975. const struct ggml_compute_params * params,
  7976. struct ggml_tensor * dst) {
  7977. const struct ggml_tensor * src0 = dst->src[0];
  7978. switch (src0->type) {
  7979. case GGML_TYPE_F16:
  7980. {
  7981. ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
  7982. } break;
  7983. case GGML_TYPE_F32:
  7984. {
  7985. ggml_compute_forward_conv_transpose_1d_f32(params, dst);
  7986. } break;
  7987. default:
  7988. {
  7989. GGML_ABORT("fatal error");
  7990. }
  7991. }
  7992. }
  7993. // ggml_compute_forward_im2col_f32
  7994. // src0: kernel [OC, IC, KH, KW]
  7995. // src1: image [N, IC, IH, IW]
  7996. // dst: result [N, OH, OW, IC*KH*KW]
  7997. static void ggml_compute_forward_im2col_f32(
  7998. const struct ggml_compute_params * params,
  7999. struct ggml_tensor * dst) {
  8000. const struct ggml_tensor * src0 = dst->src[0];
  8001. const struct ggml_tensor * src1 = dst->src[1];
  8002. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  8003. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  8004. GGML_TENSOR_BINARY_OP_LOCALS;
  8005. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  8006. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  8007. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  8008. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  8009. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  8010. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  8011. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  8012. const int ith = params->ith;
  8013. const int nth = params->nth;
  8014. const int64_t N = is_2D ? ne13 : ne12;
  8015. const int64_t IC = is_2D ? ne12 : ne11;
  8016. const int64_t IH = is_2D ? ne11 : 1;
  8017. const int64_t IW = ne10;
  8018. const int64_t KH = is_2D ? ne01 : 1;
  8019. const int64_t KW = ne00;
  8020. const int64_t OH = is_2D ? ne2 : 1;
  8021. const int64_t OW = ne1;
  8022. int ofs0 = is_2D ? nb13 : nb12;
  8023. int ofs1 = is_2D ? nb12 : nb11;
  8024. GGML_ASSERT(nb10 == sizeof(float));
  8025. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  8026. {
  8027. float * const wdata = (float *) dst->data;
  8028. for (int64_t in = 0; in < N; in++) {
  8029. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  8030. for (int64_t iow = 0; iow < OW; iow++) {
  8031. for (int64_t iic = ith; iic < IC; iic += nth) {
  8032. // micro kernel
  8033. float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  8034. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  8035. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  8036. for (int64_t ikw = 0; ikw < KW; ikw++) {
  8037. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  8038. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  8039. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  8040. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  8041. } else {
  8042. dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
  8043. }
  8044. }
  8045. }
  8046. }
  8047. }
  8048. }
  8049. }
  8050. }
  8051. }
  8052. // ggml_compute_forward_im2col_f16
  8053. // src0: kernel [OC, IC, KH, KW]
  8054. // src1: image [N, IC, IH, IW]
  8055. // dst: result [N, OH, OW, IC*KH*KW]
  8056. static void ggml_compute_forward_im2col_f16(
  8057. const struct ggml_compute_params * params,
  8058. struct ggml_tensor * dst) {
  8059. const struct ggml_tensor * src0 = dst->src[0];
  8060. const struct ggml_tensor * src1 = dst->src[1];
  8061. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  8062. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  8063. GGML_ASSERT( dst->type == GGML_TYPE_F16);
  8064. GGML_TENSOR_BINARY_OP_LOCALS;
  8065. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  8066. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  8067. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  8068. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  8069. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  8070. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  8071. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  8072. const int ith = params->ith;
  8073. const int nth = params->nth;
  8074. const int64_t N = is_2D ? ne13 : ne12;
  8075. const int64_t IC = is_2D ? ne12 : ne11;
  8076. const int64_t IH = is_2D ? ne11 : 1;
  8077. const int64_t IW = ne10;
  8078. const int64_t KH = is_2D ? ne01 : 1;
  8079. const int64_t KW = ne00;
  8080. const int64_t OH = is_2D ? ne2 : 1;
  8081. const int64_t OW = ne1;
  8082. int ofs0 = is_2D ? nb13 : nb12;
  8083. int ofs1 = is_2D ? nb12 : nb11;
  8084. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  8085. GGML_ASSERT(nb10 == sizeof(float));
  8086. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  8087. {
  8088. ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
  8089. for (int64_t in = 0; in < N; in++) {
  8090. for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
  8091. for (int64_t iow = 0; iow < OW; iow++) {
  8092. for (int64_t iic = ith; iic < IC; iic += nth) {
  8093. // micro kernel
  8094. ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  8095. const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
  8096. for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
  8097. for (int64_t ikw = 0; ikw < KW; ikw++) {
  8098. const int64_t iiw = iow*s0 + ikw*d0 - p0;
  8099. const int64_t iih = ioh*s1 + ikh*d1 - p1;
  8100. if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
  8101. dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
  8102. } else {
  8103. dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
  8104. }
  8105. }
  8106. }
  8107. }
  8108. }
  8109. }
  8110. }
  8111. }
  8112. }
  8113. static void ggml_compute_forward_im2col(
  8114. const struct ggml_compute_params * params,
  8115. struct ggml_tensor * dst) {
  8116. switch (dst->type) {
  8117. case GGML_TYPE_F16:
  8118. {
  8119. ggml_compute_forward_im2col_f16(params, dst);
  8120. } break;
  8121. case GGML_TYPE_F32:
  8122. {
  8123. ggml_compute_forward_im2col_f32(params, dst);
  8124. } break;
  8125. default:
  8126. {
  8127. GGML_ABORT("fatal error");
  8128. }
  8129. }
  8130. }
  8131. // ggml_compute_forward_im2col_back_f32
  8132. static void ggml_compute_forward_im2col_back_f32(
  8133. const struct ggml_compute_params * params,
  8134. struct ggml_tensor * dst) {
  8135. const struct ggml_tensor * src0 = dst->src[0];
  8136. const struct ggml_tensor * src1 = dst->src[1];
  8137. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  8138. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  8139. GGML_TENSOR_BINARY_OP_LOCALS;
  8140. const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
  8141. const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
  8142. const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
  8143. const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
  8144. const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
  8145. const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
  8146. const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
  8147. const int ith = params->ith;
  8148. const int nth = params->nth;
  8149. const int64_t N = is_2D ? ne3 : ne2;
  8150. const int64_t IC = is_2D ? ne2 : ne1;
  8151. const int64_t IH = is_2D ? ne1 : 1;
  8152. const int64_t IW = ne0;
  8153. const int64_t KH = is_2D ? ne01 : 1;
  8154. const int64_t KW = ne00;
  8155. const int64_t OH = is_2D ? ne12 : 1;
  8156. const int64_t OW = ne11;
  8157. int ofs0 = is_2D ? nb3 : nb2;
  8158. int ofs1 = is_2D ? nb2 : nb1;
  8159. GGML_ASSERT(nb0 == sizeof(float));
  8160. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  8161. {
  8162. float * const wdata = (float *) dst->data;
  8163. for (int64_t in = 0; in < N; in++) {
  8164. for (int64_t iic = ith; iic < IC; iic += nth) {
  8165. for (int64_t iih = 0; iih < IH; iih++) {
  8166. for (int64_t iiw = 0; iiw < IW; iiw++) {
  8167. // micro kernel
  8168. float grad = 0.0f;
  8169. for (int64_t ikh = 0; ikh < KH; ikh++) {
  8170. for (int64_t ikw = 0; ikw < KW; ikw++) {
  8171. // For s0 > 1 some values were skipped over in the forward pass.
  8172. // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
  8173. const int64_t tmpw = (iiw + p0 - ikw*d0);
  8174. if (tmpw % s0 != 0) {
  8175. continue;
  8176. }
  8177. const int64_t iow = tmpw / s0;
  8178. // Equivalent logic as above except for s1.
  8179. int64_t ioh;
  8180. if (is_2D) {
  8181. const int64_t tmph = iih + p1 - ikh*d1;
  8182. if (tmph % s1 != 0) {
  8183. continue;
  8184. }
  8185. ioh = tmph / s1;
  8186. } else {
  8187. ioh = 0;
  8188. }
  8189. if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
  8190. continue;
  8191. }
  8192. const float * const src_data = (const float *) src1->data
  8193. + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
  8194. grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
  8195. }
  8196. }
  8197. float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
  8198. dst_data[iih*IW + iiw] = grad;
  8199. }
  8200. }
  8201. }
  8202. }
  8203. }
  8204. }
  8205. // ggml_compute_forward_conv_transpose_2d
  8206. static void ggml_compute_forward_conv_transpose_2d(
  8207. const struct ggml_compute_params * params,
  8208. struct ggml_tensor * dst) {
  8209. const struct ggml_tensor * src0 = dst->src[0];
  8210. const struct ggml_tensor * src1 = dst->src[1];
  8211. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  8212. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  8213. GGML_ASSERT( dst->type == GGML_TYPE_F32);
  8214. GGML_TENSOR_BINARY_OP_LOCALS
  8215. const int ith = params->ith;
  8216. const int nth = params->nth;
  8217. const int nk = ne00*ne01*ne02*ne03;
  8218. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  8219. GGML_ASSERT(nb10 == sizeof(float));
  8220. if (ith == 0) {
  8221. memset(params->wdata, 0, params->wsize);
  8222. // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
  8223. {
  8224. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  8225. for (int64_t i03 = 0; i03 < ne03; i03++) {
  8226. for (int64_t i02 = 0; i02 < ne02; i02++) {
  8227. const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
  8228. ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
  8229. for (int64_t i01 = 0; i01 < ne01; i01++) {
  8230. for (int64_t i00 = 0; i00 < ne00; i00++) {
  8231. dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
  8232. }
  8233. }
  8234. }
  8235. }
  8236. }
  8237. // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
  8238. {
  8239. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
  8240. for (int i12 = 0; i12 < ne12; i12++) {
  8241. for (int i11 = 0; i11 < ne11; i11++) {
  8242. const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
  8243. ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
  8244. for (int i10 = 0; i10 < ne10; i10++) {
  8245. dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
  8246. }
  8247. }
  8248. }
  8249. }
  8250. memset(dst->data, 0, ggml_nbytes(dst));
  8251. }
  8252. ggml_barrier(params->threadpool);
  8253. const int32_t stride = ggml_get_op_params_i32(dst, 0);
  8254. // total patches in dst
  8255. const int np = ne2;
  8256. // patches per thread
  8257. const int dp = (np + nth - 1)/nth;
  8258. // patch range for this thread
  8259. const int ip0 = dp*ith;
  8260. const int ip1 = MIN(ip0 + dp, np);
  8261. ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
  8262. ggml_fp16_t * const wdata_src = wdata + nk;
  8263. for (int i2 = ip0; i2 < ip1; i2++) { // Cout
  8264. float * dst_data = (float *)((char *) dst->data + i2*nb2);
  8265. ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
  8266. for (int i11 = 0; i11 < ne11; i11++) {
  8267. for (int i10 = 0; i10 < ne10; i10++) {
  8268. const int i1n = i11*ne10*ne12 + i10*ne12;
  8269. for (int i01 = 0; i01 < ne01; i01++) {
  8270. for (int i00 = 0; i00 < ne00; i00++) {
  8271. float v = 0;
  8272. ggml_vec_dot_f16(ne03, &v, 0,
  8273. wdata_src + i1n, 0,
  8274. wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
  8275. dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
  8276. }
  8277. }
  8278. }
  8279. }
  8280. }
  8281. }
  8282. // ggml_compute_forward_pool_1d_sk_p0
  8283. static void ggml_compute_forward_pool_1d_sk_p0(
  8284. const struct ggml_compute_params * params,
  8285. const enum ggml_op_pool op,
  8286. const int k,
  8287. struct ggml_tensor * dst) {
  8288. const struct ggml_tensor * src = dst->src[0];
  8289. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  8290. if (params->ith != 0) {
  8291. return;
  8292. }
  8293. const char * cdata = (const char *)src->data;
  8294. const char * const data_end = cdata + ggml_nbytes(src);
  8295. float * drow = (float *)dst->data;
  8296. const int64_t rs = dst->ne[0];
  8297. while (cdata < data_end) {
  8298. const void * srow = (const void *)cdata;
  8299. int j = 0;
  8300. for (int64_t i = 0; i < rs; ++i) {
  8301. switch (op) {
  8302. case GGML_OP_POOL_AVG: drow[i] = 0; break;
  8303. case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
  8304. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  8305. }
  8306. for (int ki = 0; ki < k; ++ki) {
  8307. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  8308. switch (op) {
  8309. case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
  8310. case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
  8311. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  8312. }
  8313. ++j;
  8314. }
  8315. switch (op) {
  8316. case GGML_OP_POOL_AVG: drow[i] /= k; break;
  8317. case GGML_OP_POOL_MAX: break;
  8318. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  8319. }
  8320. }
  8321. cdata += src->nb[1];
  8322. drow += rs;
  8323. }
  8324. }
  8325. // ggml_compute_forward_pool_1d
  8326. static void ggml_compute_forward_pool_1d(
  8327. const struct ggml_compute_params * params,
  8328. struct ggml_tensor * dst) {
  8329. const int32_t * opts = (const int32_t *)dst->op_params;
  8330. enum ggml_op_pool op = opts[0];
  8331. const int k0 = opts[1];
  8332. const int s0 = opts[2];
  8333. const int p0 = opts[3];
  8334. GGML_ASSERT(p0 == 0); // padding not supported
  8335. GGML_ASSERT(k0 == s0); // only s = k supported
  8336. ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
  8337. }
  8338. // ggml_compute_forward_pool_2d
  8339. static void ggml_compute_forward_pool_2d(
  8340. const struct ggml_compute_params * params,
  8341. struct ggml_tensor * dst) {
  8342. const struct ggml_tensor * src = dst->src[0];
  8343. assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
  8344. if (params->ith != 0) {
  8345. return;
  8346. }
  8347. const int32_t * opts = (const int32_t *)dst->op_params;
  8348. enum ggml_op_pool op = opts[0];
  8349. const int k0 = opts[1];
  8350. const int k1 = opts[2];
  8351. const int s0 = opts[3];
  8352. const int s1 = opts[4];
  8353. const int p0 = opts[5];
  8354. const int p1 = opts[6];
  8355. const char * cdata = (const char*)src->data;
  8356. const char * const data_end = cdata + ggml_nbytes(src);
  8357. const int64_t px = dst->ne[0];
  8358. const int64_t py = dst->ne[1];
  8359. const int64_t pa = px * py;
  8360. float * dplane = (float *)dst->data;
  8361. const int ka = k0 * k1;
  8362. const int offset0 = -p0;
  8363. const int offset1 = -p1;
  8364. while (cdata < data_end) {
  8365. for (int oy = 0; oy < py; ++oy) {
  8366. float * const drow = dplane + oy * px;
  8367. for (int ox = 0; ox < px; ++ox) {
  8368. float * const out = drow + ox;
  8369. switch (op) {
  8370. case GGML_OP_POOL_AVG: *out = 0; break;
  8371. case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
  8372. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  8373. }
  8374. const int ix = offset0 + ox * s0;
  8375. const int iy = offset1 + oy * s1;
  8376. for (int ky = 0; ky < k1; ++ky) {
  8377. if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
  8378. const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
  8379. for (int kx = 0; kx < k0; ++kx) {
  8380. int j = ix + kx;
  8381. if (j < 0 || j >= src->ne[0]) continue;
  8382. const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
  8383. switch (op) {
  8384. case GGML_OP_POOL_AVG: *out += srow_j; break;
  8385. case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
  8386. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  8387. }
  8388. }
  8389. }
  8390. switch (op) {
  8391. case GGML_OP_POOL_AVG: *out /= ka; break;
  8392. case GGML_OP_POOL_MAX: break;
  8393. case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
  8394. }
  8395. }
  8396. }
  8397. cdata += src->nb[2];
  8398. dplane += pa;
  8399. }
  8400. }
  8401. // ggml_compute_forward_pool_2d_back
  8402. static void ggml_compute_forward_pool_2d_back(
  8403. const struct ggml_compute_params * params,
  8404. struct ggml_tensor * dst) {
  8405. const struct ggml_tensor * src = dst->src[0];
  8406. const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
  8407. assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
  8408. if (params->ith != 0) {
  8409. return;
  8410. }
  8411. const int32_t * opts = (const int32_t *)dst->op_params;
  8412. enum ggml_op_pool op = opts[0];
  8413. const int k0 = opts[1];
  8414. const int k1 = opts[2];
  8415. const int s0 = opts[3];
  8416. const int s1 = opts[4];
  8417. const int p0 = opts[5];
  8418. const int p1 = opts[6];
  8419. char * cdata = (char *) dst->data;
  8420. const char * cdataf = (const char *) dstf->data;
  8421. const char * const data_end = cdata + ggml_nbytes(dst);
  8422. GGML_ASSERT(params->ith == 0);
  8423. memset(cdata, 0, ggml_nbytes(dst));
  8424. const int64_t px = src->ne[0];
  8425. const int64_t py = src->ne[1];
  8426. const int64_t pa = px * py;
  8427. const float * splane = (const float *) src->data;
  8428. const int ka = k0 * k1;
  8429. const int offset0 = -p0;
  8430. const int offset1 = -p1;
  8431. while (cdata < data_end) {
  8432. for (int oy = 0; oy < py; ++oy) {
  8433. const float * const srow = splane + oy * px;
  8434. for (int ox = 0; ox < px; ++ox) {
  8435. const float grad0 = srow[ox];
  8436. const int ix = offset0 + ox * s0;
  8437. const int iy = offset1 + oy * s1;
  8438. if (op == GGML_OP_POOL_MAX) {
  8439. float maxval = -FLT_MAX;
  8440. int kxmax = -1;
  8441. int kymax = -1;
  8442. for (int ky = 0; ky < k1; ++ky) {
  8443. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  8444. continue;
  8445. }
  8446. const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
  8447. for (int kx = 0; kx < k0; ++kx) {
  8448. int j = ix + kx;
  8449. if (j < 0 || j >= dst->ne[0]) {
  8450. continue;
  8451. }
  8452. const float val = dst->type == GGML_TYPE_F32 ?
  8453. ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
  8454. if (val <= maxval) {
  8455. continue;
  8456. }
  8457. maxval = val;
  8458. kxmax = kx;
  8459. kymax = ky;
  8460. }
  8461. }
  8462. if (kxmax == -1 || kymax == -1) {
  8463. continue;
  8464. }
  8465. void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
  8466. const int j = ix + kxmax;
  8467. if (dst->type == GGML_TYPE_F32) {
  8468. ((float *) drow)[j] += grad0;
  8469. } else {
  8470. ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
  8471. }
  8472. } else if (op == GGML_OP_POOL_AVG) {
  8473. const float grad = grad0 / ka;
  8474. for (int ky = 0; ky < k1; ++ky) {
  8475. if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
  8476. continue;
  8477. }
  8478. void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
  8479. for (int kx = 0; kx < k0; ++kx) {
  8480. int j = ix + kx;
  8481. if (j < 0 || j >= dst->ne[0]) {
  8482. continue;
  8483. }
  8484. if (dst->type == GGML_TYPE_F32) {
  8485. ((float *) drow)[j] += grad;
  8486. } else {
  8487. ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
  8488. }
  8489. }
  8490. }
  8491. } else {
  8492. GGML_ASSERT(false);
  8493. }
  8494. }
  8495. }
  8496. cdata += dst->nb[2];
  8497. cdataf += dst->nb[2];
  8498. splane += pa;
  8499. }
  8500. }
  8501. // ggml_compute_forward_upscale
  8502. static void ggml_compute_forward_upscale_f32(
  8503. const struct ggml_compute_params * params,
  8504. struct ggml_tensor * dst) {
  8505. const struct ggml_tensor * src0 = dst->src[0];
  8506. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  8507. const int ith = params->ith;
  8508. const int nth = params->nth;
  8509. GGML_TENSOR_UNARY_OP_LOCALS
  8510. const float sf0 = (float)ne0/src0->ne[0];
  8511. const float sf1 = (float)ne1/src0->ne[1];
  8512. const float sf2 = (float)ne2/src0->ne[2];
  8513. const float sf3 = (float)ne3/src0->ne[3];
  8514. // TODO: optimize
  8515. for (int64_t i3 = 0; i3 < ne3; i3++) {
  8516. const int64_t i03 = i3 / sf3;
  8517. for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
  8518. const int64_t i02 = i2 / sf2;
  8519. for (int64_t i1 = 0; i1 < ne1; i1++) {
  8520. const int64_t i01 = i1 / sf1;
  8521. for (int64_t i0 = 0; i0 < ne0; i0++) {
  8522. const int64_t i00 = i0 / sf0;
  8523. const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  8524. float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
  8525. *y = *x;
  8526. }
  8527. }
  8528. }
  8529. }
  8530. }
  8531. static void ggml_compute_forward_upscale(
  8532. const struct ggml_compute_params * params,
  8533. struct ggml_tensor * dst) {
  8534. const struct ggml_tensor * src0 = dst->src[0];
  8535. switch (src0->type) {
  8536. case GGML_TYPE_F32:
  8537. {
  8538. ggml_compute_forward_upscale_f32(params, dst);
  8539. } break;
  8540. default:
  8541. {
  8542. GGML_ABORT("fatal error");
  8543. }
  8544. }
  8545. }
  8546. // ggml_compute_forward_pad
  8547. static void ggml_compute_forward_pad_f32(
  8548. const struct ggml_compute_params * params,
  8549. struct ggml_tensor * dst) {
  8550. const struct ggml_tensor * src0 = dst->src[0];
  8551. GGML_ASSERT(src0->nb[0] == sizeof(float));
  8552. GGML_ASSERT( dst->nb[0] == sizeof(float));
  8553. const int ith = params->ith;
  8554. const int nth = params->nth;
  8555. GGML_TENSOR_UNARY_OP_LOCALS
  8556. float * dst_ptr = (float *) dst->data;
  8557. // TODO: optimize
  8558. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  8559. for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
  8560. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  8561. for (int64_t i3 = 0; i3 < ne3; ++i3) {
  8562. const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
  8563. const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  8564. if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
  8565. dst_ptr[dst_idx] = *src_ptr;
  8566. } else {
  8567. dst_ptr[dst_idx] = 0;
  8568. }
  8569. }
  8570. }
  8571. }
  8572. }
  8573. }
  8574. static void ggml_compute_forward_pad(
  8575. const struct ggml_compute_params * params,
  8576. struct ggml_tensor * dst) {
  8577. const struct ggml_tensor * src0 = dst->src[0];
  8578. switch (src0->type) {
  8579. case GGML_TYPE_F32:
  8580. {
  8581. ggml_compute_forward_pad_f32(params, dst);
  8582. } break;
  8583. default:
  8584. {
  8585. GGML_ABORT("fatal error");
  8586. }
  8587. }
  8588. }
  8589. // ggml_compute_forward_arange
  8590. static void ggml_compute_forward_arange_f32(
  8591. const struct ggml_compute_params * params,
  8592. struct ggml_tensor * dst) {
  8593. GGML_ASSERT(dst->nb[0] == sizeof(float));
  8594. const int ith = params->ith;
  8595. const int nth = params->nth;
  8596. const float start = ggml_get_op_params_f32(dst, 0);
  8597. const float stop = ggml_get_op_params_f32(dst, 1);
  8598. const float step = ggml_get_op_params_f32(dst, 2);
  8599. const int64_t steps = (int64_t) ceilf((stop - start) / step);
  8600. GGML_ASSERT(ggml_nelements(dst) == steps);
  8601. for (int64_t i = ith; i < steps; i+= nth) {
  8602. float value = start + step * i;
  8603. ((float *)dst->data)[i] = value;
  8604. }
  8605. }
  8606. static void ggml_compute_forward_arange(
  8607. const struct ggml_compute_params * params,
  8608. struct ggml_tensor * dst) {
  8609. switch (dst->type) {
  8610. case GGML_TYPE_F32:
  8611. {
  8612. ggml_compute_forward_arange_f32(params, dst);
  8613. } break;
  8614. default:
  8615. {
  8616. GGML_ABORT("fatal error");
  8617. }
  8618. }
  8619. }
  8620. static void ggml_compute_forward_timestep_embedding_f32(
  8621. const struct ggml_compute_params * params,
  8622. struct ggml_tensor * dst) {
  8623. const struct ggml_tensor * src0 = dst->src[0];
  8624. GGML_ASSERT(src0->nb[0] == sizeof(float));
  8625. const int ith = params->ith;
  8626. const int nth = params->nth;
  8627. GGML_TENSOR_UNARY_OP_LOCALS
  8628. const int dim = ggml_get_op_params_i32(dst, 0);
  8629. const int max_period = ggml_get_op_params_i32(dst, 1);
  8630. int half = dim / 2;
  8631. for (int64_t i = 0; i < ne00; i++) {
  8632. float * embed_data = (float *)((char *) dst->data + i*nb1);
  8633. for (int64_t j = ith; j < half; j += nth) {
  8634. float timestep = ((float *)src0->data)[i];
  8635. float freq = (float)expf(-logf(max_period) * j / half);
  8636. float arg = timestep * freq;
  8637. embed_data[j] = cosf(arg);
  8638. embed_data[j + half] = sinf(arg);
  8639. }
  8640. if (dim % 2 != 0 && ith == 0) {
  8641. embed_data[dim] = 0.f;
  8642. }
  8643. }
  8644. }
  8645. static void ggml_compute_forward_timestep_embedding(
  8646. const struct ggml_compute_params * params,
  8647. struct ggml_tensor * dst) {
  8648. const struct ggml_tensor * src0 = dst->src[0];
  8649. switch (src0->type) {
  8650. case GGML_TYPE_F32:
  8651. {
  8652. ggml_compute_forward_timestep_embedding_f32(params, dst);
  8653. } break;
  8654. default:
  8655. {
  8656. GGML_ABORT("fatal error");
  8657. }
  8658. }
  8659. }
  8660. // ggml_compute_forward_argsort
  8661. static void ggml_compute_forward_argsort_f32(
  8662. const struct ggml_compute_params * params,
  8663. struct ggml_tensor * dst) {
  8664. const struct ggml_tensor * src0 = dst->src[0];
  8665. GGML_TENSOR_UNARY_OP_LOCALS
  8666. GGML_ASSERT(nb0 == sizeof(float));
  8667. const int ith = params->ith;
  8668. const int nth = params->nth;
  8669. const int64_t nr = ggml_nrows(src0);
  8670. enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
  8671. for (int64_t i = ith; i < nr; i += nth) {
  8672. int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
  8673. const float * src_data = (float *)((char *) src0->data + i*nb01);
  8674. for (int64_t j = 0; j < ne0; j++) {
  8675. dst_data[j] = j;
  8676. }
  8677. // C doesn't have a functional sort, so we do a bubble sort instead
  8678. for (int64_t j = 0; j < ne0; j++) {
  8679. for (int64_t k = j + 1; k < ne0; k++) {
  8680. if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
  8681. (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
  8682. int32_t tmp = dst_data[j];
  8683. dst_data[j] = dst_data[k];
  8684. dst_data[k] = tmp;
  8685. }
  8686. }
  8687. }
  8688. }
  8689. }
  8690. static void ggml_compute_forward_argsort(
  8691. const struct ggml_compute_params * params,
  8692. struct ggml_tensor * dst) {
  8693. const struct ggml_tensor * src0 = dst->src[0];
  8694. switch (src0->type) {
  8695. case GGML_TYPE_F32:
  8696. {
  8697. ggml_compute_forward_argsort_f32(params, dst);
  8698. } break;
  8699. default:
  8700. {
  8701. GGML_ABORT("fatal error");
  8702. }
  8703. }
  8704. }
  8705. // ggml_compute_forward_flash_attn_ext
  8706. static void ggml_compute_forward_flash_attn_ext_f16(
  8707. const struct ggml_compute_params * params,
  8708. const struct ggml_tensor * q,
  8709. const struct ggml_tensor * k,
  8710. const struct ggml_tensor * v,
  8711. const struct ggml_tensor * mask,
  8712. struct ggml_tensor * dst) {
  8713. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  8714. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  8715. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  8716. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  8717. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  8718. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  8719. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  8720. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  8721. const int ith = params->ith;
  8722. const int nth = params->nth;
  8723. const int64_t D = neq0;
  8724. const int64_t N = neq1;
  8725. GGML_ASSERT(ne0 == D);
  8726. GGML_ASSERT(ne2 == N);
  8727. // input tensor rows must be contiguous
  8728. GGML_ASSERT(nbq0 == ggml_type_size(q->type));
  8729. GGML_ASSERT(nbk0 == ggml_type_size(k->type));
  8730. GGML_ASSERT(nbv0 == ggml_type_size(v->type));
  8731. GGML_ASSERT(neq0 == D);
  8732. GGML_ASSERT(nek0 == D);
  8733. GGML_ASSERT(nev0 == D);
  8734. GGML_ASSERT(neq1 == N);
  8735. GGML_ASSERT(nev0 == D);
  8736. // dst cannot be transposed or permuted
  8737. GGML_ASSERT(nb0 == sizeof(float));
  8738. GGML_ASSERT(nb0 <= nb1);
  8739. GGML_ASSERT(nb1 <= nb2);
  8740. GGML_ASSERT(nb2 <= nb3);
  8741. // broadcast factors
  8742. const int64_t rk2 = neq2/nek2;
  8743. const int64_t rk3 = neq3/nek3;
  8744. const int64_t rv2 = neq2/nev2;
  8745. const int64_t rv3 = neq3/nev3;
  8746. // parallelize by q rows using ggml_vec_dot_f32
  8747. // total rows in q
  8748. const int nr = neq1*neq2*neq3;
  8749. // rows per thread
  8750. const int dr = (nr + nth - 1)/nth;
  8751. // row range for this thread
  8752. const int ir0 = dr*ith;
  8753. const int ir1 = MIN(ir0 + dr, nr);
  8754. float scale = 1.0f;
  8755. float max_bias = 0.0f;
  8756. float logit_softcap = 0.0f;
  8757. memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
  8758. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  8759. memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
  8760. if (logit_softcap != 0) {
  8761. scale /= logit_softcap;
  8762. }
  8763. const uint32_t n_head = neq2;
  8764. const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
  8765. const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
  8766. const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  8767. enum ggml_type const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type;
  8768. ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits(k_vec_dot_type)->from_float;
  8769. ggml_vec_dot_t const kq_vec_dot = type_traits_cpu[k->type].vec_dot;
  8770. ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
  8771. GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
  8772. GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
  8773. // loop over n_batch and n_head
  8774. for (int ir = ir0; ir < ir1; ++ir) {
  8775. // q indices
  8776. const int iq3 = ir/(neq2*neq1);
  8777. const int iq2 = (ir - iq3*neq2*neq1)/neq1;
  8778. const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
  8779. const uint32_t h = iq2; // head index
  8780. const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
  8781. float S = 0.0f; // sum
  8782. float M = -INFINITY; // maximum KQ value
  8783. float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
  8784. float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
  8785. ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
  8786. ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
  8787. if (v->type == GGML_TYPE_F16) {
  8788. memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
  8789. } else {
  8790. memset(VKQ32, 0, D*sizeof(float));
  8791. }
  8792. const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
  8793. // k indices
  8794. const int ik3 = iq3 / rk3;
  8795. const int ik2 = iq2 / rk2;
  8796. // v indices
  8797. const int iv3 = iq3 / rv3;
  8798. const int iv2 = iq2 / rv2;
  8799. const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
  8800. q_to_vec_dot(pq, Q_q, D);
  8801. // online softmax / attention
  8802. // loop over n_kv and n_head_kv
  8803. // ref: https://arxiv.org/pdf/2112.05682.pdf
  8804. for (int64_t ic = 0; ic < nek1; ++ic) {
  8805. const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
  8806. if (mv == -INFINITY) {
  8807. continue;
  8808. }
  8809. float s; // KQ value
  8810. const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
  8811. kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
  8812. s = s*scale; // scale KQ value
  8813. if (logit_softcap != 0.0f) {
  8814. s = logit_softcap*tanhf(s);
  8815. }
  8816. s += mv; // apply mask
  8817. const float Mold = M;
  8818. float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
  8819. float vs = 1.0f; // post-softmax KQ value, expf(s - M)
  8820. const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
  8821. if (v->type == GGML_TYPE_F16) {
  8822. if (s > M) {
  8823. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  8824. M = s;
  8825. ms = expf(Mold - M);
  8826. // V = V*expf(Mold - M)
  8827. ggml_vec_scale_f16(D, VKQ16, ms);
  8828. } else {
  8829. // no new maximum, ms == 1.0f, vs != 1.0f
  8830. vs = expf(s - M);
  8831. }
  8832. // V += v*expf(s - M)
  8833. ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
  8834. } else {
  8835. if (s > M) {
  8836. // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
  8837. M = s;
  8838. ms = expf(Mold - M);
  8839. // V = V*expf(Mold - M)
  8840. ggml_vec_scale_f32(D, VKQ32, ms);
  8841. } else {
  8842. // no new maximum, ms == 1.0f, vs != 1.0f
  8843. vs = expf(s - M);
  8844. }
  8845. v_to_float(v_data, V32, D);
  8846. // V += v*expf(s - M)
  8847. ggml_vec_mad_f32(D, VKQ32, V32, vs);
  8848. }
  8849. S = S*ms + vs; // scale and increment sum with partial sum
  8850. }
  8851. if (v->type == GGML_TYPE_F16) {
  8852. for (int64_t d = 0; d < D; ++d) {
  8853. VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
  8854. }
  8855. }
  8856. // V /= S
  8857. const float S_inv = 1.0f/S;
  8858. ggml_vec_scale_f32(D, VKQ32, S_inv);
  8859. // dst indices
  8860. const int i1 = iq1;
  8861. const int i2 = iq2;
  8862. const int i3 = iq3;
  8863. // original
  8864. //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
  8865. // permute(0, 2, 1, 3)
  8866. memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
  8867. }
  8868. }
  8869. static void ggml_compute_forward_flash_attn_ext(
  8870. const struct ggml_compute_params * params,
  8871. const struct ggml_tensor * q,
  8872. const struct ggml_tensor * k,
  8873. const struct ggml_tensor * v,
  8874. const struct ggml_tensor * mask,
  8875. struct ggml_tensor * dst) {
  8876. switch (dst->op_params[3]) {
  8877. case GGML_PREC_DEFAULT:
  8878. case GGML_PREC_F32:
  8879. {
  8880. // uses F32 accumulators
  8881. ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
  8882. } break;
  8883. default:
  8884. {
  8885. GGML_ABORT("fatal error");
  8886. }
  8887. }
  8888. }
  8889. // ggml_compute_forward_flash_attn_back
  8890. static void ggml_compute_forward_flash_attn_back_f32(
  8891. const struct ggml_compute_params * params,
  8892. const bool masked,
  8893. struct ggml_tensor * dst) {
  8894. const struct ggml_tensor * q = dst->src[0];
  8895. const struct ggml_tensor * k = dst->src[1];
  8896. const struct ggml_tensor * v = dst->src[2];
  8897. const struct ggml_tensor * d = dst->src[3];
  8898. GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
  8899. GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
  8900. GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
  8901. GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
  8902. GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
  8903. GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
  8904. GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
  8905. GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
  8906. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  8907. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  8908. const int ith = params->ith;
  8909. const int nth = params->nth;
  8910. const int64_t D = neq0;
  8911. const int64_t N = neq1;
  8912. const int64_t P = nek1 - N;
  8913. const int64_t M = P + N;
  8914. const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
  8915. const int mxDM = MAX(D, Mup);
  8916. // GGML_ASSERT(ne0 == D);
  8917. // GGML_ASSERT(ne1 == N);
  8918. GGML_ASSERT(P >= 0);
  8919. GGML_ASSERT(nbq0 == sizeof(float));
  8920. GGML_ASSERT(nbk0 == sizeof(float));
  8921. GGML_ASSERT(nbv0 == sizeof(float));
  8922. GGML_ASSERT(neq0 == D);
  8923. GGML_ASSERT(nek0 == D);
  8924. GGML_ASSERT(nev1 == D);
  8925. GGML_ASSERT(ned0 == D);
  8926. GGML_ASSERT(neq1 == N);
  8927. GGML_ASSERT(nek1 == N + P);
  8928. GGML_ASSERT(nev1 == D);
  8929. GGML_ASSERT(ned1 == N);
  8930. // dst cannot be transposed or permuted
  8931. GGML_ASSERT(nb0 == sizeof(float));
  8932. GGML_ASSERT(nb0 <= nb1);
  8933. GGML_ASSERT(nb1 <= nb2);
  8934. GGML_ASSERT(nb2 <= nb3);
  8935. if (ith == 0) {
  8936. memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
  8937. }
  8938. ggml_barrier(params->threadpool);
  8939. const int64_t elem_q = ggml_nelements(q);
  8940. const int64_t elem_k = ggml_nelements(k);
  8941. enum ggml_type result_type = dst->type;
  8942. GGML_ASSERT(ggml_blck_size(result_type) == 1);
  8943. const size_t tsize = ggml_type_size(result_type);
  8944. const size_t offs_q = 0;
  8945. const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
  8946. const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
  8947. void * grad_q = (char *) dst->data;
  8948. void * grad_k = (char *) dst->data + offs_k;
  8949. void * grad_v = (char *) dst->data + offs_v;
  8950. const size_t nbgq1 = nb0*neq0;
  8951. const size_t nbgq2 = nb0*neq0*neq1;
  8952. const size_t nbgq3 = nb0*neq0*neq1*neq2;
  8953. const size_t nbgk1 = nb0*nek0;
  8954. const size_t nbgk2 = nb0*nek0*nek1;
  8955. const size_t nbgk3 = nb0*nek0*nek1*neq2;
  8956. const size_t nbgv1 = nb0*nev0;
  8957. const size_t nbgv2 = nb0*nev0*nev1;
  8958. const size_t nbgv3 = nb0*nev0*nev1*neq2;
  8959. // parallelize by k rows using ggml_vec_dot_f32
  8960. // total rows in k
  8961. const int nr = nek2*nek3;
  8962. // rows per thread
  8963. const int dr = (nr + nth - 1)/nth;
  8964. // row range for this thread
  8965. const int ir0 = dr*ith;
  8966. const int ir1 = MIN(ir0 + dr, nr);
  8967. const float scale = 1.0f/sqrtf(D);
  8968. //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
  8969. // how often k2 (and v2) is repeated in q2
  8970. int nrep = neq2/nek2;
  8971. for (int ir = ir0; ir < ir1; ++ir) {
  8972. // q indices
  8973. const int ik3 = ir/(nek2);
  8974. const int ik2 = ir - ik3*nek2;
  8975. const int iq3 = ik3;
  8976. const int id3 = ik3;
  8977. const int iv3 = ik3;
  8978. const int iv2 = ik2;
  8979. for (int irep = 0; irep < nrep; ++irep) {
  8980. const int iq2 = ik2 + irep*nek2;
  8981. const int id2 = iq2;
  8982. // (ik2 + irep*nek2) % nek2 == ik2
  8983. for (int iq1 = 0; iq1 < neq1; ++iq1) {
  8984. const int id1 = iq1;
  8985. // not sure about CACHE_LINE_SIZE_F32..
  8986. // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
  8987. float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
  8988. float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
  8989. for (int i = M; i < Mup; ++i) {
  8990. S[i] = -INFINITY;
  8991. }
  8992. const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
  8993. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  8994. // k indices
  8995. const int ik1 = ic;
  8996. // S indices
  8997. const int i1 = ik1;
  8998. ggml_vec_dot_f32(neq0,
  8999. S + i1, 0,
  9000. (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
  9001. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
  9002. }
  9003. // scale
  9004. ggml_vec_scale_f32(masked_begin, S, scale);
  9005. for (int64_t i = masked_begin; i < M; i++) {
  9006. S[i] = -INFINITY;
  9007. }
  9008. // softmax
  9009. // exclude known -INF S[..] values from max and loop
  9010. // dont forget to set their SM values to zero
  9011. {
  9012. float max = -INFINITY;
  9013. ggml_vec_max_f32(masked_begin, &max, S);
  9014. ggml_float sum = 0.0;
  9015. {
  9016. #ifdef GGML_SOFT_MAX_ACCELERATE
  9017. max = -max;
  9018. vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
  9019. vvexpf(SM, SM, &Mup);
  9020. ggml_vec_sum_f32(Mup, &sum, SM);
  9021. #else
  9022. sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
  9023. #endif
  9024. }
  9025. assert(sum > 0.0);
  9026. sum = 1.0/sum;
  9027. ggml_vec_scale_f32(masked_begin, SM, sum);
  9028. }
  9029. // step-by-step explanation
  9030. {
  9031. // forward-process shape grads from backward process
  9032. // parallel_for ik2,ik3:
  9033. // for irep:
  9034. // iq2 = ik2 + irep*nek2
  9035. // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
  9036. // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
  9037. // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
  9038. // for iq1:
  9039. // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
  9040. // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
  9041. // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
  9042. // S0 = -Inf [D,1,1,1]
  9043. // ~S1[i] = dot(kcur[:D,i], qcur)
  9044. // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
  9045. // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
  9046. // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  9047. // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
  9048. // ~S5[i] = dot(vcur[:,i], S4)
  9049. // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
  9050. // ~dst[i,iq1,iq2,iq3] = S5[i] ^
  9051. // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
  9052. // dst backward-/ grad[dst] = d
  9053. //
  9054. // output gradients with their dependencies:
  9055. //
  9056. // grad[kcur] = grad[S1].T @ qcur
  9057. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  9058. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  9059. // grad[S4] = grad[S5] @ vcur
  9060. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  9061. // grad[qcur] = grad[S1] @ kcur
  9062. // grad[vcur] = grad[S5].T @ S4
  9063. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  9064. //
  9065. // in post-order:
  9066. //
  9067. // S1 = qcur @ kcur.T
  9068. // S2 = S1 * scale
  9069. // S3 = diag_mask_inf(S2, P)
  9070. // S4 = softmax(S3)
  9071. // grad[S4] = d[:D,id1,id2,id3] @ vcur
  9072. // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
  9073. // grad[S1] = diag_mask_zero(grad[S3], P) * scale
  9074. // grad[qcur] = grad[S1] @ kcur
  9075. // grad[kcur] = grad[S1].T @ qcur
  9076. // grad[vcur] = d[:D,id1,id2,id3].T @ S4
  9077. //
  9078. // using less variables (SM=S4):
  9079. //
  9080. // S = diag_mask_inf(qcur @ kcur.T * scale, P)
  9081. // SM = softmax(S)
  9082. // S = d[:D,iq1,iq2,iq3] @ vcur
  9083. // dot_SM_gradSM = dot(SM, S)
  9084. // S = SM * (S - dot(SM, S))
  9085. // S = diag_mask_zero(S, P) * scale
  9086. //
  9087. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  9088. // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
  9089. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  9090. }
  9091. // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  9092. // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
  9093. // for ic:
  9094. // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
  9095. // exclude known future zero S[..] values from operation
  9096. ggml_vec_set_f32(masked_begin, S, 0);
  9097. for (int64_t ic = 0; ic < D; ++ic) {
  9098. ggml_vec_mad_f32(masked_begin,
  9099. S,
  9100. (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
  9101. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  9102. }
  9103. // S = SM * (S - dot(SM, S))
  9104. float dot_SM_gradSM = 0;
  9105. ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
  9106. ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
  9107. ggml_vec_mul_f32 (masked_begin, S, S, SM);
  9108. // S = diag_mask_zero(S, P) * scale
  9109. // already done by above ggml_vec_set_f32
  9110. // exclude known zero S[..] values from operation
  9111. ggml_vec_scale_f32(masked_begin, S, scale);
  9112. // S shape [M,1]
  9113. // SM shape [M,1]
  9114. // kcur shape [D,M]
  9115. // qcur shape [D,1]
  9116. // vcur shape [M,D]
  9117. // grad[q][:D,iq1,iq2,iq3] += S @ kcur
  9118. // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
  9119. // for ic:
  9120. // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
  9121. // exclude known zero S[..] values from loop
  9122. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  9123. ggml_vec_mad_f32(D,
  9124. (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
  9125. (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
  9126. S[ic]);
  9127. }
  9128. // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
  9129. // for ic:
  9130. // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
  9131. // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
  9132. // exclude known zero S[..] values from loop
  9133. for (int64_t ic = 0; ic < masked_begin; ++ic) {
  9134. ggml_vec_mad_f32(D,
  9135. (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
  9136. (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
  9137. S[ic]);
  9138. }
  9139. // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
  9140. // for ic:
  9141. // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
  9142. // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
  9143. // exclude known zero SM[..] values from mad
  9144. for (int64_t ic = 0; ic < D; ++ic) {
  9145. ggml_vec_mad_f32(masked_begin,
  9146. (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
  9147. SM,
  9148. *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
  9149. }
  9150. }
  9151. }
  9152. }
  9153. }
  9154. static void ggml_compute_forward_flash_attn_back(
  9155. const struct ggml_compute_params * params,
  9156. const bool masked,
  9157. struct ggml_tensor * dst) {
  9158. const struct ggml_tensor * q = dst->src[0];
  9159. switch (q->type) {
  9160. case GGML_TYPE_F32:
  9161. {
  9162. ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
  9163. } break;
  9164. default:
  9165. {
  9166. GGML_ABORT("fatal error");
  9167. }
  9168. }
  9169. }
  9170. // ggml_compute_forward_ssm_conv
  9171. static void ggml_compute_forward_ssm_conv_f32(
  9172. const struct ggml_compute_params * params,
  9173. struct ggml_tensor * dst) {
  9174. const struct ggml_tensor * src0 = dst->src[0]; // conv_x
  9175. const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
  9176. const int ith = params->ith;
  9177. const int nth = params->nth;
  9178. const int nc = src1->ne[0]; // d_conv
  9179. const int ncs = src0->ne[0]; // d_conv - 1 + n_t
  9180. const int nr = src0->ne[1]; // d_inner
  9181. const int n_t = dst->ne[1]; // tokens per sequence
  9182. const int n_s = dst->ne[2]; // number of sequences in the batch
  9183. GGML_ASSERT( dst->ne[0] == nr);
  9184. GGML_ASSERT(src0->nb[0] == sizeof(float));
  9185. GGML_ASSERT(src1->nb[0] == sizeof(float));
  9186. GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
  9187. // rows per thread
  9188. const int dr = (nr + nth - 1)/nth;
  9189. // row range for this thread
  9190. const int ir0 = dr*ith;
  9191. const int ir1 = MIN(ir0 + dr, nr);
  9192. const int ir = ir1 - ir0;
  9193. for (int i3 = 0; i3 < n_s; ++i3) {
  9194. for (int i2 = 0; i2 < n_t; ++i2) {
  9195. // {d_conv - 1 + n_t, d_inner, n_seqs}
  9196. // sliding window
  9197. const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
  9198. const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
  9199. float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
  9200. // TODO: transpose the output for smaller strides for big batches?
  9201. // d_inner
  9202. for (int i1 = 0; i1 < ir; ++i1) {
  9203. // rowwise dot product
  9204. // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
  9205. float sumf = 0.0f;
  9206. // d_conv
  9207. for (int i0 = 0; i0 < nc; ++i0) {
  9208. sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
  9209. }
  9210. x[i1] = sumf;
  9211. }
  9212. }
  9213. }
  9214. }
  9215. static void ggml_compute_forward_ssm_conv(
  9216. const struct ggml_compute_params * params,
  9217. struct ggml_tensor * dst) {
  9218. switch (dst->src[0]->type) {
  9219. case GGML_TYPE_F32:
  9220. {
  9221. ggml_compute_forward_ssm_conv_f32(params, dst);
  9222. } break;
  9223. default:
  9224. {
  9225. GGML_ABORT("fatal error");
  9226. }
  9227. }
  9228. }
  9229. // ggml_compute_forward_ssm_scan
  9230. static void ggml_compute_forward_ssm_scan_f32(
  9231. const struct ggml_compute_params * params,
  9232. struct ggml_tensor * dst) {
  9233. const struct ggml_tensor * src0 = dst->src[0]; // s
  9234. const struct ggml_tensor * src1 = dst->src[1]; // x
  9235. const struct ggml_tensor * src2 = dst->src[2]; // dt
  9236. const struct ggml_tensor * src3 = dst->src[3]; // A
  9237. const struct ggml_tensor * src4 = dst->src[4]; // B
  9238. const struct ggml_tensor * src5 = dst->src[5]; // C
  9239. const int ith = params->ith;
  9240. const int nth = params->nth;
  9241. const int64_t nc = src0->ne[0]; // d_state
  9242. const int64_t nr = src0->ne[1]; // d_inner
  9243. const int64_t n_t = src1->ne[1]; // number of tokens per sequence
  9244. const int64_t n_s = src0->ne[2]; // number of sequences in the batch
  9245. GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
  9246. GGML_ASSERT(src0->nb[0] == sizeof(float));
  9247. GGML_ASSERT(src1->nb[0] == sizeof(float));
  9248. GGML_ASSERT(src2->nb[0] == sizeof(float));
  9249. GGML_ASSERT(src3->nb[0] == sizeof(float));
  9250. GGML_ASSERT(src4->nb[0] == sizeof(float));
  9251. GGML_ASSERT(src5->nb[0] == sizeof(float));
  9252. // required for the dot product between s and C
  9253. GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
  9254. // required for per-sequence offsets for states
  9255. GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
  9256. // required to get correct offset for state destination (i.e. src1->nb[3])
  9257. GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
  9258. // rows per thread
  9259. const int dr = (nr + nth - 1)/nth;
  9260. // row range for this thread
  9261. const int ir0 = dr*ith;
  9262. const int ir1 = MIN(ir0 + dr, nr);
  9263. const int ir = ir1 - ir0;
  9264. for (int i3 = 0; i3 < n_s; ++i3) {
  9265. for (int i2 = 0; i2 < n_t; ++i2) {
  9266. const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
  9267. const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
  9268. const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
  9269. const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
  9270. const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
  9271. const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
  9272. float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
  9273. float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
  9274. // use the output as the source for the next token-wise iterations
  9275. if (i2 > 0) { s0 = s; }
  9276. // d_inner
  9277. for (int i1 = 0; i1 < ir; ++i1) {
  9278. // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
  9279. float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
  9280. float x_dt = x[i1] * dt_soft_plus;
  9281. float sumf = 0.0f;
  9282. // d_state
  9283. for (int i0 = 0; i0 < nc; ++i0) {
  9284. int i = i0 + i1*nc;
  9285. // state = prev_state * dA + dB * x
  9286. float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
  9287. // y = rowwise_dotprod(state, C)
  9288. sumf += state * C[i0];
  9289. s[i] = state;
  9290. }
  9291. y[i1] = sumf;
  9292. }
  9293. }
  9294. }
  9295. }
  9296. static void ggml_compute_forward_ssm_scan(
  9297. const struct ggml_compute_params * params,
  9298. struct ggml_tensor * dst) {
  9299. switch (dst->src[0]->type) {
  9300. case GGML_TYPE_F32:
  9301. {
  9302. ggml_compute_forward_ssm_scan_f32(params, dst);
  9303. } break;
  9304. default:
  9305. {
  9306. GGML_ABORT("fatal error");
  9307. }
  9308. }
  9309. }
  9310. // ggml_compute_forward_win_part
  9311. static void ggml_compute_forward_win_part_f32(
  9312. const struct ggml_compute_params * params,
  9313. struct ggml_tensor * dst) {
  9314. UNUSED(params);
  9315. const struct ggml_tensor * src0 = dst->src[0];
  9316. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  9317. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  9318. const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
  9319. const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
  9320. const int32_t w = ((const int32_t *)(dst->op_params))[2];
  9321. assert(ne00 == ne0);
  9322. assert(ne3 == nep0*nep1);
  9323. // TODO: optimize / multi-thread
  9324. for (int py = 0; py < nep1; ++py) {
  9325. for (int px = 0; px < nep0; ++px) {
  9326. const int64_t i3 = py*nep0 + px;
  9327. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  9328. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  9329. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  9330. const int64_t i02 = py*w + i2;
  9331. const int64_t i01 = px*w + i1;
  9332. const int64_t i00 = i0;
  9333. const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
  9334. const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
  9335. if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
  9336. ((float *) dst->data)[i] = 0.0f;
  9337. } else {
  9338. ((float *) dst->data)[i] = ((float *) src0->data)[j];
  9339. }
  9340. }
  9341. }
  9342. }
  9343. }
  9344. }
  9345. }
  9346. static void ggml_compute_forward_win_part(
  9347. const struct ggml_compute_params * params,
  9348. struct ggml_tensor * dst) {
  9349. const struct ggml_tensor * src0 = dst->src[0];
  9350. switch (src0->type) {
  9351. case GGML_TYPE_F32:
  9352. {
  9353. ggml_compute_forward_win_part_f32(params, dst);
  9354. } break;
  9355. default:
  9356. {
  9357. GGML_ABORT("fatal error");
  9358. }
  9359. }
  9360. }
  9361. // ggml_compute_forward_win_unpart
  9362. static void ggml_compute_forward_win_unpart_f32(
  9363. const struct ggml_compute_params * params,
  9364. struct ggml_tensor * dst) {
  9365. UNUSED(params);
  9366. const struct ggml_tensor * src0 = dst->src[0];
  9367. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
  9368. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
  9369. const int32_t w = ((const int32_t *)(dst->op_params))[0];
  9370. // padding
  9371. const int px = (w - ne1%w)%w;
  9372. //const int py = (w - ne2%w)%w;
  9373. const int npx = (px + ne1)/w;
  9374. //const int npy = (py + ne2)/w;
  9375. assert(ne0 == ne00);
  9376. // TODO: optimize / multi-thread
  9377. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  9378. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  9379. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  9380. const int ip2 = i2/w;
  9381. const int ip1 = i1/w;
  9382. const int64_t i02 = i2%w;
  9383. const int64_t i01 = i1%w;
  9384. const int64_t i00 = i0;
  9385. const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
  9386. const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
  9387. ((float *) dst->data)[j] = ((float *) src0->data)[i];
  9388. }
  9389. }
  9390. }
  9391. }
  9392. static void ggml_compute_forward_win_unpart(
  9393. const struct ggml_compute_params * params,
  9394. struct ggml_tensor * dst) {
  9395. const struct ggml_tensor * src0 = dst->src[0];
  9396. switch (src0->type) {
  9397. case GGML_TYPE_F32:
  9398. {
  9399. ggml_compute_forward_win_unpart_f32(params, dst);
  9400. } break;
  9401. default:
  9402. {
  9403. GGML_ABORT("fatal error");
  9404. }
  9405. }
  9406. }
  9407. //gmml_compute_forward_unary
  9408. static void ggml_compute_forward_unary(
  9409. const struct ggml_compute_params * params,
  9410. struct ggml_tensor * dst) {
  9411. const enum ggml_unary_op op = ggml_get_unary_op(dst);
  9412. switch (op) {
  9413. case GGML_UNARY_OP_ABS:
  9414. {
  9415. ggml_compute_forward_abs(params, dst);
  9416. } break;
  9417. case GGML_UNARY_OP_SGN:
  9418. {
  9419. ggml_compute_forward_sgn(params, dst);
  9420. } break;
  9421. case GGML_UNARY_OP_NEG:
  9422. {
  9423. ggml_compute_forward_neg(params, dst);
  9424. } break;
  9425. case GGML_UNARY_OP_STEP:
  9426. {
  9427. ggml_compute_forward_step(params, dst);
  9428. } break;
  9429. case GGML_UNARY_OP_TANH:
  9430. {
  9431. ggml_compute_forward_tanh(params, dst);
  9432. } break;
  9433. case GGML_UNARY_OP_ELU:
  9434. {
  9435. ggml_compute_forward_elu(params, dst);
  9436. } break;
  9437. case GGML_UNARY_OP_RELU:
  9438. {
  9439. ggml_compute_forward_relu(params, dst);
  9440. } break;
  9441. case GGML_UNARY_OP_SIGMOID:
  9442. {
  9443. ggml_compute_forward_sigmoid(params, dst);
  9444. } break;
  9445. case GGML_UNARY_OP_GELU:
  9446. {
  9447. ggml_compute_forward_gelu(params, dst);
  9448. } break;
  9449. case GGML_UNARY_OP_GELU_QUICK:
  9450. {
  9451. ggml_compute_forward_gelu_quick(params, dst);
  9452. } break;
  9453. case GGML_UNARY_OP_SILU:
  9454. {
  9455. ggml_compute_forward_silu(params, dst);
  9456. } break;
  9457. case GGML_UNARY_OP_HARDSWISH:
  9458. {
  9459. ggml_compute_forward_hardswish(params, dst);
  9460. } break;
  9461. case GGML_UNARY_OP_HARDSIGMOID:
  9462. {
  9463. ggml_compute_forward_hardsigmoid(params, dst);
  9464. } break;
  9465. case GGML_UNARY_OP_EXP:
  9466. {
  9467. ggml_compute_forward_exp(params, dst);
  9468. } break;
  9469. default:
  9470. {
  9471. GGML_ABORT("fatal error");
  9472. }
  9473. }
  9474. }
  9475. // ggml_compute_forward_get_rel_pos
  9476. static void ggml_compute_forward_get_rel_pos_f16(
  9477. const struct ggml_compute_params * params,
  9478. struct ggml_tensor * dst) {
  9479. UNUSED(params);
  9480. const struct ggml_tensor * src0 = dst->src[0];
  9481. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
  9482. GGML_TENSOR_UNARY_OP_LOCALS
  9483. const int64_t w = ne1;
  9484. ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
  9485. ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
  9486. for (int64_t i2 = 0; i2 < ne2; ++i2) {
  9487. for (int64_t i1 = 0; i1 < ne1; ++i1) {
  9488. const int64_t pos = (w - i1 - 1) + i2;
  9489. for (int64_t i0 = 0; i0 < ne0; ++i0) {
  9490. dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
  9491. }
  9492. }
  9493. }
  9494. }
  9495. static void ggml_compute_forward_get_rel_pos(
  9496. const struct ggml_compute_params * params,
  9497. struct ggml_tensor * dst) {
  9498. const struct ggml_tensor * src0 = dst->src[0];
  9499. switch (src0->type) {
  9500. case GGML_TYPE_F16:
  9501. case GGML_TYPE_BF16:
  9502. {
  9503. ggml_compute_forward_get_rel_pos_f16(params, dst);
  9504. } break;
  9505. default:
  9506. {
  9507. GGML_ABORT("fatal error");
  9508. }
  9509. }
  9510. }
  9511. // ggml_compute_forward_add_rel_pos
  9512. static void ggml_compute_forward_add_rel_pos_f32(
  9513. const struct ggml_compute_params * params,
  9514. struct ggml_tensor * dst) {
  9515. const struct ggml_tensor * src0 = dst->src[0];
  9516. const struct ggml_tensor * src1 = dst->src[1];
  9517. const struct ggml_tensor * src2 = dst->src[2];
  9518. const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
  9519. if (!inplace) {
  9520. if (params->ith == 0) {
  9521. memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
  9522. }
  9523. ggml_barrier(params->threadpool);
  9524. }
  9525. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
  9526. float * src1_data = (float *) src1->data;
  9527. float * src2_data = (float *) src2->data;
  9528. float * dst_data = (float *) dst->data;
  9529. const int64_t ne10 = src1->ne[0];
  9530. const int64_t ne11 = src1->ne[1];
  9531. const int64_t ne12 = src1->ne[2];
  9532. const int64_t ne13 = src1->ne[3];
  9533. const int ith = params->ith;
  9534. const int nth = params->nth;
  9535. // total patches in dst
  9536. const int np = ne13;
  9537. // patches per thread
  9538. const int dp = (np + nth - 1)/nth;
  9539. // patch range for this thread
  9540. const int ip0 = dp*ith;
  9541. const int ip1 = MIN(ip0 + dp, np);
  9542. for (int64_t i13 = ip0; i13 < ip1; ++i13) {
  9543. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  9544. for (int64_t i11 = 0; i11 < ne11; ++i11) {
  9545. const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
  9546. for (int64_t i10 = 0; i10 < ne10; ++i10) {
  9547. const int64_t jp0 = jp1 + i10;
  9548. const float src1_e = src1_data[jp0];
  9549. const float src2_e = src2_data[jp0];
  9550. const int64_t jdh = jp0 * ne10;
  9551. const int64_t jdw = jdh - (ne10 - 1) * i10;
  9552. for (int64_t j = 0; j < ne10; ++j) {
  9553. dst_data[jdh + j ] += src2_e;
  9554. dst_data[jdw + j*ne10] += src1_e;
  9555. }
  9556. }
  9557. }
  9558. }
  9559. }
  9560. }
  9561. static void ggml_compute_forward_add_rel_pos(
  9562. const struct ggml_compute_params * params,
  9563. struct ggml_tensor * dst) {
  9564. const struct ggml_tensor * src0 = dst->src[0];
  9565. switch (src0->type) {
  9566. case GGML_TYPE_F32:
  9567. {
  9568. ggml_compute_forward_add_rel_pos_f32(params, dst);
  9569. } break;
  9570. default:
  9571. {
  9572. GGML_ABORT("fatal error");
  9573. }
  9574. }
  9575. }
  9576. // ggml_compute_forward_rwkv_wkv
  9577. static void ggml_compute_forward_rwkv_wkv_f32(
  9578. const struct ggml_compute_params * params,
  9579. struct ggml_tensor * dst) {
  9580. const size_t T = dst->src[1]->ne[3];
  9581. const size_t C = dst->ne[0];
  9582. const size_t H = dst->src[1]->ne[2];
  9583. const size_t n_seqs = dst->src[5]->ne[1];
  9584. float * dst_data = (float *) dst->data;
  9585. float * state = ((float *) dst->data) + C * T;
  9586. if (params->ith != 0) {
  9587. return;
  9588. }
  9589. memset(dst_data, 0, T * C * sizeof(float));
  9590. float * k = (float *) dst->src[0]->data;
  9591. float * v = (float *) dst->src[1]->data;
  9592. float * r = (float *) dst->src[2]->data;
  9593. float * time_faaaa = (float *) dst->src[3]->data;
  9594. float * time_decay = (float *) dst->src[4]->data;
  9595. size_t t_stride = H * (C / H);
  9596. size_t h_stride = C / H;
  9597. size_t h_stride_2d = (C / H) * (C / H);
  9598. // basically fused operations:
  9599. // dst = r @ (time_faaaa * (k @ v) + state),
  9600. // state = time_decay * state + (k @ v),
  9601. // recursive through each token
  9602. for (size_t t = 0; t < T; t++) {
  9603. size_t t_offset = t * t_stride;
  9604. size_t state_offset = (C / H) * C * (t / (T / n_seqs));
  9605. float * state_cur = state + state_offset;
  9606. float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
  9607. for (size_t h = 0; h < H; h++) {
  9608. size_t h_offset = h * h_stride;
  9609. size_t t_h_offset = t_offset + h_offset;
  9610. size_t h_2d_offset = h * h_stride_2d;
  9611. for (size_t i = 0; i < C / H; i++) {
  9612. size_t t_h_i_offset = t_h_offset + i;
  9613. size_t h_i_offset = h_offset + i;
  9614. size_t h_2d_i_offset = h_2d_offset + i * h_stride;
  9615. float k_val = k[t_h_i_offset];
  9616. float r_val = r[t_h_i_offset];
  9617. float time_faaaa_val = time_faaaa[h_i_offset];
  9618. // RWKV v6: different time_decay for each token.
  9619. float time_decay_val = time_decay[t_h_i_offset];
  9620. for (size_t j = 0; j < C / H; j ++) {
  9621. size_t t_h_j_offset = t_h_offset + j;
  9622. size_t h_2d_i_j_offset = h_2d_i_offset + j;
  9623. float v_val = v[t_h_j_offset];
  9624. float kv_val = v_val * k_val;
  9625. float prev_state_val = state_prev[h_2d_i_j_offset];
  9626. float temp_val = kv_val * time_faaaa_val + prev_state_val;
  9627. dst_data[t_h_j_offset] += temp_val * r_val;
  9628. state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
  9629. }
  9630. }
  9631. }
  9632. }
  9633. }
  9634. static void ggml_compute_forward_rwkv_wkv(
  9635. const struct ggml_compute_params * params,
  9636. struct ggml_tensor * dst) {
  9637. const struct ggml_tensor * src0 = dst->src[0];
  9638. switch (src0->type) {
  9639. case GGML_TYPE_F32:
  9640. {
  9641. ggml_compute_forward_rwkv_wkv_f32(params, dst);
  9642. } break;
  9643. default:
  9644. {
  9645. GGML_ABORT("fatal error");
  9646. }
  9647. }
  9648. }
  9649. // ggml_compute_forward_map_unary
  9650. static void ggml_compute_forward_map_unary_f32(
  9651. const struct ggml_compute_params * params,
  9652. struct ggml_tensor * dst,
  9653. const ggml_unary_op_f32_t fun) {
  9654. const struct ggml_tensor * src0 = dst->src[0];
  9655. if (params->ith != 0) {
  9656. return;
  9657. }
  9658. assert(ggml_is_contiguous_1(src0));
  9659. assert(ggml_is_contiguous_1(dst));
  9660. assert(ggml_are_same_shape(src0, dst));
  9661. const int n = ggml_nrows(src0);
  9662. const int nc = src0->ne[0];
  9663. for (int i = 0; i < n; i++) {
  9664. fun(nc,
  9665. (float *) ((char *) dst->data + i*( dst->nb[1])),
  9666. (float *) ((char *) src0->data + i*(src0->nb[1])));
  9667. }
  9668. }
  9669. static void ggml_compute_forward_map_unary(
  9670. const struct ggml_compute_params * params,
  9671. struct ggml_tensor * dst,
  9672. const ggml_unary_op_f32_t fun) {
  9673. const struct ggml_tensor * src0 = dst->src[0];
  9674. switch (src0->type) {
  9675. case GGML_TYPE_F32:
  9676. {
  9677. ggml_compute_forward_map_unary_f32(params, dst, fun);
  9678. } break;
  9679. default:
  9680. {
  9681. GGML_ABORT("fatal error");
  9682. }
  9683. }
  9684. }
  9685. // ggml_compute_forward_map_binary
  9686. static void ggml_compute_forward_map_binary_f32(
  9687. const struct ggml_compute_params * params,
  9688. struct ggml_tensor * dst,
  9689. const ggml_binary_op_f32_t fun) {
  9690. const struct ggml_tensor * src0 = dst->src[0];
  9691. const struct ggml_tensor * src1 = dst->src[1];
  9692. if (params->ith != 0) {
  9693. return;
  9694. }
  9695. assert(ggml_is_contiguous_1(src0));
  9696. assert(ggml_is_contiguous_1(src1));
  9697. assert(ggml_is_contiguous_1(dst));
  9698. assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  9699. const int n = ggml_nrows(src0);
  9700. const int nc = src0->ne[0];
  9701. for (int i = 0; i < n; i++) {
  9702. fun(nc,
  9703. (float *) ((char *) dst->data + i*( dst->nb[1])),
  9704. (float *) ((char *) src0->data + i*(src0->nb[1])),
  9705. (float *) ((char *) src1->data + i*(src1->nb[1])));
  9706. }
  9707. }
  9708. static void ggml_compute_forward_map_binary(
  9709. const struct ggml_compute_params * params,
  9710. struct ggml_tensor * dst,
  9711. const ggml_binary_op_f32_t fun) {
  9712. const struct ggml_tensor * src0 = dst->src[0];
  9713. switch (src0->type) {
  9714. case GGML_TYPE_F32:
  9715. {
  9716. ggml_compute_forward_map_binary_f32(params, dst, fun);
  9717. } break;
  9718. default:
  9719. {
  9720. GGML_ABORT("fatal error");
  9721. }
  9722. }
  9723. }
  9724. // ggml_compute_forward_map_custom1
  9725. static void ggml_compute_forward_map_custom1_f32(
  9726. const struct ggml_compute_params * params,
  9727. struct ggml_tensor * dst,
  9728. const ggml_custom1_op_f32_t fun) {
  9729. const struct ggml_tensor * a = dst->src[0];
  9730. if (params->ith != 0) {
  9731. return;
  9732. }
  9733. fun(dst, a);
  9734. }
  9735. // ggml_compute_forward_map_custom2
  9736. static void ggml_compute_forward_map_custom2_f32(
  9737. const struct ggml_compute_params * params,
  9738. struct ggml_tensor * dst,
  9739. const ggml_custom2_op_f32_t fun) {
  9740. const struct ggml_tensor * a = dst->src[0];
  9741. const struct ggml_tensor * b = dst->src[1];
  9742. if (params->ith != 0) {
  9743. return;
  9744. }
  9745. fun(dst, a, b);
  9746. }
  9747. // ggml_compute_forward_map_custom3
  9748. static void ggml_compute_forward_map_custom3_f32(
  9749. const struct ggml_compute_params * params,
  9750. struct ggml_tensor * dst,
  9751. const ggml_custom3_op_f32_t fun) {
  9752. const struct ggml_tensor * a = dst->src[0];
  9753. const struct ggml_tensor * b = dst->src[1];
  9754. const struct ggml_tensor * c = dst->src[1];
  9755. if (params->ith != 0) {
  9756. return;
  9757. }
  9758. fun(dst, a, b, c);
  9759. }
  9760. // ggml_compute_forward_map_custom1
  9761. static void ggml_compute_forward_map_custom1(
  9762. const struct ggml_compute_params * params,
  9763. struct ggml_tensor * dst) {
  9764. const struct ggml_tensor * a = dst->src[0];
  9765. struct ggml_map_custom1_op_params p;
  9766. memcpy(&p, dst->op_params, sizeof(p));
  9767. p.fun(dst, a, params->ith, params->nth, p.userdata);
  9768. }
  9769. // ggml_compute_forward_map_custom2
  9770. static void ggml_compute_forward_map_custom2(
  9771. const struct ggml_compute_params * params,
  9772. struct ggml_tensor * dst) {
  9773. const struct ggml_tensor * a = dst->src[0];
  9774. const struct ggml_tensor * b = dst->src[1];
  9775. struct ggml_map_custom2_op_params p;
  9776. memcpy(&p, dst->op_params, sizeof(p));
  9777. p.fun(dst, a, b, params->ith, params->nth, p.userdata);
  9778. }
  9779. // ggml_compute_forward_map_custom3
  9780. static void ggml_compute_forward_map_custom3(
  9781. const struct ggml_compute_params * params,
  9782. struct ggml_tensor * dst) {
  9783. const struct ggml_tensor * a = dst->src[0];
  9784. const struct ggml_tensor * b = dst->src[1];
  9785. const struct ggml_tensor * c = dst->src[2];
  9786. struct ggml_map_custom3_op_params p;
  9787. memcpy(&p, dst->op_params, sizeof(p));
  9788. p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
  9789. }
  9790. // ggml_compute_forward_cross_entropy_loss
  9791. static void ggml_compute_forward_cross_entropy_loss_f32(
  9792. const struct ggml_compute_params * params,
  9793. struct ggml_tensor * dst) {
  9794. const struct ggml_tensor * src0 = dst->src[0];
  9795. const struct ggml_tensor * src1 = dst->src[1];
  9796. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  9797. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  9798. GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
  9799. GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
  9800. GGML_ASSERT(ggml_are_same_shape(src0, src1));
  9801. GGML_ASSERT(ggml_is_scalar(dst));
  9802. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  9803. // TODO: handle transposed/permuted matrices
  9804. const int64_t nc = src0->ne[0];
  9805. const int64_t nr = ggml_nrows(src0);
  9806. const int ith = params->ith;
  9807. const int nth = params->nth;
  9808. float * sums = (float *) params->wdata;
  9809. float * st = ((float *) params->wdata) + nth + ith*nc;
  9810. float sum_thread = 0.0f;
  9811. GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
  9812. // rows per thread
  9813. const int64_t dr = (nr + nth - 1)/nth;
  9814. // row range for this thread
  9815. const int64_t ir0 = dr*ith;
  9816. const int64_t ir1 = MIN(ir0 + dr, nr);
  9817. for (int64_t i1 = ir0; i1 < ir1; ++i1) {
  9818. const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
  9819. const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
  9820. #ifndef NDEBUG
  9821. for (int64_t i = 0; i < nc; ++i) {
  9822. //printf("p[%d] = %f\n", i, p[i]);
  9823. assert(!isnan(s0[i]));
  9824. assert(!isnan(s1[i]));
  9825. }
  9826. #endif
  9827. float max = -INFINITY;
  9828. ggml_vec_max_f32(nc, &max, s0);
  9829. const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
  9830. assert(sum_softmax >= 0.0);
  9831. ggml_vec_add1_f32(nc, st, st, -sum_softmax);
  9832. ggml_vec_mul_f32(nc, st, st, s1);
  9833. float sum_st = 0.0f;
  9834. ggml_vec_sum_f32(nc, &sum_st, st);
  9835. sum_thread += sum_st;
  9836. #ifndef NDEBUG
  9837. for (int64_t i = 0; i < nc; ++i) {
  9838. assert(!isnan(st[i]));
  9839. assert(!isinf(st[i]));
  9840. }
  9841. #endif
  9842. }
  9843. sums[ith] = sum_thread;
  9844. ggml_barrier(params->threadpool);
  9845. if (ith == 0) {
  9846. float * dp = (float *) dst->data;
  9847. ggml_vec_sum_f32(nth, dp, sums);
  9848. dp[0] *= -1.0f / (float) nr;
  9849. }
  9850. }
  9851. static void ggml_compute_forward_cross_entropy_loss(
  9852. const struct ggml_compute_params * params,
  9853. struct ggml_tensor * dst) {
  9854. const struct ggml_tensor * src0 = dst->src[0];
  9855. switch (src0->type) {
  9856. case GGML_TYPE_F32:
  9857. {
  9858. ggml_compute_forward_cross_entropy_loss_f32(params, dst);
  9859. } break;
  9860. default:
  9861. {
  9862. GGML_ABORT("fatal error");
  9863. }
  9864. }
  9865. }
  9866. // ggml_compute_forward_cross_entropy_loss_back
  9867. static void ggml_compute_forward_cross_entropy_loss_back_f32(
  9868. const struct ggml_compute_params * params,
  9869. struct ggml_tensor * dst) {
  9870. const struct ggml_tensor * src0 = dst->src[0];
  9871. const struct ggml_tensor * src1 = dst->src[1];
  9872. const struct ggml_tensor * opt0 = dst->src[2];
  9873. GGML_ASSERT(ggml_is_contiguous(dst));
  9874. GGML_ASSERT(ggml_is_contiguous(src0));
  9875. GGML_ASSERT(ggml_is_contiguous(src1));
  9876. GGML_ASSERT(ggml_is_contiguous(opt0));
  9877. GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  9878. const int64_t ith = params->ith;
  9879. const int64_t nth = params->nth;
  9880. // TODO: handle transposed/permuted matrices
  9881. const int64_t nc = src0->ne[0];
  9882. const int64_t nr = ggml_nrows(src0);
  9883. // rows per thread
  9884. const int64_t dr = (nr + nth - 1)/nth;
  9885. // row range for this thread
  9886. const int64_t ir0 = dr*ith;
  9887. const int64_t ir1 = MIN(ir0 + dr, nr);
  9888. const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
  9889. for (int64_t i1 = ir0; i1 < ir1; i1++) {
  9890. float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
  9891. float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
  9892. float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
  9893. #ifndef NDEBUG
  9894. for (int64_t i = 0; i < nc; ++i) {
  9895. //printf("p[%d] = %f\n", i, p[i]);
  9896. assert(!isnan(s0[i]));
  9897. assert(!isnan(s1[i]));
  9898. }
  9899. #endif
  9900. // soft_max
  9901. float max = -INFINITY;
  9902. ggml_vec_max_f32(nc, &max, s0);
  9903. ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
  9904. assert(sum > 0.0);
  9905. ggml_vec_scale_f32(nc, ds0, 1.0/sum);
  9906. // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
  9907. ggml_vec_sub_f32(nc, ds0, ds0, s1);
  9908. ggml_vec_scale_f32(nc, ds0, d_by_nr);
  9909. #ifndef NDEBUG
  9910. for (int64_t i = 0; i < nc; ++i) {
  9911. assert(!isnan(ds0[i]));
  9912. assert(!isinf(ds0[i]));
  9913. }
  9914. #endif
  9915. }
  9916. }
  9917. static void ggml_compute_forward_cross_entropy_loss_back(
  9918. const struct ggml_compute_params * params,
  9919. struct ggml_tensor * dst) {
  9920. const struct ggml_tensor * src0 = dst->src[0];
  9921. switch (src0->type) {
  9922. case GGML_TYPE_F32:
  9923. {
  9924. ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
  9925. } break;
  9926. default:
  9927. {
  9928. GGML_ABORT("fatal error");
  9929. }
  9930. }
  9931. }
  9932. static void ggml_compute_forward_opt_step_adamw_f32(
  9933. const struct ggml_compute_params * params,
  9934. struct ggml_tensor * dst) {
  9935. const struct ggml_tensor * src0 = dst->src[0];
  9936. const struct ggml_tensor * src0_grad = dst->src[1];
  9937. const struct ggml_tensor * src0_grad_m = dst->src[2];
  9938. const struct ggml_tensor * src0_grad_v = dst->src[3];
  9939. GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
  9940. const int ith = params->ith;
  9941. const int nth = params->nth;
  9942. const int nr = ggml_nrows(src0);
  9943. GGML_TENSOR_UNARY_OP_LOCALS
  9944. GGML_ASSERT(nb00 == sizeof(float));
  9945. // rows per thread
  9946. const int dr = (nr + nth - 1)/nth;
  9947. // row range for this thread
  9948. const int ir0 = dr*ith;
  9949. const int ir1 = MIN(ir0 + dr, nr);
  9950. /* const float gnorm = 1.0f; */
  9951. int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
  9952. const float alpha = ggml_get_op_params_f32(dst, 2);
  9953. const float beta1 = ggml_get_op_params_f32(dst, 3);
  9954. const float beta2 = ggml_get_op_params_f32(dst, 4);
  9955. const float eps = ggml_get_op_params_f32(dst, 5);
  9956. const float wd = ggml_get_op_params_f32(dst, 6);
  9957. const float beta1h = alpha/(1.0f - powf(beta1, iter));
  9958. const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
  9959. for (int ir = ir0; ir < ir1; ++ir) {
  9960. const int64_t i03 = ir/(ne02*ne01);
  9961. const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
  9962. const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
  9963. const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
  9964. float * w = (float *) ((char *) src0->data + offset); // weight
  9965. const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
  9966. float * m = (float *) ((char *) src0_grad_m->data + offset);
  9967. float * v = (float *) ((char *) src0_grad_v->data + offset);
  9968. for (int i00 = 0; i00 < ne00; ++i00) {
  9969. m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
  9970. v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
  9971. const float mh = m[i00]*beta1h;
  9972. const float vh = sqrtf(v[i00]*beta2h) + eps;
  9973. // The weight decay is applied independently of the Adam momenta m and v.
  9974. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
  9975. // See: https://arxiv.org/pdf/1711.05101v3.pdf
  9976. w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
  9977. }
  9978. }
  9979. ggml_barrier(params->threadpool);
  9980. if (ith != 0) {
  9981. return;
  9982. }
  9983. iter++;
  9984. memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
  9985. }
  9986. static void ggml_compute_forward_opt_step_adamw(
  9987. const struct ggml_compute_params * params,
  9988. struct ggml_tensor * dst) {
  9989. const struct ggml_tensor * src0 = dst->src[0];
  9990. switch (src0->type) {
  9991. case GGML_TYPE_F32:
  9992. {
  9993. ggml_compute_forward_opt_step_adamw_f32(params, dst);
  9994. } break;
  9995. default:
  9996. {
  9997. GGML_ABORT("fatal error");
  9998. }
  9999. }
  10000. }
  10001. /////////////////////////////////
  10002. static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
  10003. GGML_ASSERT(params);
  10004. if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
  10005. return;
  10006. }
  10007. switch (tensor->op) {
  10008. case GGML_OP_DUP:
  10009. {
  10010. ggml_compute_forward_dup(params, tensor);
  10011. } break;
  10012. case GGML_OP_ADD:
  10013. {
  10014. ggml_compute_forward_add(params, tensor);
  10015. } break;
  10016. case GGML_OP_ADD1:
  10017. {
  10018. ggml_compute_forward_add1(params, tensor);
  10019. } break;
  10020. case GGML_OP_ACC:
  10021. {
  10022. ggml_compute_forward_acc(params, tensor);
  10023. } break;
  10024. case GGML_OP_SUB:
  10025. {
  10026. ggml_compute_forward_sub(params, tensor);
  10027. } break;
  10028. case GGML_OP_MUL:
  10029. {
  10030. ggml_compute_forward_mul(params, tensor);
  10031. } break;
  10032. case GGML_OP_DIV:
  10033. {
  10034. ggml_compute_forward_div(params, tensor);
  10035. } break;
  10036. case GGML_OP_SQR:
  10037. {
  10038. ggml_compute_forward_sqr(params, tensor);
  10039. } break;
  10040. case GGML_OP_SQRT:
  10041. {
  10042. ggml_compute_forward_sqrt(params, tensor);
  10043. } break;
  10044. case GGML_OP_LOG:
  10045. {
  10046. ggml_compute_forward_log(params, tensor);
  10047. } break;
  10048. case GGML_OP_SIN:
  10049. {
  10050. ggml_compute_forward_sin(params, tensor);
  10051. } break;
  10052. case GGML_OP_COS:
  10053. {
  10054. ggml_compute_forward_cos(params, tensor);
  10055. } break;
  10056. case GGML_OP_SUM:
  10057. {
  10058. ggml_compute_forward_sum(params, tensor);
  10059. } break;
  10060. case GGML_OP_SUM_ROWS:
  10061. {
  10062. ggml_compute_forward_sum_rows(params, tensor);
  10063. } break;
  10064. case GGML_OP_MEAN:
  10065. {
  10066. ggml_compute_forward_mean(params, tensor);
  10067. } break;
  10068. case GGML_OP_ARGMAX:
  10069. {
  10070. ggml_compute_forward_argmax(params, tensor);
  10071. } break;
  10072. case GGML_OP_COUNT_EQUAL:
  10073. {
  10074. ggml_compute_forward_count_equal(params, tensor);
  10075. } break;
  10076. case GGML_OP_REPEAT:
  10077. {
  10078. ggml_compute_forward_repeat(params, tensor);
  10079. } break;
  10080. case GGML_OP_REPEAT_BACK:
  10081. {
  10082. ggml_compute_forward_repeat_back(params, tensor);
  10083. } break;
  10084. case GGML_OP_CONCAT:
  10085. {
  10086. ggml_compute_forward_concat(params, tensor);
  10087. } break;
  10088. case GGML_OP_SILU_BACK:
  10089. {
  10090. ggml_compute_forward_silu_back(params, tensor);
  10091. } break;
  10092. case GGML_OP_NORM:
  10093. {
  10094. ggml_compute_forward_norm(params, tensor);
  10095. } break;
  10096. case GGML_OP_RMS_NORM:
  10097. {
  10098. ggml_compute_forward_rms_norm(params, tensor);
  10099. } break;
  10100. case GGML_OP_RMS_NORM_BACK:
  10101. {
  10102. ggml_compute_forward_rms_norm_back(params, tensor);
  10103. } break;
  10104. case GGML_OP_GROUP_NORM:
  10105. {
  10106. ggml_compute_forward_group_norm(params, tensor);
  10107. } break;
  10108. case GGML_OP_MUL_MAT:
  10109. {
  10110. ggml_compute_forward_mul_mat(params, tensor);
  10111. } break;
  10112. case GGML_OP_MUL_MAT_ID:
  10113. {
  10114. ggml_compute_forward_mul_mat_id(params, tensor);
  10115. } break;
  10116. case GGML_OP_OUT_PROD:
  10117. {
  10118. ggml_compute_forward_out_prod(params, tensor);
  10119. } break;
  10120. case GGML_OP_SCALE:
  10121. {
  10122. ggml_compute_forward_scale(params, tensor);
  10123. } break;
  10124. case GGML_OP_SET:
  10125. {
  10126. ggml_compute_forward_set(params, tensor);
  10127. } break;
  10128. case GGML_OP_CPY:
  10129. {
  10130. ggml_compute_forward_cpy(params, tensor);
  10131. } break;
  10132. case GGML_OP_CONT:
  10133. {
  10134. ggml_compute_forward_cont(params, tensor);
  10135. } break;
  10136. case GGML_OP_RESHAPE:
  10137. {
  10138. ggml_compute_forward_reshape(params, tensor);
  10139. } break;
  10140. case GGML_OP_VIEW:
  10141. {
  10142. ggml_compute_forward_view(params, tensor);
  10143. } break;
  10144. case GGML_OP_PERMUTE:
  10145. {
  10146. ggml_compute_forward_permute(params, tensor);
  10147. } break;
  10148. case GGML_OP_TRANSPOSE:
  10149. {
  10150. ggml_compute_forward_transpose(params, tensor);
  10151. } break;
  10152. case GGML_OP_GET_ROWS:
  10153. {
  10154. ggml_compute_forward_get_rows(params, tensor);
  10155. } break;
  10156. case GGML_OP_GET_ROWS_BACK:
  10157. {
  10158. ggml_compute_forward_get_rows_back(params, tensor);
  10159. } break;
  10160. case GGML_OP_DIAG:
  10161. {
  10162. ggml_compute_forward_diag(params, tensor);
  10163. } break;
  10164. case GGML_OP_DIAG_MASK_INF:
  10165. {
  10166. ggml_compute_forward_diag_mask_inf(params, tensor);
  10167. } break;
  10168. case GGML_OP_DIAG_MASK_ZERO:
  10169. {
  10170. ggml_compute_forward_diag_mask_zero(params, tensor);
  10171. } break;
  10172. case GGML_OP_SOFT_MAX:
  10173. {
  10174. ggml_compute_forward_soft_max(params, tensor);
  10175. } break;
  10176. case GGML_OP_SOFT_MAX_BACK:
  10177. {
  10178. ggml_compute_forward_soft_max_back(params, tensor);
  10179. } break;
  10180. case GGML_OP_ROPE:
  10181. {
  10182. ggml_compute_forward_rope(params, tensor);
  10183. } break;
  10184. case GGML_OP_ROPE_BACK:
  10185. {
  10186. ggml_compute_forward_rope_back(params, tensor);
  10187. } break;
  10188. case GGML_OP_CLAMP:
  10189. {
  10190. ggml_compute_forward_clamp(params, tensor);
  10191. } break;
  10192. case GGML_OP_CONV_TRANSPOSE_1D:
  10193. {
  10194. ggml_compute_forward_conv_transpose_1d(params, tensor);
  10195. } break;
  10196. case GGML_OP_IM2COL:
  10197. {
  10198. ggml_compute_forward_im2col(params, tensor);
  10199. } break;
  10200. case GGML_OP_IM2COL_BACK:
  10201. {
  10202. ggml_compute_forward_im2col_back_f32(params, tensor);
  10203. } break;
  10204. case GGML_OP_CONV_TRANSPOSE_2D:
  10205. {
  10206. ggml_compute_forward_conv_transpose_2d(params, tensor);
  10207. } break;
  10208. case GGML_OP_POOL_1D:
  10209. {
  10210. ggml_compute_forward_pool_1d(params, tensor);
  10211. } break;
  10212. case GGML_OP_POOL_2D:
  10213. {
  10214. ggml_compute_forward_pool_2d(params, tensor);
  10215. } break;
  10216. case GGML_OP_POOL_2D_BACK:
  10217. {
  10218. ggml_compute_forward_pool_2d_back(params, tensor);
  10219. } break;
  10220. case GGML_OP_UPSCALE:
  10221. {
  10222. ggml_compute_forward_upscale(params, tensor);
  10223. } break;
  10224. case GGML_OP_PAD:
  10225. {
  10226. ggml_compute_forward_pad(params, tensor);
  10227. } break;
  10228. case GGML_OP_ARANGE:
  10229. {
  10230. ggml_compute_forward_arange(params, tensor);
  10231. } break;
  10232. case GGML_OP_TIMESTEP_EMBEDDING:
  10233. {
  10234. ggml_compute_forward_timestep_embedding(params, tensor);
  10235. } break;
  10236. case GGML_OP_ARGSORT:
  10237. {
  10238. ggml_compute_forward_argsort(params, tensor);
  10239. } break;
  10240. case GGML_OP_LEAKY_RELU:
  10241. {
  10242. ggml_compute_forward_leaky_relu(params, tensor);
  10243. } break;
  10244. case GGML_OP_FLASH_ATTN_EXT:
  10245. {
  10246. ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
  10247. } break;
  10248. case GGML_OP_FLASH_ATTN_BACK:
  10249. {
  10250. int32_t t = ggml_get_op_params_i32(tensor, 0);
  10251. GGML_ASSERT(t == 0 || t == 1);
  10252. bool masked = t != 0;
  10253. ggml_compute_forward_flash_attn_back(params, masked, tensor);
  10254. } break;
  10255. case GGML_OP_SSM_CONV:
  10256. {
  10257. ggml_compute_forward_ssm_conv(params, tensor);
  10258. } break;
  10259. case GGML_OP_SSM_SCAN:
  10260. {
  10261. ggml_compute_forward_ssm_scan(params, tensor);
  10262. } break;
  10263. case GGML_OP_WIN_PART:
  10264. {
  10265. ggml_compute_forward_win_part(params, tensor);
  10266. } break;
  10267. case GGML_OP_WIN_UNPART:
  10268. {
  10269. ggml_compute_forward_win_unpart(params, tensor);
  10270. } break;
  10271. case GGML_OP_UNARY:
  10272. {
  10273. ggml_compute_forward_unary(params, tensor);
  10274. } break;
  10275. case GGML_OP_GET_REL_POS:
  10276. {
  10277. ggml_compute_forward_get_rel_pos(params, tensor);
  10278. } break;
  10279. case GGML_OP_ADD_REL_POS:
  10280. {
  10281. ggml_compute_forward_add_rel_pos(params, tensor);
  10282. } break;
  10283. case GGML_OP_RWKV_WKV:
  10284. {
  10285. ggml_compute_forward_rwkv_wkv(params, tensor);
  10286. } break;
  10287. case GGML_OP_MAP_UNARY:
  10288. {
  10289. ggml_unary_op_f32_t fun;
  10290. memcpy(&fun, tensor->op_params, sizeof(fun));
  10291. ggml_compute_forward_map_unary(params, tensor, fun);
  10292. }
  10293. break;
  10294. case GGML_OP_MAP_BINARY:
  10295. {
  10296. ggml_binary_op_f32_t fun;
  10297. memcpy(&fun, tensor->op_params, sizeof(fun));
  10298. ggml_compute_forward_map_binary(params, tensor, fun);
  10299. }
  10300. break;
  10301. case GGML_OP_MAP_CUSTOM1_F32:
  10302. {
  10303. ggml_custom1_op_f32_t fun;
  10304. memcpy(&fun, tensor->op_params, sizeof(fun));
  10305. ggml_compute_forward_map_custom1_f32(params, tensor, fun);
  10306. }
  10307. break;
  10308. case GGML_OP_MAP_CUSTOM2_F32:
  10309. {
  10310. ggml_custom2_op_f32_t fun;
  10311. memcpy(&fun, tensor->op_params, sizeof(fun));
  10312. ggml_compute_forward_map_custom2_f32(params, tensor, fun);
  10313. }
  10314. break;
  10315. case GGML_OP_MAP_CUSTOM3_F32:
  10316. {
  10317. ggml_custom3_op_f32_t fun;
  10318. memcpy(&fun, tensor->op_params, sizeof(fun));
  10319. ggml_compute_forward_map_custom3_f32(params, tensor, fun);
  10320. }
  10321. break;
  10322. case GGML_OP_MAP_CUSTOM1:
  10323. {
  10324. ggml_compute_forward_map_custom1(params, tensor);
  10325. }
  10326. break;
  10327. case GGML_OP_MAP_CUSTOM2:
  10328. {
  10329. ggml_compute_forward_map_custom2(params, tensor);
  10330. }
  10331. break;
  10332. case GGML_OP_MAP_CUSTOM3:
  10333. {
  10334. ggml_compute_forward_map_custom3(params, tensor);
  10335. }
  10336. break;
  10337. case GGML_OP_CROSS_ENTROPY_LOSS:
  10338. {
  10339. ggml_compute_forward_cross_entropy_loss(params, tensor);
  10340. }
  10341. break;
  10342. case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
  10343. {
  10344. ggml_compute_forward_cross_entropy_loss_back(params, tensor);
  10345. }
  10346. break;
  10347. case GGML_OP_OPT_STEP_ADAMW:
  10348. {
  10349. ggml_compute_forward_opt_step_adamw(params, tensor);
  10350. }
  10351. break;
  10352. case GGML_OP_NONE:
  10353. {
  10354. // nop
  10355. } break;
  10356. case GGML_OP_COUNT:
  10357. {
  10358. GGML_ABORT("fatal error");
  10359. }
  10360. }
  10361. }
  10362. // Android's libc implementation "bionic" does not support setting affinity
  10363. #if defined(__gnu_linux__)
  10364. static void set_numa_thread_affinity(int thread_n) {
  10365. if (!ggml_is_numa()) {
  10366. return;
  10367. }
  10368. int node_num;
  10369. int rv;
  10370. size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
  10371. switch(g_state.numa.numa_strategy) {
  10372. case GGML_NUMA_STRATEGY_DISTRIBUTE:
  10373. // run thread on node_num thread_n / (threads per node)
  10374. node_num = thread_n % g_state.numa.n_nodes;
  10375. break;
  10376. case GGML_NUMA_STRATEGY_ISOLATE:
  10377. // run thread on current_node
  10378. node_num = g_state.numa.current_node;
  10379. break;
  10380. case GGML_NUMA_STRATEGY_NUMACTL:
  10381. // use the cpuset that numactl gave us
  10382. rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
  10383. if (rv) {
  10384. fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
  10385. }
  10386. return;
  10387. default:
  10388. return;
  10389. }
  10390. struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
  10391. cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
  10392. CPU_ZERO_S(setsize, cpus);
  10393. for (size_t i = 0; i < node->n_cpus; ++i) {
  10394. CPU_SET_S(node->cpus[i], setsize, cpus);
  10395. }
  10396. rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
  10397. if (rv) {
  10398. fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
  10399. }
  10400. CPU_FREE(cpus);
  10401. }
  10402. static void clear_numa_thread_affinity(void) {
  10403. if (!ggml_is_numa()) {
  10404. return;
  10405. }
  10406. size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
  10407. cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
  10408. CPU_ZERO_S(setsize, cpus);
  10409. for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
  10410. CPU_SET_S(i, setsize, cpus);
  10411. }
  10412. int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
  10413. if (rv) {
  10414. fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
  10415. }
  10416. CPU_FREE(cpus);
  10417. }
  10418. #else
  10419. // TODO: Windows etc.
  10420. // (the linux implementation may also work on BSD, someone should test)
  10421. static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
  10422. static void clear_numa_thread_affinity(void) {}
  10423. #endif
  10424. static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
  10425. int n_tasks = 0;
  10426. if (ggml_is_empty(node)) {
  10427. // no need to multi-thread a no-op
  10428. n_tasks = 1;
  10429. return n_tasks;
  10430. }
  10431. switch (node->op) {
  10432. case GGML_OP_CPY:
  10433. case GGML_OP_DUP:
  10434. case GGML_OP_CONT:
  10435. case GGML_OP_ADD:
  10436. case GGML_OP_ADD1:
  10437. case GGML_OP_ACC:
  10438. {
  10439. n_tasks = n_threads;
  10440. } break;
  10441. case GGML_OP_SUB:
  10442. case GGML_OP_SQR:
  10443. case GGML_OP_SQRT:
  10444. case GGML_OP_LOG:
  10445. case GGML_OP_SIN:
  10446. case GGML_OP_COS:
  10447. case GGML_OP_SUM:
  10448. case GGML_OP_SUM_ROWS:
  10449. case GGML_OP_MEAN:
  10450. case GGML_OP_ARGMAX:
  10451. {
  10452. n_tasks = 1;
  10453. } break;
  10454. case GGML_OP_COUNT_EQUAL:
  10455. {
  10456. n_tasks = n_threads;
  10457. } break;
  10458. case GGML_OP_REPEAT:
  10459. case GGML_OP_REPEAT_BACK:
  10460. case GGML_OP_LEAKY_RELU:
  10461. {
  10462. n_tasks = 1;
  10463. } break;
  10464. case GGML_OP_UNARY:
  10465. switch (ggml_get_unary_op(node)) {
  10466. case GGML_UNARY_OP_ABS:
  10467. case GGML_UNARY_OP_SGN:
  10468. case GGML_UNARY_OP_NEG:
  10469. case GGML_UNARY_OP_STEP:
  10470. case GGML_UNARY_OP_TANH:
  10471. case GGML_UNARY_OP_ELU:
  10472. case GGML_UNARY_OP_RELU:
  10473. case GGML_UNARY_OP_SIGMOID:
  10474. case GGML_UNARY_OP_HARDSWISH:
  10475. case GGML_UNARY_OP_HARDSIGMOID:
  10476. case GGML_UNARY_OP_EXP:
  10477. {
  10478. n_tasks = 1;
  10479. } break;
  10480. case GGML_UNARY_OP_GELU:
  10481. case GGML_UNARY_OP_GELU_QUICK:
  10482. case GGML_UNARY_OP_SILU:
  10483. {
  10484. n_tasks = n_threads;
  10485. } break;
  10486. default:
  10487. GGML_ABORT("fatal error");
  10488. }
  10489. break;
  10490. case GGML_OP_SILU_BACK:
  10491. case GGML_OP_MUL:
  10492. case GGML_OP_DIV:
  10493. case GGML_OP_NORM:
  10494. case GGML_OP_RMS_NORM:
  10495. case GGML_OP_RMS_NORM_BACK:
  10496. case GGML_OP_GROUP_NORM:
  10497. case GGML_OP_CONCAT:
  10498. case GGML_OP_MUL_MAT:
  10499. case GGML_OP_MUL_MAT_ID:
  10500. case GGML_OP_OUT_PROD:
  10501. {
  10502. n_tasks = n_threads;
  10503. } break;
  10504. case GGML_OP_GET_ROWS:
  10505. {
  10506. // FIXME: get_rows can use additional threads, but the cost of launching additional threads
  10507. // decreases performance with GPU offloading
  10508. //n_tasks = n_threads;
  10509. n_tasks = 1;
  10510. } break;
  10511. case GGML_OP_SCALE:
  10512. case GGML_OP_SET:
  10513. case GGML_OP_RESHAPE:
  10514. case GGML_OP_VIEW:
  10515. case GGML_OP_PERMUTE:
  10516. case GGML_OP_TRANSPOSE:
  10517. case GGML_OP_GET_ROWS_BACK:
  10518. case GGML_OP_DIAG:
  10519. {
  10520. n_tasks = 1;
  10521. } break;
  10522. case GGML_OP_DIAG_MASK_ZERO:
  10523. case GGML_OP_DIAG_MASK_INF:
  10524. case GGML_OP_SOFT_MAX_BACK:
  10525. case GGML_OP_ROPE:
  10526. case GGML_OP_ROPE_BACK:
  10527. case GGML_OP_ADD_REL_POS:
  10528. {
  10529. n_tasks = n_threads;
  10530. } break;
  10531. case GGML_OP_CLAMP:
  10532. {
  10533. n_tasks = 1; //TODO
  10534. } break;
  10535. case GGML_OP_SOFT_MAX:
  10536. {
  10537. n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
  10538. } break;
  10539. case GGML_OP_IM2COL:
  10540. case GGML_OP_IM2COL_BACK:
  10541. case GGML_OP_CONV_TRANSPOSE_1D:
  10542. case GGML_OP_CONV_TRANSPOSE_2D:
  10543. {
  10544. n_tasks = n_threads;
  10545. } break;
  10546. case GGML_OP_POOL_1D:
  10547. case GGML_OP_POOL_2D:
  10548. case GGML_OP_POOL_2D_BACK:
  10549. {
  10550. n_tasks = 1;
  10551. } break;
  10552. case GGML_OP_UPSCALE:
  10553. case GGML_OP_PAD:
  10554. case GGML_OP_ARANGE:
  10555. case GGML_OP_TIMESTEP_EMBEDDING:
  10556. case GGML_OP_ARGSORT:
  10557. case GGML_OP_FLASH_ATTN_EXT:
  10558. case GGML_OP_FLASH_ATTN_BACK:
  10559. case GGML_OP_SSM_CONV:
  10560. case GGML_OP_SSM_SCAN:
  10561. {
  10562. n_tasks = n_threads;
  10563. } break;
  10564. case GGML_OP_WIN_PART:
  10565. case GGML_OP_WIN_UNPART:
  10566. case GGML_OP_GET_REL_POS:
  10567. case GGML_OP_RWKV_WKV:
  10568. case GGML_OP_MAP_UNARY:
  10569. case GGML_OP_MAP_BINARY:
  10570. case GGML_OP_MAP_CUSTOM1_F32:
  10571. case GGML_OP_MAP_CUSTOM2_F32:
  10572. case GGML_OP_MAP_CUSTOM3_F32:
  10573. {
  10574. n_tasks = 1;
  10575. } break;
  10576. case GGML_OP_MAP_CUSTOM1:
  10577. {
  10578. struct ggml_map_custom1_op_params p;
  10579. memcpy(&p, node->op_params, sizeof(p));
  10580. if (p.n_tasks == GGML_N_TASKS_MAX) {
  10581. n_tasks = n_threads;
  10582. } else {
  10583. n_tasks = MIN(p.n_tasks, n_threads);
  10584. }
  10585. } break;
  10586. case GGML_OP_MAP_CUSTOM2:
  10587. {
  10588. struct ggml_map_custom2_op_params p;
  10589. memcpy(&p, node->op_params, sizeof(p));
  10590. if (p.n_tasks == GGML_N_TASKS_MAX) {
  10591. n_tasks = n_threads;
  10592. } else {
  10593. n_tasks = MIN(p.n_tasks, n_threads);
  10594. }
  10595. } break;
  10596. case GGML_OP_MAP_CUSTOM3:
  10597. {
  10598. struct ggml_map_custom3_op_params p;
  10599. memcpy(&p, node->op_params, sizeof(p));
  10600. if (p.n_tasks == GGML_N_TASKS_MAX) {
  10601. n_tasks = n_threads;
  10602. } else {
  10603. n_tasks = MIN(p.n_tasks, n_threads);
  10604. }
  10605. } break;
  10606. case GGML_OP_CROSS_ENTROPY_LOSS:
  10607. case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
  10608. case GGML_OP_OPT_STEP_ADAMW:
  10609. {
  10610. n_tasks = n_threads;
  10611. } break;
  10612. case GGML_OP_NONE:
  10613. {
  10614. n_tasks = 1;
  10615. } break;
  10616. case GGML_OP_COUNT:
  10617. {
  10618. GGML_ABORT("fatal error");
  10619. }
  10620. default:
  10621. {
  10622. fprintf(stderr, "%s: op not implemented: ", __func__);
  10623. if (node->op < GGML_OP_COUNT) {
  10624. fprintf(stderr, "%s\n", ggml_op_name(node->op));
  10625. } else {
  10626. fprintf(stderr, "%d\n", node->op);
  10627. }
  10628. GGML_ABORT("fatal error");
  10629. }
  10630. }
  10631. assert(n_tasks > 0);
  10632. return n_tasks;
  10633. }
  10634. static thread_ret_t ggml_graph_compute_secondary_thread(void* data);
  10635. #if defined(_WIN32)
  10636. #include "windows.h"
  10637. // TODO: support > 64 CPUs
  10638. bool ggml_thread_apply_affinity(bool * mask) {
  10639. HANDLE h = GetCurrentThread();
  10640. uint64_t bitmask = 0ULL;
  10641. assert(GGML_MAX_N_THREADS >= 64);
  10642. for (int32_t i = 0; i < 8; i++) {
  10643. int32_t idx = i * 8;
  10644. uint8_t val = 0;
  10645. val |= mask[idx + 0] << 0;
  10646. val |= mask[idx + 1] << 1;
  10647. val |= mask[idx + 2] << 2;
  10648. val |= mask[idx + 3] << 3;
  10649. val |= mask[idx + 4] << 4;
  10650. val |= mask[idx + 5] << 5;
  10651. val |= mask[idx + 6] << 6;
  10652. val |= mask[idx + 7] << 7;
  10653. bitmask |= (uint64_t)val << idx;
  10654. }
  10655. for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) {
  10656. if (mask[i]) {
  10657. fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
  10658. break;
  10659. }
  10660. }
  10661. DWORD_PTR m = (DWORD_PTR)bitmask;
  10662. m = SetThreadAffinityMask(h, m);
  10663. return m != 0;
  10664. }
  10665. static bool ggml_thread_apply_priority(int32_t prio) {
  10666. // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
  10667. // This is up to the applications.
  10668. DWORD p = THREAD_PRIORITY_NORMAL;
  10669. switch (prio) {
  10670. case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
  10671. case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
  10672. case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
  10673. case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
  10674. }
  10675. if (prio == GGML_SCHED_PRIO_NORMAL) {
  10676. // Keep inherited policy/priority
  10677. return true;
  10678. }
  10679. if (!SetThreadPriority(GetCurrentThread(), p)) {
  10680. fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
  10681. return false;
  10682. }
  10683. return true;
  10684. }
  10685. #elif defined(__APPLE__)
  10686. #include <sys/types.h>
  10687. #include <sys/resource.h>
  10688. static bool ggml_thread_apply_affinity(const bool * mask) {
  10689. // Not supported on Apple platforms
  10690. UNUSED(mask);
  10691. return true;
  10692. }
  10693. static bool ggml_thread_apply_priority(int32_t prio) {
  10694. struct sched_param p;
  10695. int32_t policy = SCHED_OTHER;
  10696. switch (prio) {
  10697. case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
  10698. case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
  10699. case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
  10700. case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
  10701. }
  10702. if (prio == GGML_SCHED_PRIO_NORMAL) {
  10703. // Keep inherited policy/priority
  10704. return true;
  10705. }
  10706. int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
  10707. if (err != 0) {
  10708. fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
  10709. return false;
  10710. }
  10711. return true;
  10712. }
  10713. #elif defined(__gnu_linux__)
  10714. // TODO: this may not work on BSD, to be verified
  10715. static bool ggml_thread_apply_affinity(const bool * mask) {
  10716. cpu_set_t cpuset;
  10717. int err;
  10718. CPU_ZERO(&cpuset);
  10719. for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
  10720. if (mask[i]) {
  10721. GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
  10722. CPU_SET(i, &cpuset);
  10723. }
  10724. }
  10725. #ifdef __ANDROID__
  10726. err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
  10727. if (err < 0) {
  10728. err = errno;
  10729. }
  10730. #else
  10731. err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
  10732. #endif
  10733. if (err != 0) {
  10734. fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
  10735. return false;
  10736. }
  10737. return true;
  10738. }
  10739. static bool ggml_thread_apply_priority(int32_t prio) {
  10740. struct sched_param p;
  10741. int32_t policy = SCHED_OTHER;
  10742. switch (prio) {
  10743. case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
  10744. case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
  10745. case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
  10746. case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
  10747. }
  10748. if (prio == GGML_SCHED_PRIO_NORMAL) {
  10749. // Keep inherited policy/priority
  10750. return true;
  10751. }
  10752. int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
  10753. if (err != 0) {
  10754. fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
  10755. return false;
  10756. }
  10757. return true;
  10758. }
  10759. #else // unsupported platforms
  10760. static bool ggml_thread_apply_affinity(const bool * mask) {
  10761. UNUSED(mask);
  10762. return true;
  10763. }
  10764. static bool ggml_thread_apply_priority(int32_t prio) {
  10765. UNUSED(prio);
  10766. return true;
  10767. }
  10768. #endif
  10769. static bool ggml_thread_cpumask_is_valid(const bool * mask) {
  10770. for (int i = 0; i < GGML_MAX_N_THREADS; i++) {
  10771. if (mask[i]) { return true; }
  10772. }
  10773. return false;
  10774. }
  10775. static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
  10776. if (!strict) {
  10777. memcpy(local_mask, global_mask, GGML_MAX_N_THREADS);
  10778. return;
  10779. } else {
  10780. memset(local_mask, 0, GGML_MAX_N_THREADS);
  10781. int32_t base_idx = *iter;
  10782. for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
  10783. int32_t idx = base_idx + i;
  10784. if (idx >= GGML_MAX_N_THREADS) {
  10785. // Just a cheaper modulo
  10786. idx -= GGML_MAX_N_THREADS;
  10787. }
  10788. if (global_mask[idx]) {
  10789. local_mask[idx] = 1;
  10790. *iter = idx + 1;
  10791. return;
  10792. }
  10793. }
  10794. }
  10795. }
  10796. void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
  10797. if (!threadpool) return;
  10798. const int n_threads = threadpool->n_threads_max;
  10799. #ifndef GGML_USE_OPENMP
  10800. struct ggml_compute_state* workers = threadpool->workers;
  10801. ggml_mutex_lock(&threadpool->mutex);
  10802. threadpool->stop = true;
  10803. threadpool->pause = false;
  10804. ggml_cond_broadcast(&threadpool->cond);
  10805. ggml_mutex_unlock(&threadpool->mutex);
  10806. for (int j = 1; j < n_threads; j++) {
  10807. int32_t rc = ggml_thread_join(workers[j].thrd, NULL);
  10808. GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED);
  10809. UNUSED(rc);
  10810. }
  10811. ggml_mutex_destroy(&threadpool->mutex);
  10812. ggml_cond_destroy(&threadpool->cond);
  10813. #endif // GGML_USE_OPENMP
  10814. const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
  10815. ggml_aligned_free(threadpool->workers, workers_size);
  10816. ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
  10817. }
  10818. #ifndef GGML_USE_OPENMP
  10819. // pause/resume must be called under mutex
  10820. static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) {
  10821. GGML_PRINT_DEBUG("Pausing threadpool\n");
  10822. threadpool->pause = true;
  10823. ggml_cond_broadcast(&threadpool->cond);
  10824. }
  10825. static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) {
  10826. GGML_PRINT_DEBUG("Resuming threadpool\n");
  10827. threadpool->pause = false;
  10828. ggml_cond_broadcast(&threadpool->cond);
  10829. }
  10830. #endif
  10831. void ggml_threadpool_pause(struct ggml_threadpool * threadpool) {
  10832. #ifndef GGML_USE_OPENMP
  10833. ggml_mutex_lock(&threadpool->mutex);
  10834. if (!threadpool->pause) {
  10835. ggml_threadpool_pause_locked(threadpool);
  10836. }
  10837. ggml_mutex_unlock(&threadpool->mutex);
  10838. #else
  10839. UNUSED(threadpool);
  10840. #endif
  10841. }
  10842. void ggml_threadpool_resume(struct ggml_threadpool * threadpool) {
  10843. #ifndef GGML_USE_OPENMP
  10844. ggml_mutex_lock(&threadpool->mutex);
  10845. if (threadpool->pause) {
  10846. ggml_threadpool_resume_locked(threadpool);
  10847. }
  10848. ggml_mutex_unlock(&threadpool->mutex);
  10849. #else
  10850. UNUSED(threadpool);
  10851. #endif
  10852. }
  10853. struct ggml_cplan ggml_graph_plan(
  10854. const struct ggml_cgraph * cgraph,
  10855. int n_threads,
  10856. struct ggml_threadpool * threadpool) {
  10857. if (threadpool == NULL) {
  10858. //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
  10859. }
  10860. if (n_threads <= 0) {
  10861. n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
  10862. }
  10863. size_t work_size = 0;
  10864. struct ggml_cplan cplan;
  10865. memset(&cplan, 0, sizeof(struct ggml_cplan));
  10866. int max_tasks = 1;
  10867. // thread scheduling for the different operations + work buffer size estimation
  10868. for (int i = 0; i < cgraph->n_nodes; i++) {
  10869. struct ggml_tensor * node = cgraph->nodes[i];
  10870. const int n_tasks = ggml_get_n_tasks(node, n_threads);
  10871. max_tasks = MAX(max_tasks, n_tasks);
  10872. size_t cur = 0;
  10873. switch (node->op) {
  10874. case GGML_OP_CPY:
  10875. case GGML_OP_DUP:
  10876. {
  10877. if (ggml_is_quantized(node->type) ||
  10878. // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
  10879. (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
  10880. (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
  10881. cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
  10882. }
  10883. } break;
  10884. case GGML_OP_ADD:
  10885. case GGML_OP_ADD1:
  10886. {
  10887. if (ggml_is_quantized(node->src[0]->type)) {
  10888. cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
  10889. }
  10890. } break;
  10891. case GGML_OP_ACC:
  10892. {
  10893. if (ggml_is_quantized(node->src[0]->type)) {
  10894. cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
  10895. }
  10896. } break;
  10897. case GGML_OP_COUNT_EQUAL:
  10898. {
  10899. cur = ggml_type_size(node->type)*n_tasks;
  10900. } break;
  10901. case GGML_OP_MUL_MAT:
  10902. {
  10903. const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
  10904. if (node->src[1]->type != vec_dot_type) {
  10905. cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
  10906. }
  10907. } break;
  10908. case GGML_OP_MUL_MAT_ID:
  10909. {
  10910. cur = 0;
  10911. const struct ggml_tensor * src0 = node->src[0];
  10912. const struct ggml_tensor * src1 = node->src[1];
  10913. const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
  10914. if (src1->type != vec_dot_type) {
  10915. cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
  10916. }
  10917. const int n_as = src0->ne[2];
  10918. cur += GGML_PAD(cur, sizeof(int64_t)); // align
  10919. cur += n_as * sizeof(int64_t); // matrix_row_counts
  10920. cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
  10921. } break;
  10922. case GGML_OP_OUT_PROD:
  10923. {
  10924. if (ggml_is_quantized(node->src[0]->type)) {
  10925. cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
  10926. }
  10927. } break;
  10928. case GGML_OP_SOFT_MAX:
  10929. case GGML_OP_ROPE:
  10930. {
  10931. cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
  10932. } break;
  10933. case GGML_OP_CONV_TRANSPOSE_1D:
  10934. {
  10935. GGML_ASSERT(node->src[0]->ne[3] == 1);
  10936. GGML_ASSERT(node->src[1]->ne[2] == 1);
  10937. GGML_ASSERT(node->src[1]->ne[3] == 1);
  10938. const int64_t ne00 = node->src[0]->ne[0]; // K
  10939. const int64_t ne01 = node->src[0]->ne[1]; // Cout
  10940. const int64_t ne02 = node->src[0]->ne[2]; // Cin
  10941. const int64_t ne10 = node->src[1]->ne[0]; // L
  10942. const int64_t ne11 = node->src[1]->ne[1]; // Cin
  10943. if ((node->src[0]->type == GGML_TYPE_F16 ||
  10944. node->src[0]->type == GGML_TYPE_BF16) &&
  10945. node->src[1]->type == GGML_TYPE_F32) {
  10946. cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
  10947. cur += sizeof(ggml_fp16_t)*ne10*ne11;
  10948. } else if (node->src[0]->type == GGML_TYPE_F32 &&
  10949. node->src[1]->type == GGML_TYPE_F32) {
  10950. cur += sizeof(float)*ne00*ne01*ne02;
  10951. cur += sizeof(float)*ne10*ne11;
  10952. } else {
  10953. GGML_ABORT("fatal error");
  10954. }
  10955. } break;
  10956. case GGML_OP_CONV_TRANSPOSE_2D:
  10957. {
  10958. const int64_t ne00 = node->src[0]->ne[0]; // W
  10959. const int64_t ne01 = node->src[0]->ne[1]; // H
  10960. const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
  10961. const int64_t ne03 = node->src[0]->ne[3]; // Channels In
  10962. const int64_t ne10 = node->src[1]->ne[0]; // W
  10963. const int64_t ne11 = node->src[1]->ne[1]; // H
  10964. const int64_t ne12 = node->src[1]->ne[2]; // Channels In
  10965. cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
  10966. cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
  10967. } break;
  10968. case GGML_OP_FLASH_ATTN_EXT:
  10969. {
  10970. const int64_t ne00 = node->src[0]->ne[0]; // D
  10971. cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
  10972. } break;
  10973. case GGML_OP_FLASH_ATTN_BACK:
  10974. {
  10975. const int64_t D = node->src[0]->ne[0];
  10976. const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
  10977. const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
  10978. if (node->src[1]->type == GGML_TYPE_F32) {
  10979. cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
  10980. cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
  10981. } else if (node->src[1]->type == GGML_TYPE_F16) {
  10982. cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
  10983. cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
  10984. } else if (node->src[1]->type == GGML_TYPE_BF16) {
  10985. cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
  10986. cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
  10987. }
  10988. } break;
  10989. case GGML_OP_CROSS_ENTROPY_LOSS:
  10990. {
  10991. cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
  10992. } break;
  10993. case GGML_OP_COUNT:
  10994. {
  10995. GGML_ABORT("fatal error");
  10996. }
  10997. default:
  10998. break;
  10999. }
  11000. work_size = MAX(work_size, cur);
  11001. }
  11002. if (work_size > 0) {
  11003. work_size += CACHE_LINE_SIZE*(n_threads);
  11004. }
  11005. cplan.threadpool = threadpool;
  11006. cplan.n_threads = MIN(max_tasks, n_threads);
  11007. cplan.work_size = work_size;
  11008. cplan.work_data = NULL;
  11009. return cplan;
  11010. }
  11011. static thread_ret_t ggml_graph_compute_thread(void * data) {
  11012. struct ggml_compute_state * state = (struct ggml_compute_state *) data;
  11013. struct ggml_threadpool * tp = state->threadpool;
  11014. const struct ggml_cgraph * cgraph = tp->cgraph;
  11015. const struct ggml_cplan * cplan = tp->cplan;
  11016. set_numa_thread_affinity(state->ith);
  11017. struct ggml_compute_params params = {
  11018. /*.ith =*/ state->ith,
  11019. /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
  11020. /*.wsize =*/ cplan->work_size,
  11021. /*.wdata =*/ cplan->work_data,
  11022. /*.threadpool=*/ tp,
  11023. };
  11024. for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
  11025. struct ggml_tensor * node = cgraph->nodes[node_n];
  11026. ggml_compute_forward(&params, node);
  11027. if (state->ith == 0 && cplan->abort_callback &&
  11028. cplan->abort_callback(cplan->abort_callback_data)) {
  11029. tp->abort = true;
  11030. tp->ec = GGML_STATUS_ABORTED;
  11031. }
  11032. ggml_barrier(state->threadpool);
  11033. }
  11034. return 0;
  11035. }
  11036. #ifndef GGML_USE_OPENMP
  11037. // check if thread is active
  11038. static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
  11039. struct ggml_threadpool * threadpool = state->threadpool;
  11040. int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
  11041. return (state->ith < n_threads);
  11042. }
  11043. // check if thread is ready to proceed (exit from polling or sleeping)
  11044. static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
  11045. struct ggml_threadpool * threadpool = state->threadpool;
  11046. if (state->pending || threadpool->stop || threadpool->pause) { return true; }
  11047. // check for new graph/work
  11048. int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
  11049. if (new_graph != state->last_graph) {
  11050. state->pending = ggml_graph_compute_thread_active(state);
  11051. state->last_graph = new_graph;
  11052. }
  11053. return state->pending;
  11054. }
  11055. // sync thread state after polling
  11056. static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
  11057. // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
  11058. #ifdef GGML_TSAN_ENABLED
  11059. atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
  11060. #else
  11061. atomic_thread_fence(memory_order_seq_cst);
  11062. #endif
  11063. UNUSED(state);
  11064. }
  11065. static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
  11066. struct ggml_threadpool * threadpool = state->threadpool;
  11067. // Skip polling for unused threads
  11068. if (!ggml_graph_compute_thread_active(state)) {
  11069. return state->pending;
  11070. }
  11071. // This seems to make 0 ... 100 a decent range for polling level across modern processors.
  11072. // Perhaps, we can adjust it dynamically based on load and things.
  11073. const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
  11074. for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
  11075. // No new work. Keep polling.
  11076. ggml_thread_cpu_relax();
  11077. }
  11078. return state->pending;
  11079. }
  11080. static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {
  11081. struct ggml_threadpool * threadpool = state->threadpool;
  11082. if (ggml_graph_compute_poll_for_work(state)) {
  11083. ggml_graph_compute_thread_sync(state);
  11084. return state->pending;
  11085. }
  11086. ggml_mutex_lock_shared(&threadpool->mutex);
  11087. while (!ggml_graph_compute_thread_ready(state)) {
  11088. // No new work. Wait for the signal.
  11089. GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
  11090. ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
  11091. }
  11092. ggml_mutex_unlock_shared(&threadpool->mutex);
  11093. return state->pending;
  11094. }
  11095. static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
  11096. struct ggml_compute_state * state = (struct ggml_compute_state *) data;
  11097. struct ggml_threadpool * threadpool = state->threadpool;
  11098. ggml_thread_apply_priority(threadpool->prio);
  11099. if (ggml_thread_cpumask_is_valid(state->cpumask)) {
  11100. ggml_thread_apply_affinity(state->cpumask);
  11101. }
  11102. while (true) {
  11103. // Check if we need to sleep
  11104. while (threadpool->pause) {
  11105. GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
  11106. ggml_mutex_lock_shared(&threadpool->mutex);
  11107. if (threadpool->pause) {
  11108. ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
  11109. }
  11110. GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
  11111. ggml_mutex_unlock_shared(&threadpool->mutex);
  11112. }
  11113. // This needs to be checked for after the cond_wait
  11114. if (threadpool->stop) break;
  11115. // Check if there is new work
  11116. // The main thread is the only one that can dispatch new work
  11117. ggml_graph_compute_check_for_work(state);
  11118. if (state->pending) {
  11119. state->pending = false;
  11120. ggml_graph_compute_thread(state);
  11121. }
  11122. }
  11123. return (thread_ret_t) 0;
  11124. }
  11125. // Start processing new graph
  11126. static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)
  11127. {
  11128. // Always take the mutex here because the worker threads are doing hybrid poll/wait
  11129. ggml_mutex_lock(&threadpool->mutex);
  11130. GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
  11131. // Update the number of active threads
  11132. atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
  11133. // Indicate the graph is ready to be processed
  11134. // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
  11135. atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
  11136. if (threadpool->pause) {
  11137. // Update main thread prio and affinity to match the threadpool settings
  11138. ggml_thread_apply_priority(threadpool->prio);
  11139. if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
  11140. ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
  11141. }
  11142. // resume does cond broadcast
  11143. ggml_threadpool_resume_locked(threadpool);
  11144. } else {
  11145. ggml_cond_broadcast(&threadpool->cond);
  11146. }
  11147. ggml_mutex_unlock(&threadpool->mutex);
  11148. }
  11149. #endif // GGML_USE_OPENMP
  11150. void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
  11151. p->n_threads = n_threads;
  11152. p->prio = 0; // default priority (usually means normal or inherited)
  11153. p->poll = 50; // hybrid-polling enabled
  11154. p->strict_cpu = false; // no strict placement (all threads share same cpumask)
  11155. p->paused = false; // threads are ready to go
  11156. memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
  11157. }
  11158. struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
  11159. struct ggml_threadpool_params p;
  11160. ggml_threadpool_params_init(&p, n_threads);
  11161. return p;
  11162. }
  11163. bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
  11164. if (p0->n_threads != p1->n_threads ) return false;
  11165. if (p0->prio != p1->prio ) return false;
  11166. if (p0->poll != p1->poll ) return false;
  11167. if (p0->strict_cpu != p1->strict_cpu ) return false;
  11168. return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
  11169. }
  11170. static struct ggml_threadpool * ggml_threadpool_new_impl(
  11171. struct ggml_threadpool_params * tpp,
  11172. struct ggml_cgraph * cgraph,
  11173. struct ggml_cplan * cplan) {
  11174. struct ggml_threadpool * threadpool =
  11175. ggml_aligned_malloc(sizeof(struct ggml_threadpool));
  11176. {
  11177. threadpool->cgraph = cgraph;
  11178. threadpool->cplan = cplan;
  11179. threadpool->n_graph = 0;
  11180. threadpool->n_barrier = 0;
  11181. threadpool->n_barrier_passed = 0;
  11182. threadpool->current_chunk = 0;
  11183. threadpool->stop = false;
  11184. threadpool->pause = tpp->paused;
  11185. threadpool->abort = false;
  11186. threadpool->workers = NULL;
  11187. threadpool->n_threads_max = tpp->n_threads;
  11188. threadpool->n_threads_cur = tpp->n_threads;
  11189. threadpool->poll = tpp->poll;
  11190. threadpool->prio = tpp->prio;
  11191. threadpool->ec = GGML_STATUS_SUCCESS;
  11192. }
  11193. // Allocate and init workers state
  11194. const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
  11195. struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
  11196. memset(workers, 0, workers_size);
  11197. for (int j = 0; j < tpp->n_threads; j++) {
  11198. workers[j].threadpool = threadpool;
  11199. workers[j].ith = j;
  11200. }
  11201. threadpool->workers = workers;
  11202. #ifndef GGML_USE_OPENMP
  11203. ggml_mutex_init(&threadpool->mutex);
  11204. ggml_cond_init(&threadpool->cond);
  11205. // Spin the threads for all workers, and update CPU placements.
  11206. // Place the main thread last (towards the higher numbered CPU cores).
  11207. int32_t cpumask_iter = 0;
  11208. for (int j = 1; j < tpp->n_threads; j++) {
  11209. ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
  11210. int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]);
  11211. GGML_ASSERT(rc == 0);
  11212. }
  11213. ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
  11214. if (!threadpool->pause) {
  11215. // Update main thread prio and affinity at the start, otherwise we'll do it in resume
  11216. ggml_thread_apply_priority(threadpool->prio);
  11217. if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
  11218. ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
  11219. }
  11220. }
  11221. #endif // GGML_USE_OPENMP
  11222. return threadpool;
  11223. }
  11224. struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) {
  11225. return ggml_threadpool_new_impl(tpp, NULL, NULL);
  11226. }
  11227. enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
  11228. ggml_cpu_init();
  11229. GGML_ASSERT(cplan);
  11230. GGML_ASSERT(cplan->n_threads > 0);
  11231. GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
  11232. int n_threads = cplan->n_threads;
  11233. struct ggml_threadpool * threadpool = cplan->threadpool;
  11234. bool disposable_threadpool = false;
  11235. if (threadpool == NULL) {
  11236. //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
  11237. disposable_threadpool = true;
  11238. struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads);
  11239. threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan);
  11240. } else {
  11241. // Reset some of the parameters that need resetting
  11242. // No worker threads should be accessing the parameters below at this stage
  11243. threadpool->cgraph = cgraph;
  11244. threadpool->cplan = cplan;
  11245. threadpool->current_chunk = 0;
  11246. threadpool->abort = false;
  11247. threadpool->ec = GGML_STATUS_SUCCESS;
  11248. }
  11249. #ifdef GGML_USE_OPENMP
  11250. if (n_threads > 1) {
  11251. #pragma omp parallel num_threads(n_threads)
  11252. {
  11253. #pragma omp single
  11254. {
  11255. // update the number of threads from the actual number of threads that we got from OpenMP
  11256. n_threads = omp_get_num_threads();
  11257. atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
  11258. }
  11259. ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
  11260. }
  11261. } else {
  11262. atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
  11263. ggml_graph_compute_thread(&threadpool->workers[0]);
  11264. }
  11265. #else
  11266. if (n_threads > threadpool->n_threads_max) {
  11267. GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
  11268. n_threads = threadpool->n_threads_max;
  11269. }
  11270. // Kick all threads to start the new graph
  11271. ggml_graph_compute_kickoff(threadpool, n_threads);
  11272. // This is a work thread too
  11273. ggml_graph_compute_thread(&threadpool->workers[0]);
  11274. #endif
  11275. // don't leave affinity set on the main thread
  11276. clear_numa_thread_affinity();
  11277. enum ggml_status ret = threadpool->ec;
  11278. if (disposable_threadpool) {
  11279. ggml_threadpool_free(threadpool);
  11280. }
  11281. return ret;
  11282. }
  11283. enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
  11284. struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL);
  11285. cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size);
  11286. return ggml_graph_compute(cgraph, &cplan);
  11287. }
  11288. int ggml_cpu_has_neon(void) {
  11289. #if defined(__ARM_ARCH)
  11290. return ggml_arm_arch_features.has_neon;
  11291. #else
  11292. return 0;
  11293. #endif
  11294. }
  11295. int ggml_cpu_has_sve(void) {
  11296. #if defined(__ARM_ARCH)
  11297. return ggml_arm_arch_features.has_sve;
  11298. #else
  11299. return 0;
  11300. #endif
  11301. }
  11302. int ggml_cpu_has_matmul_int8(void) {
  11303. #if defined(__ARM_ARCH)
  11304. return ggml_arm_arch_features.has_i8mm;
  11305. #else
  11306. return 0;
  11307. #endif
  11308. }
  11309. int ggml_cpu_get_sve_cnt(void) {
  11310. #if defined(__ARM_ARCH)
  11311. return ggml_arm_arch_features.sve_cnt;
  11312. #else
  11313. return 0;
  11314. #endif
  11315. }
  11316. void ggml_cpu_init(void) {
  11317. // needed to initialize f16 tables
  11318. {
  11319. struct ggml_init_params params = { 0, NULL, false };
  11320. struct ggml_context * ctx = ggml_init(params);
  11321. ggml_free(ctx);
  11322. }
  11323. ggml_critical_section_start();
  11324. static bool is_first_call = true;
  11325. if (is_first_call) {
  11326. // initialize GELU, Quick GELU, SILU and EXP F32 tables
  11327. {
  11328. const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
  11329. for (int i = 0; i < (1 << 16); ++i) {
  11330. union {
  11331. uint16_t u16;
  11332. ggml_fp16_t fp16;
  11333. } u = {i};
  11334. float f = GGML_FP16_TO_FP32(u.fp16);
  11335. ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
  11336. ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
  11337. }
  11338. const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
  11339. GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
  11340. }
  11341. #if defined(__ARM_ARCH)
  11342. ggml_init_arm_arch_features();
  11343. #endif
  11344. is_first_call = false;
  11345. }
  11346. ggml_critical_section_end();
  11347. }